from torch import nn
from .resnet_cifar import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152


class projection_MLP(nn.Module):
    def __init__(self, in_dim, out_dim, num_layers=2):
        super().__init__()
        hidden_dim = out_dim
        self.num_layers = num_layers

        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )

        self.layer2 = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer3 = nn.Sequential(
            nn.Linear(hidden_dim, out_dim),
            nn.BatchNorm1d(out_dim, affine=False)  # Page:5, Paragraph:2
        )

    def forward(self, x):
        if self.num_layers == 2:
            x = self.layer1(x)
            x = self.layer3(x)
        elif self.num_layers == 3:
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)

        return x


class prediction_MLP(nn.Module):
    def __init__(self, in_dim=2048):
        super().__init__()
        out_dim = in_dim
        hidden_dim = int(out_dim / 4)

        self.layer1 = nn.Sequential(
            nn.Linear(in_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True)
        )
        self.layer2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)

        return x


class SepCL_v1_Simsiam(nn.Module):
    def __init__(self, args):
        super(SepCL_v1_Simsiam, self).__init__()
        self.backbone = SepCL_v1_Simsiam.get_backbone(args.arch, args)
        out_dim = self.backbone.fc.weight.shape[1]
        self.backbone.fc = nn.Identity()

        self.size1 = int(round(out_dim * args.hp1)) ## out_dim 512
        self.size2 = out_dim - self.size1

        self.projector_img = projection_MLP(self.size1, args.feat_dim, args.num_proj_layers)
        self.predictor_img = prediction_MLP(args.feat_dim)

        self.projector_aug = projection_MLP(self.size2, args.feat_dim, args.num_proj_layers)
        self.predictor_aug = prediction_MLP(args.feat_dim)

    @staticmethod
    def get_backbone(backbone_name, args = None):
        return {'resnet18': ResNet18(args=args),
                'resnet34': ResNet34(args=args),
                'resnet50': ResNet50(args=args),
                'resnet101': ResNet101(args=args),
                'resnet152': ResNet152(args=args)}[backbone_name]

    def forward(self, im_aug1, im_aug2):

        z1 = self.backbone(im_aug1)
        z1_sep_img = z1[:,:self.size1]
        z1_sep_aug = z1[:,self.size1:]

        z2 = self.backbone(im_aug2)
        z2_sep_img = z2[:,:self.size1]
        z2_sep_aug = z2[:,self.size1:]

        z1_feature_img = self.projector_img(z1_sep_img)
        z2_feature_img = self.projector_img(z2_sep_img)

        z1_feature_aug = self.projector_aug(z1_sep_aug)
        z2_feature_aug = self.projector_aug(z2_sep_aug)

        p1_img = self.predictor_img(z1_feature_img)
        p2_img = self.predictor_img(z2_feature_img)

        p1_aug = self.predictor_aug(z1_feature_aug)
        p2_aug = self.predictor_aug(z2_feature_aug)

        return {'za_img': z1_feature_img, 'zb_img': z2_feature_img, 'za_aug': z1_feature_aug, 'zb_aug': z2_feature_aug, 'pa_img': p1_img, 'pb_img': p2_img, 'pa_aug': p1_aug, 'pb_aug': p2_aug, 'z1_sep_img': z1_sep_img, 'z1_sep_aug': z1_sep_aug, 'z2_sep_img': z2_sep_img, 'z2_sep_aug': z2_sep_aug}







