from torch import nn
import torchvision
from .resnet_cifar import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152

class SepCL_v1(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.backbone = SepCL_v1.get_backbone(self.args.arch, self.args)
        out_dim = self.backbone.fc.weight.shape[1]
        self.backbone.fc = nn.Identity()

        self.size1 = int(round(out_dim * self.args.hp1)) ## out_dim 512
        self.size2 = out_dim - self.size1
        # projector_img
        sizes_sep1 = [self.size1] + list(map(int, args.projector_img.split('-')))
        layers = []
        for i in range(len(sizes_sep1) - 2):
            layers.append(nn.Linear(sizes_sep1[i], sizes_sep1[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes_sep1[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes_sep1[-2], sizes_sep1[-1], bias=False))
        self.projector_img = nn.Sequential(*layers)

        # projector_aug
        sizes_sep2 = [self.size2] + list(map(int, args.projector_aug.split('-')))
        layers = []
        for i in range(len(sizes_sep2) - 2):
            layers.append(nn.Linear(sizes_sep2[i], sizes_sep2[i + 1], bias=False))
            layers.append(nn.BatchNorm1d(sizes_sep2[i + 1]))
            layers.append(nn.ReLU(inplace=True))
        layers.append(nn.Linear(sizes_sep2[-2], sizes_sep2[-1], bias=False))
        self.projector_aug = nn.Sequential(*layers)

    @staticmethod
    def get_backbone(backbone_name, args = None):
        return {'resnet18': ResNet18(args=args),
                'resnet34': ResNet34(args=args),
                'resnet50': ResNet50(args=args),
                'resnet101': ResNet101(args=args),
                'resnet152': ResNet152(args=args)}[backbone_name]

    def forward(self, im_aug1, im_aug2):
        z1 = self.backbone(im_aug1)
        z2 = self.backbone(im_aug2)

        z1_sep_img = z1[:,:self.size1]
        z1_sep_aug = z1[:,self.size1:]

        z2_sep_img = z2[:,:self.size1]
        z2_sep_aug = z2[:,self.size1:]

        z1_feature_img = self.projector_img(z1_sep_img)
        z2_feature_img = self.projector_img(z2_sep_img)

        z1_feature_aug = self.projector_aug(z1_sep_aug)
        z2_feature_aug = self.projector_aug(z2_sep_aug)

        return {'za_img': z1_feature_img, 'zb_img': z2_feature_img, 'za_aug': z1_feature_aug, 'zb_aug': z2_feature_aug, 'z1_sep_img': z1_sep_img, 'z1_sep_aug': z1_sep_aug, 'z2_sep_img': z2_sep_img, 'z2_sep_aug': z2_sep_aug}

