## 🛠️ Requirements and Installation

Basic Dependencies:
* Python >= 3.8
* Pytorch >= 2.0.0
* CUDA Version >= 11.8

[Local] Install Inf-CL:
```bash
pip install -e .
```

Install required packages:
```bash
pip install -r requirements.txt
```

## ⭐ Features

`inf_cl` is the triton implementation of Inf-CL loss:
* [x] [Ring-CL (inf_cl/ring.py#L238)](inf_clip/models/ops/ring.py#L238)
* [x] [Inf-CL  (inf_cl/ring.py#L251)](inf_clip/models/ops/ring.py#L251)

`inf_clip` is the CLIP training codebase with Inf-CL loss and other training features:
- [x] [Gradient Accumulation (inf_clip/train/train.py#L180)](inf_clip_train/train.py#L180)
- [x] [Gradient Cache (inf_clip/train/train.py#L292)](inf_clip_train/train.py#L292)


## 🔑 Usage

A simple example about how to adopt our Inf-CL loss for contrastive learning. Using such command for attempting:
```
torchrun --nproc_per_node 2 tests/example.py
```

```python
import torch
import torch.nn.functional as F
import torch.distributed as dist
import numpy as np

from inf_cl import cal_inf_loss


def create_cl_tensors(rank, world_size):
    # Parameters
    dtype = torch.float32
    num_heads = 3        # Number of attention heads
    seq_length_q = 32768 # Sequence length
    seq_length_k = 32768
    d_model = 256        # Dimension of each head (must be 16, 32, 64, or 128)

    # Randomly initialize inputs
    q = torch.rand((seq_length_q // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
    k = torch.rand((seq_length_k // world_size, num_heads * d_model), dtype=dtype, device=f"cuda:{rank}")
    l = torch.ones([], dtype=dtype, device=f"cuda:{rank}") * np.log(1 / 0.07)

    q = F.normalize(q, p=2, dim=-1).requires_grad_() # Query
    k = F.normalize(k, p=2, dim=-1).requires_grad_() # Key
    l = l.requires_grad_() # Logit scale

    return q, k, l


if __name__ == "__main__":
    # Assume that the distributed environment has been initialized
    dist.init_process_group("nccl")

    rank = dist.get_rank()
    world_size = dist.get_world_size()

    torch.cuda.set_device(rank)

    # Exampled by Image-Text Contrastive Learning, q is the global image features, 
    # k is the text features, and l is the logit scale.
    q, k, l = create_cl_tensors(rank, world_size)

    # labels are diagonal elements by default. 
    # labels = torch.arange(q.shape[0])
    loss = cal_inf_loss(q, k, scale=l.exp())

    print(loss)

```

## 🗝️ Training & Evaluation

### Quick Start

To facilitate further development on top of our codebase, we provide a quick-start guide on how to use Inf-CLIP to train a customized CLIP and evaluate the trained model on the mainstream clip benchmarks.

1. Training Data Structure:
```bash
Inf-CLIP
├── datasets
│   ├── cc3m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc3m.md
|   |   ├── 0000.tar
|   |   ├── 0001.tar
|   |   ├── ...
|   |   └── 0301.tar
│   ├── cc12m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md
|   |   ├── 0000.tar
|   |   ├── 0001.tar
|   |   ├── ...
|   |   └── 1044.tar
│   ├── laion400m/ # https://github.com/rom1504/img2dataset/blob/main/dataset_examples/laion400m.md
|   |   ├── 00000.tar
|   |   ├── 00001.tar
|   |   ├── ...
|   |   └── 41407.tar
```
2. Command:
```bash
bash scripts/cc3m/lit_vit-b-32_bs16k.sh
bash scripts/cc12m/lit_vit-b-32_bs32k.sh
bash scripts/laion400m/lit_vit-b-32_bs256k.sh
```
3. Evaluation Data Structure:
```bash
Inf-CLIP
├── datasets
│   ├── imagenet-1k/ # download val_images.tar.gz of imagenet
|   |   └── val/
|   |   |   ├── n01440764
|   |   |   ├── n01443537
|   |   |   ├── ...
|   |   |   └── n15075141
│   ├── clip-benchmark/ # bash datasets/benchmarks_download.sh
|   |   ├── wds_mscoco_captions
|   |   ├── wds_flickr8k
|   |   ├── wds_flickr30k
|   |   ├── wds_imagenet1k
|   |   ├── wds_imagenetv2
|   |   ├── wds_imagenet_sketch
|   |   ├── wds_imagenet-a
|   |   ├── wds_imagenet-r
|   |   ├── wds_imagenet-o
|   |   └── wds_objectnet
```
4. Command:
```bash
# imagenet evaluation
bash scripts/imagenet_eval.sh
# overall evaluation
bash scripts/benchmarks_eval.sh
```
