import torch
import torch.nn as nn
from torchvision.models.vgg import vgg19

from iconflow.model.rescae import get_residual_block

from .featext import FeatureExtractor


class Generator(nn.Module):
    def __init__(self, in_ch=1):
        super().__init__()
        
        self.downsample = nn.MaxPool2d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
        
        lrelu = lambda: nn.LeakyReLU(0.2, inplace=True)
        
        self.down0 = nn.Sequential(nn.Conv2d(in_ch, 16, 3, padding='same', padding_mode='replicate'),
                                   nn.BatchNorm2d(16), lrelu())
        self.down1 = get_residual_block(16, 16, nonlinearity=lrelu)
        self.down2 = get_residual_block(16, 32, nonlinearity=lrelu)
        self.down3 = get_residual_block(32, 64, nonlinearity=lrelu)
        self.down4 = get_residual_block(64, 128, nonlinearity=lrelu)
        self.down5 = get_residual_block(128, 256, nonlinearity=lrelu)
        self.down6 = get_residual_block(256, 2048, nonlinearity=lrelu)
        
        self.up6 = get_residual_block(2048, 512, nonlinearity=lrelu)
        self.up5 = get_residual_block(512, 512, nonlinearity=lrelu)
        self.up4 = get_residual_block(512+256, 128, nonlinearity=lrelu)
        self.up3 = get_residual_block(128+64, 64, nonlinearity=lrelu)
        self.up2 = get_residual_block(64+32, 32, nonlinearity=lrelu)
        self.up1 = get_residual_block(32+16, 64, nonlinearity=lrelu)
        self.up0 = nn.Conv2d(64, 3, 3, padding='same', padding_mode='replicate')
        
        self.feat = nn.Linear(4096, 2048)
                    
    def forward(self, sketch, reference_features, return_last=True):
        # sketch [B, C, H, W]
        # reference_features [B, 4096]
        
        h = self.down0(sketch)
        h = d1 = self.down1(h) # [16, 256, 256]
        h = self.downsample(h) # [16, 128, 128]
        h = d2 = self.down2(h) # [32, 128, 128]
        h = self.downsample(h) # [32, 64, 64]
        h = d3 = self.down3(h) # [64, 64, 64]
        h = self.downsample(h) # [64, 32, 32]
        h = d4 = self.down4(h) # [128, 32, 32]
        h = self.downsample(h) # [128, 16, 16]
        h = d5 = self.down5(h) # [256, 16, 16]
        h = self.downsample(h) # [256, 8, 8]
        h = d6 = self.down6(h)
        
        s = self.feat(reference_features) # [2048]
        h = h + s[:, :, None, None] # [2048, 8, 8]
        
        h = u6 = self.up6(h) # [512, 8, 8]
        h = self.upsample(h) # [512, 16, 16]
        h = u5 = self.up5(h) # [512, 16, 16]
        h = torch.cat([h, d5], 1) # [512+256, 16, 16]
        h = self.upsample(h) # [512+256, 32, 32]
        h = u4 = self.up4(h) # [128, 32, 32]
        h = self.upsample(h) # [128, 64, 64]
        h = torch.cat([h, d3], 1) # [128+64, 64, 64]
        h = u3 = self.up3(h) # [64, 64, 64]
        h = self.upsample(h) # [64, 128, 128]
        h = torch.cat([h, d2], 1) # [64+32, 128, 128]
        h = u2 = self.up2(h) # [32, 128, 128]
        h = self.upsample(h) # [32, 256, 256]
        h = torch.cat([h, d1], 1) # [32+16, 256, 256]
        h = u1 = self.up1(h) # [64, 256, 256]
        output = torch.sigmoid(self.up0(h))
        
        if return_last:
            return output
        
        return dict(
            output=output,
            bottom1=d5, # [256, 16, 16]
            bottom2=u5 # [512, 16, 16]
        )


