import torch
from torch import nn


class SimSiamLoss(nn.Module):
    def __init__(self, version='simplified'):
        super().__init__()
        self.ver = version

    def asymmetric_loss(self, p, z):
        if self.ver == 'original':
            z = z.detach()  # stop gradient

            p = nn.functional.normalize(p, dim=1)
            z = nn.functional.normalize(z, dim=1)

            return -(p * z).sum(dim=1).mean()

        elif self.ver == 'simplified':
            z = z.detach()  # stop gradient
            return - nn.functional.cosine_similarity(p, z, dim=-1).mean()

    def forward(self, z1, z2, p1, p2):

        loss1 = self.asymmetric_loss(p1, z2)
        loss2 = self.asymmetric_loss(p2, z1)

        loss = 0.5 * loss1 + 0.5 * loss2
        return [loss, 0.5 * loss1, 0.5 * loss2]


def off_diagonal(x):
    # return a flattened view of the off-diagonal elements of a square matrix
    n, m = x.shape
    assert n == m
    return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten()

class TwinsLoss(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.size1 = int(round(self.args.feat_dim * self.args.hp1))
        self.size2 = self.args.feat_dim - self.size1
        self.bn1 = nn.BatchNorm1d(self.size1, affine=False)
        self.bn2 = nn.BatchNorm1d(self.size2, affine=False)

    def forward(self, za, zb):
        za1 = za[:,:self.size1]
        zb1 = zb[:,:self.size1]

        # empirical cross-correlation matrix
        c = self.bn1(za1).T @ self.bn1(zb1)

        # sum the cross-correlation matrix between all gpus
        c.div_(self.args.batch_size)

        on_diag = torch.diagonal(c).add_(-1).pow_(2).sum()
        off_diag = off_diagonal(c).pow_(2).sum()
        loss = on_diag + self.args.lambd * off_diag

        return [loss, on_diag, self.args.lambd * off_diag, c]


class SepCLLoss_v1_OTL1(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.bn = nn.BatchNorm1d(self.args.feat_dim, affine=False)
        self.l1 = nn.L1Loss()

    def forward(self, z1_feature_img, z2_feature_img, z1_feature_aug, z2_feature_aug):
        # empirical cross-correlation matrix
        c_img = self.bn(z1_feature_img).T @ self.bn(z2_feature_img)
        # sum the cross-correlation matrix between all gpus
        c_img.div_(self.args.batch_size)

        on_diag_img = torch.diagonal(c_img).add_(-1).pow_(2).sum()
        off_diag_img = off_diagonal(c_img).pow_(2).sum()
        loss_img = on_diag_img + self.args.lambd * off_diag_img

        loss_aug = nn.functional.cosine_similarity(z1_feature_aug, z2_feature_aug, dim=-1).mean()
        loss_aug = self.l1(loss_aug, 0*loss_aug)

        loss = loss_img + self.args.sep_lambd * loss_aug
        return [loss, loss_img, on_diag_img, self.args.lambd * off_diag_img, loss_aug]

class SepCLLoss_v1_SimSiam_OTL1(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.l1 = nn.L1Loss()

    def asymmetric_loss(self, p, z):
            z = z.detach()  # stop gradient
            return - nn.functional.cosine_similarity(p, z, dim=-1).mean()

    def forward(self, z1_img, z2_img, z1_aug, z2_aug, p1_img, p2_img, p1_aug, p2_aug):

        loss1_img = self.asymmetric_loss(p1_img, z2_img)
        loss1_aug = - self.asymmetric_loss(p1_aug, z2_aug)

        loss2_img = self.asymmetric_loss(p2_img, z1_img)
        loss2_aug = - self.asymmetric_loss(p2_aug, z1_aug)

        loss_img = 0.5 * loss1_img + 0.5 * loss2_img
        loss_aug = 0.5 * loss1_aug + 0.5 * loss2_aug
        loss_aug = self.l1(loss_aug, 0*loss_aug)

        loss = loss_img + self.args.sep_lambd * loss_aug ## default 1.0
        return [loss, loss_img, 0.5 * loss1_img, 0.5 * loss2_img, loss_aug, 0.5 * loss1_aug, 0.5 * loss2_aug]

