import torch.nn as nn
import torch.nn.functional as F
import torch
from torchvision.models import vgg19
import math


class FeatureExtractor(nn.Module):
    def __init__(self):
        super(FeatureExtractor, self).__init__()
        vgg19_model = vgg19(pretrained=True)
        self.vgg19_54 = nn.Sequential(*list(vgg19_model.features.children())[:35])

    def forward(self, img):
        return self.vgg19_54(img)


class DenseResidualBlock(nn.Module):
    """
    The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
    """

    def __init__(self, filters, res_scale=0.2, n_layers=4, dilation=False):
        super().__init__()
        self.res_scale = res_scale
        if dilation:
            self.dilation = [1, 2, 3]  # [1, 2, 3] or [1, 2, 5], 
        else:
            self.dilation = [1, 1, 1]

        def block(in_features, norm=True, non_linearity=True, dilation=1):
            # to preserve the output size, padding = padding+(dilation-1)*(kernel_size-1)//2
            layers = [nn.Conv2d(in_features, filters, 3, 1, padding=1+dilation-1, bias=True, dilation=dilation)]

            if norm:
                layers += [nn.InstanceNorm2d(filters)]

            if non_linearity:
                layers += [nn.LeakyReLU(0.2, inplace=True)]

            return nn.Sequential(*layers)

        # self.b1 = block(in_features=1 * filters)
        # self.b2 = block(in_features=2 * filters)
        # self.b3 = block(in_features=3 * filters)
        # self.b4 = block(in_features=4 * filters)
        # self.blocks = [self.b1, self.b2, self.b3, self.b4]
        
        self.blocks = nn.ModuleList()
        for i in range(n_layers):
            self.blocks.add_module('dense_layer_%d'%i, block(in_features=(i+1) * filters, dilation=self.dilation[i%3]))

        self.conv = nn.Conv2d((n_layers+1) * filters, filters, 1, 1, 0, bias=True)

    def forward(self, x):
        inputs = x
        for block in self.blocks:
            out = block(inputs)
            inputs = torch.cat([inputs, out], 1)

        out = self.conv(inputs)
        return out.mul(self.res_scale) + x


class ResidualInResidualDenseBlock(nn.Module):
    def __init__(self, filters, res_scale=0.2, n_layers=4, n_drb=3, dilation=False):
        super().__init__()
        self.res_scale = res_scale

        list_ = []
        for i in range(n_drb):
            list_.append(DenseResidualBlock(filters, n_layers=n_layers, dilation=dilation))
        self.dense_blocks = nn.Sequential(*list_)

    def forward(self, x):
        return self.dense_blocks(x).mul(self.res_scale) + x


class GeneratorRRDB(nn.Module):
    def __init__(self, channels, filters=64, num_res_blocks=16, num_upsample=2):
        super(GeneratorRRDB, self).__init__()

        # First layer
        self.conv1 = nn.Conv2d(channels, filters, kernel_size=3, stride=1, padding=1)
        # Residual blocks
        self.res_blocks = nn.Sequential(*[ResidualInResidualDenseBlock(filters) for _ in range(num_res_blocks)])
        # Second conv layer post residual blocks
        self.conv2 = nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1)
        # Upsampling layers
        upsample_layers = []
        for _ in range(num_upsample):
            upsample_layers += [
                nn.Conv2d(filters, filters * 4, kernel_size=3, stride=1, padding=1),
                nn.LeakyReLU(),
                nn.PixelShuffle(upscale_factor=2),
            ]
        self.upsampling = nn.Sequential(*upsample_layers)
        # Final output block
        self.conv3 = nn.Sequential(
            nn.Conv2d(filters, filters, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
            nn.Conv2d(filters, channels, kernel_size=3, stride=1, padding=1),
        )

    def forward(self, x):
        out1 = self.conv1(x)
        out = self.res_blocks(out1)
        out2 = self.conv2(out)
        out = torch.add(out1, out2)
        out = self.upsampling(out)
        out = self.conv3(out)
        return out


class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()

        self.input_shape = input_shape
        in_channels, in_height, in_width = self.input_shape
        patch_h, patch_w = int(in_height / 2 ** 4), int(in_width / 2 ** 4)
        self.output_shape = (1, patch_h, patch_w)

        def discriminator_block(in_filters, out_filters, first_block=False):
            layers = []
            layers.append(nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, padding=1))
            if not first_block:
                layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            layers.append(nn.Conv2d(out_filters, out_filters, kernel_size=3, stride=2, padding=1))
            layers.append(nn.BatchNorm2d(out_filters))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        layers = []
        in_filters = in_channels
        for i, out_filters in enumerate([64, 128, 256, 512]):
            layers.extend(discriminator_block(in_filters, out_filters, first_block=(i == 0)))
            in_filters = out_filters

        layers.append(nn.Conv2d(out_filters, 1, kernel_size=3, stride=1, padding=1))

        self.model = nn.Sequential(*layers)

    def forward(self, img):
        return self.model(img)
