KSM: Fast Multiple Task Adaption via Kernel-Wise Soft Mask Learning

Li Yang, Zhezhi He, Junshan Zhang, Deliang Fan; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2021, pp. 13845-13853

Abstract


Deep Neural Networks (DNN) could forget the knowledge about earlier tasks when learning new tasks, and this is known as catastrophic forgetting. To learn new task without forgetting, recently, the mask-based learning method (e.g. piggyback ) is proposed to address these issues by learning only a binary element-wise mask, while keeping the backbone model fixed. However, the binary mask has limited modeling capacity for new tasks. A more recent work proposes a compress-grow-based method (CPG) to achieve better accuracy for new tasks by partially training backbone model, but with order-higher training cost, which makes it infeasible to be deployed into popular state-of-the-art edge-/mobile-learning. The primary goal of this work is to simultaneously achieve fast and high-accuracy multi-task adaption in a continual learning setting. Thus motivated, we propose a new training method called Kernel-wise Soft Mask (KSM), which learns a kernel-wise hybrid binary and real-value soft mask for each task. Such a soft mask can be viewed as a superposition of a binary mask and a properly scaled real-value tensor, which offers a richer representation capability without low-level kernel support to meet the objective of low hardware overhead. We validate KSM on multiple benchmark datasets against recent state-of-the-art methods (e.g. Piggyback, Packnet, CPG, etc.), which shows good improvement in both accuracy and training cost.

Related Material


[pdf] [arXiv]
[bibtex]
@InProceedings{Yang_2021_CVPR, author = {Yang, Li and He, Zhezhi and Zhang, Junshan and Fan, Deliang}, title = {KSM: Fast Multiple Task Adaption via Kernel-Wise Soft Mask Learning}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, year = {2021}, pages = {13845-13853} }