import torch

from b2b.torch.select                           import select_optimizer
from b2b.base.losses                            import GANLoss, cal_gradient_penalty
from b2b.base.vgg_loss                          import VGGLoss
from b2b.models.discriminator                   import construct_discriminator
from b2b.models.generator                       import construct_generator
from b2b.models.deblur.nafnet.archs.NAFNet_arch import NAFNetLocal
from b2b.models.deblur.nafnet.util              import load_network, define_network
from b2b.models.explore.kernel_encoding.kernel_wizard import KernelWizard

from .model_base import ModelBase
from .named_dict import NamedDict
from .funcs import set_two_domain_input

import yaml
import torch.nn.functional as F
from copy import deepcopy
import numpy as np
import cv2

class B2BBaseModel(ModelBase):

    def _setup_images(self, _config):
        return NamedDict('real_a', 'fake_b', 'real_b')

    def _setup_models(self, config):
        models = { }

        image_shape_a = config.data.datasets[0].shape
        image_shape_b = config.data.datasets[1].shape

        assert image_shape_a[1:] == image_shape_b[1:], \
            "Pix2Pix needs images in both domains to have the same size"

        models['gen_ab'] = construct_generator(
            config.generator, image_shape_a, image_shape_b, self.device
        )

        if self.is_train:
            extended_image_shape = (
                image_shape_a[0] + image_shape_b[0], *image_shape_a[1:]
            )

            for name in [ 'disc_b' ]:
                models[name] = construct_discriminator(
                    config.discriminator, extended_image_shape, self.device
                )

        return NamedDict(**models)

    def _setup_losses(self, config):
        return NamedDict(
            'gen_ab', 'gen_perc_ab', 'disc_b', 'disc_b_fake', 'disc_b_real'
            
        )

    def _setup_optimizers(self, config):
        optimizers = NamedDict('gen_ab', 'disc_b')

        optimizers.gen_ab = select_optimizer(
            self.models.gen_ab.parameters(), config.generator.optimizer
        )
        optimizers.disc_b = select_optimizer(
            self.models.disc_b.parameters(), config.discriminator.optimizer
        )

        return optimizers


    def __init__(self, savedir, config, is_train, device,
        lambda_idt = 0.8, lambda_triplet=0.5):
        super().__init__(savedir, config, is_train, device)

        assert len(config.data.datasets) == 2, \
            "Blur2Blur expects a pair of datasets"
        self.lambda_idt = lambda_idt
        self.lambda_triplet = lambda_triplet

        self.criterion_gan    = GANLoss(config.loss).to(self.device)
        self.criterion_perc   = VGGLoss().to(self.device)
        self.criterion_triplet = torch.nn.TripletMarginLoss(margin=1.0, p=2)
        self.gradient_penalty = config.gradient_penalty

        if not self.is_train:
            with open("options/deblur/REDS/NAFNet-width64.yml", "r") as f:
                opt = yaml.load(f)
            self.deblur = define_network(deepcopy(opt['network_g']))
            self.deblur = self.deblur.to(self.device)
            load_path = opt['path'].get('pretrain_network_g', None)
            if load_path is not None:
                load_network(self.deblur, load_path,
                                opt['path'].get('strict_load_g', True), param_key=opt['path'].get('param_key', 'params'))
            self.deblur.eval()
        
        with open("b2b/models/explore/generate_blur/augmentation.yml", "r") as f:
            opt = yaml.load(f)["KernelWizard"]
            model_path = opt["pretrained"]
        self.genblur = KernelWizard(opt)
        print("Loading KernelWizard...")
        self.genblur.eval()
        self.genblur.load_state_dict(torch.load(model_path))
        self.genblur = self.genblur.to(self.device)

    def _set_input(self, inputs, domain):
        set_two_domain_input(self.images, inputs, domain, self.device)

    def deblurring_step(self, x):

        self.deblur.eval()
        with torch.no_grad():
            pred = self.deblur(x)
            if isinstance(pred, list):
                pred = pred[-1]
        
        return pred

    def forward(self):

        with torch.no_grad():
            kernel_mean, kernel_sigma = self.genblur(self.real_b, self.real_c)
            # breakpoint()
            # self.kernel_real = kernel_mean 
            self.kernel_real = kernel_mean + kernel_sigma * torch.randn_like(kernel_mean)
            self.real_b = self.genblur.adaptKernel(self.real_d, self.kernel_real)
        
        self.images.fake_b = self.models.gen_ab(self.images.real_a)
        self.fake_b_ = self.fake_b[2]


        if not self.is_train:
            self.images.deb_fake_b = self.deblurring_step(self.images.fake_b)
            self.images.deb_real_a = self.deblurring_step(self.images.real_a)
        

    def backward_discriminator_base(self, model, real, fake, preimage):
        cond_real = torch.cat([real, real], dim = 1)
        cond_fake = torch.cat([fake, fake], dim = 1).detach()

        pred_real = model(cond_real)
        loss_real = self.criterion_gan(pred_real, True)

        pred_fake = model(cond_fake)
        loss_fake = self.criterion_gan(pred_fake, False)

        loss = (loss_real + loss_fake) * 0.5
        if self.gradient_penalty is not None:
            loss += cal_gradient_penalty(
                model, cond_real, cond_fake, real.device,
                **self.gradient_penalty
            )[0]

        loss.backward()
        return (loss, loss_real, loss_fake)

    def backward_discriminators(self):
        self.losses.disc_b, self.losses.disc_b_real, self.losses.disc_b_fake = self.backward_discriminator_base(
            self.models.disc_b,
            self.images.real_b, self.images.fake_b_, self.images.real_a
        )

    def backward_generator_base(self, disc, real, fake, preimage):
        loss_gen = self.criterion_gan(
            disc(torch.cat([fake, fake], dim = 1)), True
        )

        real_a1 = F.interpolate(preimage, scale_factor=0.5, mode='bilinear')
        real_a0 = F.interpolate(preimage, scale_factor=0.25, mode='bilinear')
        perc1 = self.criterion_perc.forward(fake[0], real_a0)
        perc2 = self.criterion_perc.forward(fake[1], real_a1)
        perc3 = self.criterion_perc.forward(fake[2], preimage) 
        loss_perc = (perc1 + perc2 + perc3) * self.lambda_idt


        loss = loss_gen + loss_perc
        loss.backward()
        return (loss_gen, loss_perc)

    def backward_generators(self):
        self.losses.gen_ab, self.losses.gen_perc_ab = self.backward_generator_base(
            self.models.disc_b,
            self.images.real_b, self.images.fake_b, self.images.real_a
        )

    def optimization_step(self):
        self.forward()

        # Generators
        self.set_requires_grad(self.models.disc_b, False)
        self.optimizers.gen_ab.zero_grad()
        self.backward_generators()
        self.optimizers.gen_ab.step()

        # Discriminators
        self.set_requires_grad(self.models.disc_b, True)
        self.optimizers.disc_b.zero_grad()
        self.backward_discriminators()
        self.optimizers.disc_b.step()