class GuideDecoder(nn.Module):
    def __init__(self, in_ch, out_ch=3):
        super().__init__()
        self.upsample = nn.Upsample(scale_factor=2)
        self.conv5 = nn.Conv2d(in_ch, 512, 3, padding='same', padding_mode='replicate')
        self.bn5 = nn.BatchNorm2d(512)
        self.conv4 = nn.Conv2d(512, 256, 3, padding='same', padding_mode='replicate')
        self.bn4 = nn.BatchNorm2d(256)
        self.conv3 = nn.Conv2d(256, 128, 3, padding='same', padding_mode='replicate')
        self.bn3 = nn.BatchNorm2d(128)
        self.conv2 = nn.Conv2d(128, 64, 3, padding='same', padding_mode='replicate')
        self.bn2 = nn.BatchNorm2d(64)
        self.conv1 = nn.Conv2d(64, 32, 3, padding='same', padding_mode='replicate')
        self.bn1 = nn.BatchNorm2d(32)
        self.conv0 = nn.Conv2d(32, out_ch, 3, padding='same', padding_mode='replicate')

    def forward(self, h):
        h = torch.relu(self.bn5(self.conv5(h)))
        h = self.upsample(h)
        h = torch.relu(self.bn4(self.conv4(h)))
        h = self.upsample(h)
        h = torch.relu(self.bn3(self.conv3(h)))
        h = self.upsample(h)
        h = torch.relu(self.bn2(self.conv2(h)))
        h = self.upsample(h)
        h = torch.relu(self.bn1(self.conv1(h)))
        h = torch.sigmoid(self.conv0(h))
        return h


class StyleEncoder(FeatureExtractor):
    def __init__(self, image_size):
        super().__init__(vgg19(pretrained=True), (3, image_size, image_size))
        
        self.hook_key = ('classifier.0', (4096,))
        self.add_hook(*self.hook_key)
    
    def forward(self, x):
        return super().forward(x)[self.hook_key] # [B, 4096]


class Discriminator(nn.Module):
    def __init__(self, in_ch=3, base_dim=64):
        super().__init__()

        # Input Dimension: (nc) x 64 x 64
        self.conv1 = nn.Conv2d(in_ch, base_dim,
            4, 2, 1, bias=False)

        # Input Dimension: (ndf) x 32 x 32
        self.conv2 = nn.Conv2d(base_dim, base_dim*2,
            4, 2, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(base_dim*2)

        # Input Dimension: (ndf*2) x 16 x 16
        self.conv3 = nn.Conv2d(base_dim*2, base_dim*4,
            4, 2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(base_dim*4)

        # Input Dimension: (ndf*4) x 8 x 8
        self.conv4 = nn.Conv2d(base_dim*4, base_dim*8,
            4, 2, 1, bias=False)
        self.bn4 = nn.BatchNorm2d(base_dim*8)

        # Input Dimension: (ndf*8) x 4 x 4
        self.conv5 = nn.Conv2d(base_dim*8, 1, 4, 1, 0, bias=False)
        
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)
        
        def weights_init(m):
            classname = m.__class__.__name__
            if classname.find('Conv') != -1:
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif classname.find('BatchNorm') != -1:
                nn.init.normal_(m.weight.data, 1.0, 0.02)
                nn.init.constant_(m.bias.data, 0)
        
        self.apply(weights_init)

    def forward(self, x):
        x = self.lrelu(self.conv1(x))
        x = self.lrelu(self.bn2(self.conv2(x)))
        x = self.lrelu(self.bn3(self.conv3(x)))
        x = self.lrelu(self.bn4(self.conv4(x)))
        x = torch.sigmoid(self.conv5(x))
        return x


if __name__ == '__main__':
    x = torch.randn(1, 3, 256, 256)
    c = torch.randn(1, 1, 256, 256)
    
    G = Generator()
    D = Discriminator()
    en = StyleEncoder(256)
    gd1 = GuideDecoder(256)
    gd2 = GuideDecoder(512)
    
    if False:
        s = en(x)
        out = G(c, s)
        outs = G(c, s, return_last=False)
        