"""
Code adapted from: https://github.com/HsinYingLee/DRIT with modification
"""
import torch
import torch.nn as nn
import torchvision.models as models
import torch.nn.functional as F
from models.submodules import InterpolationLayer


class StyleEncoder(nn.Module):
    def __init__(self, input_dim, shared_layers, attribute_channels, use_attributes=False):
        super(StyleEncoder, self).__init__()
        conv_list = []
        self.use_attributes = use_attributes
        self.attribute_channels = 128
        conv_list += [DropBlock(block_size=19, p=0.98)]
        conv_list += [nn.Conv2d(input_dim, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)]
        conv_list += list(models.resnet18(pretrained=True).children())[1:3]
        conv_list += list(models.resnet18(pretrained=True).children())[4:5]
        conv_list.append(list(models.resnet18(pretrained=True).children())[5][0])

        self.conv_layers = nn.Sequential(*conv_list)
        self.conv_share = shared_layers

        if self.use_attributes:
            conv_list2 = []
            conv_list2 += [DropBlock(block_size=19, p=0.98)]
            conv_list2 += [nn.Conv2d(input_dim, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)]
            conv_list2 += list(models.resnet18(pretrained=False).children())[1:3]
            conv_list2 += list(models.resnet18(pretrained=False).children())[4:5]
            conv_list2.append(list(models.resnet18(pretrained=False).children())[5])

            self.attribute_encoder = nn.Sequential(*conv_list2)

    def forward(self, x, attribute_only=False):
        if self.use_attributes:
            output_mean = self.attribute_encoder(x)
            output_logvar = output_mean
            z_attr = output_mean

            if attribute_only:
                return output_mean, output_logvar, z_attr

        x_conv = self.conv_layers(x)
        x_conv_f = self.conv_share(x_conv)

        if not self.use_attributes:
            return (x_conv_f,x_conv), None, None, None
        
        return (x_conv_f,x_conv), output_mean, output_logvar, z_attr




class StyleDecoder(torch.nn.Module):
    def __init__(self, input_c, output_c, attribute_channels, sensor_name):
        super(StyleDecoder, self).__init__()
        self.attribute_channels = attribute_channels
        in_channels = input_c + attribute_channels
        decoder_list_1 = []
#         decoder_list_1.append(nn.ConvTranspose2d(in_channels, 64, 7, stride=2))
        decoder_list_1.append(ConvTranspose2d(in_channels, 64, 7, stride=2))
        decoder_list_1.append(DropBlock(block_size=9, p=0.98))
        decoder_list_1.append(nn.Dropout2d(p=0.1))
        decoder_list_1.append(nn.LayerNorm([64, 69, 69]))
        decoder_list_1.append(nn.Mish())
        self.decoder_list_1 = nn.Sequential(*decoder_list_1)
        
        conv_list1 = []
#         conv_list1.append(nn.Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False))
        conv_list1.append(Conv2d(64, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False))
        conv_list1.append(DropBlock(block_size=9, p=0.98))
        conv_list1.append(nn.Dropout2d(p=0.1))
        conv_list1.append(nn.LayerNorm([64, 69, 69]))
        conv_list1.append(nn.Mish())
        self.conv_list1 = nn.Sequential(*conv_list1)
        
        decoder_list_2 = []
        decoder_list_2.append(ConvTranspose2d(128, 32, 7, stride=2))
        decoder_list_2.append(DropBlock(block_size=19, p=0.98))
        decoder_list_2.append(nn.Dropout2d(p=0.1))
        decoder_list_2.append(nn.LayerNorm([32, 143, 143]))
        decoder_list_2.append(nn.Mish())
        self.decoder_list_2 = nn.Sequential(*decoder_list_2)
        
        conv_list2 = []
        conv_list2.append(Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), bias=False))
        conv_list2.append(DropBlock(block_size=19, p=0.98))
        conv_list2.append(nn.Dropout2d(p=0.1))
        conv_list2.append(nn.LayerNorm([32, 143, 143]))
        conv_list2.append(nn.Mish())
        self.conv_list2 = nn.Sequential(*conv_list2)
        
        conv_list3 = []
        conv_list3.append(Conv2d(64, 16, kernel_size=(8, 8), stride=(1, 1), padding=(0, 0), bias=False))
