SKFAC: Training Neural Networks With Faster Kronecker-Factored Approximate Curvature

Zedong Tang, Fenlong Jiang, Maoguo Gong, Hao Li, Yue Wu, Fan Yu, Zidong Wang, Min Wang; Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2021, pp. 13479-13487

Abstract


The bottleneck of computation burden limits the widespread use of the 2nd order optimization algorithms for training deep neural networks. In this paper, we present a computationally efficient approximation for natural gradient descent, named Swift Kronecker-Factored Approximate Curvature (SKFAC), which combines Kronecker factorization and a fast low-rank matrix inversion technique. Our research aims at both fully connected and convolutional layers. For the fully connected layers, by utilizing the low-rank property of Kronecker factors of Fisher information matrix, our method only requires inverting a small matrix to approximate the curvature with desirable accuracy. For convolutional layers, we propose a way with two strategies to save computational efforts without affecting the empirical performance by reducing across the spatial dimension or receptive fields of feature maps. Specifically, we propose two effective dimension reduction methods for this purpose: Spatial Subsampling and Reduce Sum. Experimental results of training several deep neural networks on Cifar-10 and ImageNet-1k datasets demonstrate that SKFAC can capture the main curvature and yield comparative performance to K-FAC. The proposed method bridges the wall-clock time gap between the 1st and 2nd order algorithms.

Related Material


[pdf] [supp]
[bibtex]
@InProceedings{Tang_2021_CVPR, author = {Tang, Zedong and Jiang, Fenlong and Gong, Maoguo and Li, Hao and Wu, Yue and Yu, Fan and Wang, Zidong and Wang, Min}, title = {SKFAC: Training Neural Networks With Faster Kronecker-Factored Approximate Curvature}, booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, month = {June}, year = {2021}, pages = {13479-13487} }