# Learning Intra-class Multimodal Distributions with Orthonormal Matrices

We use MMclassification v0.25.0 (https://github.com/open-mmlab/mmpretrain/releases/tag/v0.25.0) as codebase.

## Requirements

* Python 3.10.10
* CUDA 11.3.1
* torch==1.11.0+cu113
* torchvision==0.12.0+cu113

See https://pytorch.org/get-started/previous-versions/ for installing PyTorch.

## Installation

```shell
pip install openmim==0.3.7
mim install mmcv-full==1.7.1
pip install geoopt==0.5.0
pip install matplotlib==3.7.1
```

## Training

### CIFAR

* For CIFAR-10 with ResNet-50:  
  ```shell
  ./tools/dist_train.sh ./configs/stiefel/resnet50_4xb32_cifar10.py 4 \
  --seed 0 --work-dir work_dirs/resnet50_4xb32_cifar10/seed0 --deterministic
  ```

* For CIFAR-10 with ResNet-101:  
  ```shell
  ./tools/dist_train.sh ./configs/stiefel/resnet101_4xb32_cifar10.py 4 \
  --seed 0 --work-dir work_dirs/resnet101_4xb32_cifar10/seed0 --deterministic
  ```

* For CIFAR-100 with ResNet-50:  
  ```shell
  ./tools/dist_train.sh ./configs/stiefel/resnet50_4xb32_cifar100.py 4 \
  --seed 0 --work-dir work_dirs/resnet50_4xb32_cifar100/seed0 --deterministic
  ```

* For CIFAR-100 with ResNet-101:  
  ```shell
  ./tools/dist_train.sh ./configs/stiefel/resnet101_4xb32_cifar100.py 4 \
  --seed 0 --work-dir work_dirs/resnet101_4xb32_cifar100/seed0 --deterministic
  ```

* For CIFAR-20 with ResNet-50:  
  ```shell
  ./tools/dist_train.sh ./configs/stiefel/resnet50_4xb32_cifar20.py 4 \
  --seed 0 --work-dir work_dirs/resnet50_4xb32_cifar20/seed0 --deterministic
  ```

* For CIFAR-20 with ResNet-101:  
  ```shell
  ./tools/dist_train.sh ./configs/stiefel/resnet101_4xb32_cifar20.py 4 \
  --seed 0 --work-dir work_dirs/resnet101_4xb32_cifar20/seed0 --deterministic
  ```

### ImageNet

You need to prepare the ImageNet dataset in `data/imagenet`.

* For ImageNet with ResNet-50:  
  ```shell
  ./tools/dist_train.sh ./configs/stiefel/resnet50_8xb32_in1k.py 8 \
  --seed 0 --work-dir work_dirs/resnet50_8xb32_in1k/seed0 --deterministic
  ```

* For ImageNet with ResNet-101:  
  ```shell
  ./tools/dist_train.sh ./configs/stiefel/resnet101_8xb32_in1k.py 8 \
  --seed 0 --work-dir work_dirs/resnet101_8xb32_in1k/seed0 --deterministic
  ```

* For ImageNet-127 with ResNet-50:  
  ```shell
  ./tools/dist_train.sh ./configs/stiefel/resnet50_8xb32_in127.py 8 \
  --seed 0 --work-dir work_dirs/resnet50_8xb32_in127/seed0 --deterministic
  ```

* For ImageNet-127 with ResNet-101:  
  ```shell
  ./tools/dist_train.sh ./configs/stiefel/resnet101_8xb32_in127.py 8 \
  --seed 0 --work-dir work_dirs/resnet101_8xb32_in127/seed0 --deterministic
  ```