#         conv_list3.append(Conv2d(64, 16, kernel_size=(8, 8), stride=(1, 1), padding=(0, 0), bias=True))
        conv_list3.append(DropBlock(block_size=19, p=0.98))
        conv_list3.append(nn.Dropout2d(p=0.1))
        conv_list3.append(nn.LayerNorm([16, 136, 136]))
        conv_list3.append(nn.Mish())
        self.conv_list3 = nn.Sequential(*conv_list3)
        
        conv_list4 = []
#         conv_list4.append(nn.Conv2d(16, output_c, kernel_size=(9, 9), stride=(1, 1), padding=(0, 0), bias=False))
        conv_list4.append(Conv2d(16, output_c, kernel_size=(9, 9), stride=(1, 1), padding=(0, 0), bias=False))
        conv_list4.append(DropBlock(block_size=19, p=0.98))
        conv_list4.append(nn.Dropout2d(p=0.1))
        conv_list4.append(nn.LayerNorm([output_c, 128, 128]))
        conv_list4.append(nn.Mish())
        self.conv_list4 = nn.Sequential(*conv_list4)

    def forward(self, x, z):
        inputs = torch.cat((x, z), 1)
        out1 = self.decoder_list_1(inputs)
        out2 = self.conv_list1(out1)
        out12= torch.cat((out1, out2), 1)        
        out3 = self.decoder_list_2(out12)
        out4 = self.conv_list2(out3)
        out34= torch.cat((out3, out4), 1)        
        out5 = self.conv_list3(out34)        
        out6 = self.conv_list4(out5)

        return out6


class ContentDiscriminator(nn.Module):
    def __init__(self, nr_channels, smaller_input=False):
        super(ContentDiscriminator, self).__init__()
        model = []
        model += [DropBlock(block_size=19, p=0.99)]
        model += [LeakyReLUConv2d(nr_channels, nr_channels, kernel_size=7, stride=2, padding=1, norm='Instance')]
        model += [DropBlock(block_size=9, p=0.98)]
        model += [LeakyReLUConv2d(nr_channels, nr_channels, kernel_size=7, stride=1, padding=1, norm='Instance')]
        if smaller_input:
            model += [LeakyReLUConv2d(nr_channels, nr_channels, kernel_size=4, stride=1, padding=1, norm='Instance')]
        else:
            model += [LeakyReLUConv2d(nr_channels, nr_channels, kernel_size=7, stride=1, padding=1, norm='Instance')]
        
        model += [LeakyReLUConv2d(nr_channels, nr_channels, kernel_size=4, stride=1, padding=0)]
        model += [nn.Conv2d(nr_channels, 1, kernel_size=1, stride=1, padding=0)]
        model += [nn.Flatten()]
        model += [nn.Linear(9, 9)]
        model += [nn.Mish()]
        model += [nn.Linear(9, 2)]
        model += [nn.Softmax()]
        self.model = nn.Sequential(*model)

    def forward(self, x):
        out = self.model(x)
        out = out[:,0]

        return out


class CrossDiscriminator(nn.Module):
    def __init__(self, input_dim, n_layer=6, norm= 'None', sn=True):
        super(CrossDiscriminator, self).__init__()
        ch = 64
        self.model = self._make_net(ch, input_dim, n_layer, norm, sn)

    def _make_net(self, ch, input_dim, n_layer, norm, sn):
        model = []
        model += [DropBlock(block_size=19, p=0.98)]
        model += [LeakyReLUConv2d(input_dim, ch, kernel_size=3, stride=1, padding=1, norm=norm, sn=sn)] #16
        tch = ch

        for i in range(1, n_layer-1):
            model += [DropBlock(block_size=9, p=0.99)]
            model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)] # 8
            tch *= 2

        model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm='None', sn=sn)] # 2
        tch *= 2
        if sn:
            model += [torch.nn.utils.spectral_norm(nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0))]  # 1
        else:
            model += [nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0)]  # 1
        model += [nn.Flatten()]
        model += [nn.Linear(16, 16)]
        model += [nn.Mish()]
        model += [nn.Linear(16, 2)]
        model += [nn.Softmax()]

        
        return nn.Sequential(*model)

    def cuda(self, gpu):
        self.model.cuda(gpu)

    def forward(self, x_A):
        out_A = self.model(x_A)
        out_A = out_A[:,0]

        return out_A


####################################################################
# -------------------------- Basic Blocks --------------------------
####################################################################

def conv3x3(in_planes, out_planes):
    return [nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=True)]



def gaussian_weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 and classname.find('Conv') == 0:
        m.weight.data.normal_(0.0, 0.02)

def my_weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 and classname.find('Conv') == 0:
        m.weight.data.normal_(0.02, 0.005)
