from torch import nn
import torchvision
from .resnet_cifar import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152

class BarlowTwins(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.backbone = BarlowTwins.get_backbone(self.args.arch, self.args)
        out_dim = self.backbone.fc.weight.shape[1]
        self.backbone.fc = nn.Identity()

        # projector
        sizes = [out_dim] + list(map(int, args.projector_img.split('-')))
        layers = []
        for i in range(len(sizes) - 2):
            layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes[-2], sizes[-1], bias=False))
        self.projector = nn.Sequential(*layers)

        self.encoder = nn.Sequential(
            self.backbone,
            self.projector
        )

    @staticmethod
    def get_backbone(backbone_name, args = None):
        return {'resnet18': ResNet18(args=args),
                'resnet34': ResNet34(args=args),
                'resnet50': ResNet50(args=args),
                'resnet101': ResNet101(args=args),
                'resnet152': ResNet152(args=args)}[backbone_name]

    def forward(self, im_aug1, im_aug2):
        z1 = self.encoder(im_aug1)
        z2 = self.encoder(im_aug2)
        
        return {'za': z1, 'zb': z2}



