import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


class LowLevelNetwork(nn.Module):
    def __init__(self, in_ch=1):
        super().__init__()
        self.conv1_1 = nn.Conv2d(in_channels=in_ch, out_channels=64, kernel_size=3, stride=2, padding=1, padding_mode='replicate')
        self.bn1_1 = nn.BatchNorm2d(64)
        self.conv1_2 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
        self.bn1_2 = nn.BatchNorm2d(128)
        self.conv2_1 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, padding_mode='replicate')
        self.bn2_1 = nn.BatchNorm2d(128)
        self.conv2_2 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
        self.bn2_2 = nn.BatchNorm2d(256)
        self.conv3_1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, padding_mode='replicate')
        self.bn3_1 = nn.BatchNorm2d(256)
        self.conv3_2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
        self.bn3_2 = nn.BatchNorm2d(512)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        h = x
        h = self.lrelu(self.bn1_1(self.conv1_1(h)))
        h = self.lrelu(self.bn1_2(self.conv1_2(h)))
        h = self.lrelu(self.bn2_1(self.conv2_1(h)))
        h = self.lrelu(self.bn2_2(self.conv2_2(h)))
        h = self.lrelu(self.bn3_1(self.conv3_1(h)))
        h = self.lrelu(self.bn3_2(self.conv3_2(h)))
        return h

class MidLevelNetwork(nn.Module):
    def __init__(self, in_ch=512):
        super().__init__()
        self.conv1_1 = nn.Conv2d(in_channels=in_ch, out_channels=512, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
        self.bn1_1 = nn.BatchNorm2d(512)
        self.conv1_2 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
        self.bn1_2 = nn.BatchNorm2d(256)
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        h = x
        h = self.lrelu(self.bn1_1(self.conv1_1(h)))
        h = self.lrelu(self.bn1_2(self.conv1_2(h)))
        return h


class ColorizationNetwork(nn.Module):
    def __init__(self, output_channels=3):
        super().__init__()
        
        self.conv1_1 = nn.Conv2d(256, 128, 3, 1, 1, padding_mode='replicate')
        self.bn1_1 = nn.BatchNorm2d(128)

        self.conv2_1 = nn.Conv2d(128, 64, 3, 1, 1, padding_mode='replicate')
        self.bn2_1 = nn.BatchNorm2d(64)
        self.conv2_2 = nn.Conv2d(64, 64, 3, 1, 1, padding_mode='replicate')
        self.bn2_2 = nn.BatchNorm2d(64)

        self.conv3_1 = nn.Conv2d(64, 32, 3, 1, 1, padding_mode='replicate')
        self.bn3_1 = nn.BatchNorm2d(32)
        self.conv3_2 = nn.Conv2d(32, output_channels, 3, 1, 1, padding_mode='replicate')
        
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        h = x
        h = self.lrelu(self.bn1_1(self.conv1_1(h)))
        h = F.interpolate(h, scale_factor=2, mode='nearest')
        h = self.lrelu(self.bn2_1(self.conv2_1(h)))
        h = self.lrelu(self.bn2_2(self.conv2_2(h)))
        h = F.interpolate(h, scale_factor=2, mode='nearest')
        h = self.lrelu(self.bn3_1(self.conv3_1(h)))
        h = torch.sigmoid(self.conv3_2(h))
        return h
    
class FusionLayer(nn.Module):
    def __init__(self, in_ch=256+256+216):
        super().__init__()
        self.conv = nn.Conv2d(in_channels=in_ch, out_channels=256, kernel_size=1)

    def forward(self, h, one_dimension_feature_list):
        h_feature = torch.cat(one_dimension_feature_list, 1)
        h_feature = h_feature[:, :, None, None].expand(-1, -1, *h.shape[2:])
        h = torch.cat([h, h_feature], 1)
        h = torch.relu(self.conv(h))
        return h

class GlobalNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1_1 = nn.Conv2d(512, 512, 3, stride=2, padding=1, padding_mode='replicate')
        self.bn1_1 = nn.BatchNorm2d(512)
        self.conv1_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1, padding_mode='replicate')
        self.bn1_2 = nn.BatchNorm2d(512)

        self.conv2_1 = nn.Conv2d(512, 512, 3, stride=2, padding=1, padding_mode='replicate')
        self.bn2_1 = nn.BatchNorm2d(512)
        self.conv2_2 = nn.Conv2d(512, 512, 3, stride=1, padding=1, padding_mode='replicate')
        self.bn2_2 = nn.BatchNorm2d(512)

        self.l3_1 = nn.Linear(7 * 7 * 512, 1024)
        self.bn3_1 = nn.BatchNorm1d(1024)
        self.l3_2 = nn.Linear(1024, 512)
        self.bn3_2 = nn.BatchNorm1d(512)
        self.l3_3 = nn.Linear(512, 256)
        self.bn3_3 = nn.BatchNorm1d(256)
        
        self.lrelu = nn.LeakyReLU(0.2, inplace=True)

    def forward(self, x):
        h = x
        h = self.lrelu(self.bn1_1(self.conv1_1(h)))
        h = self.lrelu(self.bn1_2(self.conv1_2(h)))
        h = self.lrelu(self.bn2_1(self.conv2_1(h)))
        h = self.lrelu(self.bn2_2(self.conv2_2(h)))
        h = h.reshape(h.shape[0], -1)
        h = self.lrelu(self.bn3_1(self.l3_1(h)))
        h = self.lrelu(self.bn3_2(self.l3_2(h)))
        h = self.lrelu(self.bn3_3(self.l3_3(h)))
        return h
    