def my_weights_init2(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 and classname.find('Conv') == 0:
        m.weight.data.uniform_(-0.05, 0.05)
        
class Conv2d(nn.Conv2d):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=False):
        super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride,
                 padding, dilation, groups, bias=bias)
        self.stride = stride
    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv2d(x, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

class ConvTranspose2d(nn.ConvTranspose2d):

    def __init__(self,in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=False, dilation=1, padding_mode='zeros', device=None, dtype=None):
        super(ConvTranspose2d, self).__init__(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1, bias=bias, dilation=1, padding_mode='zeros', device=None, dtype=None)
        self.stride = stride
    def forward(self, x):
        weight = self.weight
        weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2,
                                  keepdim=True).mean(dim=3, keepdim=True)
        weight = weight - weight_mean
        std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5
        weight = weight / std.expand_as(weight)
        return F.conv_transpose2d(x, weight, self.bias, self.stride,
                        padding=self.padding, output_padding=0, groups=self.groups, dilation=1)


def meanpoolConv(inplanes, outplanes):
    sequence = []
    sequence += [nn.AvgPool2d(kernel_size=2, stride=2)]
    sequence += [nn.Conv2d(inplanes, outplanes, kernel_size=1, stride=1, padding=0, bias=True)]
    sequence += [nn.Dropout2d(p=0.1)]
    return nn.Sequential(*sequence)


def convMeanpool(inplanes, outplanes):
    sequence = []
    sequence += conv3x3(inplanes, outplanes)
    sequence += [nn.Dropout2d(p=0.1)]
    sequence += [nn.AvgPool2d(kernel_size=2, stride=2)]
    sequence += [nn.Dropout2d(p=0.1)]
    return nn.Sequential(*sequence)


# The code of LayerNorm is modified from MUNIT (https://github.com/NVlabs/MUNIT)
class LayerNorm(nn.Module):
    def __init__(self, n_out, eps=1e-5, affine=True):
        super(LayerNorm, self).__init__()
        self.n_out = n_out
        self.affine = affine
        if self.affine:
            self.weight = nn.Parameter(torch.ones(n_out, 1, 1))
            self.bias = nn.Parameter(torch.zeros(n_out, 1, 1))
        return

    def forward(self, x):
        normalized_shape = x.size()[1:]
        if self.affine:
            return torch.nn.functional.layer_norm(x, normalized_shape, self.weight.expand(normalized_shape),
                                                  self.bias.expand(normalized_shape))
        else:
            return torch.nn.functional.layer_norm.layer_norm(x, normalized_shape)


class LeakyReLUConv2d(nn.Module):
    def __init__(self, n_in, n_out, kernel_size, stride, padding=0, norm='None', sn=False):
        super(LeakyReLUConv2d, self).__init__()
        model = []
        if sn:
            model += [torch.nn.utils.spectral_norm(nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride,
                                                             padding=padding, bias=True))]
            model += [nn.Dropout2d(p=0.1)]

        else:
            model += [nn.Conv2d(n_in, n_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=True)]
            model += [nn.Dropout2d(p=0.1)]
        if norm == 'Instance':
            model += [nn.InstanceNorm2d(n_out, affine=False)]

        model += [nn.Mish()]

        self.model = nn.Sequential(*model)
        self.model.apply(gaussian_weights_init)


    def forward(self, x):
        return self.model(x)


#adaped from https://zhuanlan.zhihu.com/p/425636663
class DropBlock(nn.Module):
    def __init__(self, block_size: int, p: float = 0.8): #p is probability to keep
        super().__init__()
        self.block_size = block_size
        self.p = p


    def calculate_gamma(self, x):
        """calculate gamma
        Args:
            x (Tensor): input tensor
        Returns:
            Tensor: gamma
        """
        
        invalid = (1 - self.p) / (self.block_size ** 2)
        valid = (x.shape[-1] ** 2) / ((x.shape[-1] - self.block_size + 1) ** 2)
        return invalid * valid
    def forward(self, x):
        if self.training:
            gamma = self.calculate_gamma(x)
            mask = torch.bernoulli(torch.ones_like(x) * gamma)
            mask_block = 1 - F.max_pool2d(
                mask,
                kernel_size=(self.block_size, self.block_size),
                stride=(1, 1),
                padding=(self.block_size // 2, self.block_size // 2),
            )
            x = mask_block * x * (mask_block.numel() / mask_block.sum())
        return x