class HistogramNetwork(nn.Module):
    def __init__(
            self,
            num_bins=6,
    ):
        super().__init__()
        self.num_bins = num_bins

    def forward(self, h: torch.Tensor):
        B, C, H, W = h.shape
        h = h.clamp(0, 1)
        h = F.interpolate(h.detach(), (64, 64), mode='nearest')

        histogram_list = []
        
        for img in h:
            img: torch.Tensor
            array = torch.empty(img.shape, dtype=torch.int32, device=img.device)
            for i in range(C):
                array[i] = img[i] * self.num_bins
                array[i] = torch.where(array[i] == self.num_bins, array[i] - 1, array[i])
            array = (
                array[0] * self.num_bins ** 2
                + array[1] * self.num_bins
                + array[2]
            ).reshape(-1)
            img_hist = torch.bincount(array, minlength=self.num_bins ** C)
            histogram_list.append(img_hist)

        h = torch.stack(histogram_list) / (H * W)
        h = h.float()

        return h

class Comi(nn.Module):
    def __init__(self):
        super().__init__()
        self.low_level = LowLevelNetwork(1)
        self.mid_level = MidLevelNetwork()
        self.fusion_layer = FusionLayer()
        self.colorization = ColorizationNetwork()
        self.global_network = GlobalNetwork()
        self.histogram_network = HistogramNetwork()

    def forward(self, sketch, reference):
        # sketch: [B, 1, 256, 256]
        # reference: [B, 3, 256, 256]
        h = self.low_level(sketch) # [B, 512, 32, 32]
        h_sketch = self.mid_level(h) # [B, 256, 32, 32]
        h_global = self.global_network(h) # [B, 256, 32, 32]
        h_histogram = self.histogram_network(reference)
        # print(h_sketch.shape, h_global.shape, h_histogram.shape)
        h = self.fusion_layer(h_sketch, [h_global, h_histogram])
        h = self.colorization(h)
        h = F.interpolate(h, scale_factor=2, mode='bilinear')
        return h


class Discriminator(nn.Module):
    def __init__(self, size=224):
        super().__init__()
        self.size = size
        self.c0 = nn.Conv2d(3, 64, 4, 2, 1, padding_mode='replicate')
        self.c1 = nn.Conv2d(64, 128, 4, 2, 1, padding_mode='replicate')
        self.c2 = nn.Conv2d(128, 256, 4, 2, 1, padding_mode='replicate')
        self.c3 = nn.Conv2d(256, 512, 4, 2, 1, padding_mode='replicate')
        self.bn0 = nn.BatchNorm2d(64)
        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(512)
        self.flatten = nn.Flatten()
        self.l0z = nn.Linear((size // (2 ** 5)) ** 2 * 512, 1)

    def forward(self, x):
        assert x.shape[2:] == (self.size,)*2
        h = F.avg_pool2d(x, 2, 2)
        h = torch.relu(self.c0(h))
        h = torch.relu(self.bn1(self.c1(h)))
        h = torch.relu(self.bn2(self.c2(h)))
        h = torch.relu(self.bn3(self.c3(h)))
        h = self.flatten(h)
        l = self.l0z(h)
        return l


if __name__ == '__main__':
    device = 'cuda:0'
    net = Comi().to(device)
    c = torch.randn(2, 1, 224, 224).to(device)
    x = torch.randn(2, 3, 224, 224).to(device)
    dis = Discriminator().to(device)
    with torch.no_grad():
        out = net(c.to(device), x.to(device))
        print(out.shape)
        d = dis(out)
        print(d.shape)
    import ei
    ei.embed()
    