import os
import ei
import json
import torch
import torch.nn as nn
from .model.rescae import get_residual_block
    

class Upsampler128to256(nn.Module):
    def __init__(self, x_ch=3, c_ch=1, dim=64):
        super().__init__()
        
        self.x_up = nn.Sequential(
            get_residual_block(x_ch, dim),
            nn.UpsamplingBilinear2d(256),
        )
        
        self.c_res = get_residual_block(c_ch, dim)
        
        self.r_res = nn.Sequential(
            get_residual_block(dim + dim, dim),
            get_residual_block(dim, dim),
            nn.Conv2d(dim, x_ch, 3, padding='same', padding_mode='replicate')
        )
        
    def forward(self, x, c):
        assert x.dim() == 4 and x.shape[1:] == (3, 128, 128)
        assert c.dim() == 4 and c.shape[1:] == (1, 256, 256)
        
        return torch.tanh(self.r_res(torch.cat([self.x_up(x), self.c_res(c)], 1))) / 2


class Upsampler128to512(nn.Module):
    def __init__(self, x_ch=3, c_ch=1, dim=64):
        super().__init__()
        
        self.x_up = nn.Sequential(
            get_residual_block(x_ch, dim),
            nn.UpsamplingBilinear2d(256),
            get_residual_block(dim, dim),
            nn.UpsamplingBilinear2d(512)
        )
        
        self.c_res = get_residual_block(c_ch, dim)
        
        self.r_res = nn.Sequential(
            get_residual_block(dim + dim, dim),
            get_residual_block(dim, dim),
            get_residual_block(dim, dim),
            get_residual_block(dim, dim),
            nn.Conv2d(dim, x_ch, 3, padding='same', padding_mode='replicate')
        )
        
    def forward(self, x, c):
        assert x.dim() == 4 and x.shape[1:] == (3, 128, 128)
        assert c.dim() == 4 and c.shape[1:] == (1, 512, 512)
        
        return torch.tanh(self.r_res(torch.cat([self.x_up(x), self.c_res(c)], 1))) / 2


def load_method(res=512, output_dir='output/iconflow/final'):
    from .train_net import load_method
    
    up_output_dir = os.path.join(output_dir, f'up_{res}')
    
    method = load_method(output_dir)
    lr_size = method.image_size
    net = method.net

    with open(os.path.join(up_output_dir, 'config.json')) as f:
        config = json.load(f)

    up: nn.Module = {
        256: Upsampler128to256,
        512: Upsampler128to512
    }[config['image_size']](dim=config['up_dim'])
    
    up_state = torch.load(os.path.join(up_output_dir, 'checkpoint.pt'), map_location='cpu')['up']
    up.load_state_dict(up_state)
    
    from typing import List
    from PIL import Image
    import torchvision.transforms.functional as T
    from others.api import ComparedMethod
    
    class IconFlowHighRes(ComparedMethod):
        image_size = config['image_size']
        
        def __init__(self, net, up) -> None:
            super().__init__()
            self.net = net.eval()
            self.up = up.eval()
            
        @torch.no_grad()
        def forward(self, c_lr: torch.Tensor, x_lr: torch.Tensor, c: torch.Tensor):
            r = self.net(c_lr - 0.5, x_lr - 0.5)
            r = self.up(r, c - 0.5) + 0.5
            return r
        
        def inputs_to_tensor(self, cs: List[Image.Image], xs: List[Image.Image]):
            cs = [c.convert(self.c_mode) for c in cs]
            xs = [x.convert(self.x_mode) for x in xs]
            cs_lr = [c.resize((lr_size, lr_size), Image.BICUBIC) for c in cs]
            xs_lr = [x.resize((lr_size, lr_size), Image.BICUBIC) for x in xs]
            return (torch.stack([T.to_tensor(c).to(self.device) for c in cs_lr]),
                    torch.stack([T.to_tensor(x).to(self.device) for x in xs_lr]),
                    torch.stack([T.to_tensor(c).to(self.device) for c in cs]))
        
        def check_inputs(self, c_lr: torch.Tensor, x_lr: torch.Tensor, c: torch.Tensor):
            assert c_lr.shape[0] == x_lr.shape[0] == c.shape[0]
            assert c_lr.shape[1:] == (self.c_ch, lr_size, lr_size)
            assert x_lr.shape[1:] == (self.x_ch, lr_size, lr_size)
            assert c.shape[1:] == (self.c_ch, self.image_size, self.image_size)

    return IconFlowHighRes(net, up)

def train(
    device,
    image_size,
    output_dir='output/iconflow/final',
    up_dim=64,
    dataset_root='datasets/icon/data/in_memory',
    batch_size=64,
    num_workers=12,
    log_int=100,
    sample_int=1000,
    save_int=500,
    save_iters=[],
    end_iter=300000,
):
    assert image_size in (256, 512)
    
    net_output_dir = output_dir
    output_dir = os.path.join(net_output_dir, f'up_{image_size}')
    
    config = locals().copy()
    
    import random
    import torch.nn.functional as F
    from torch import optim
    from torchvision.utils import make_grid
    from torch.utils.data.dataloader import DataLoader
    from iconflow.dataset import IconContourDownscaleDataset
    from .utils import cycle_iter, random_sampler
    from tqdm import tqdm
    from torch.utils.tensorboard.writer import SummaryWriter
    from iconflow.train_net import load_method
    ei.patch()
    
    os.makedirs(output_dir, exist_ok=True)
    
    with open(os.path.join(output_dir, 'config.json'), 'w') as f:
        json.dump(config, f)
    
    device = torch.device(device)
    writer = SummaryWriter(output_dir)
    
    net = load_method(net_output_dir).net.to(device).eval()
    up: nn.Module = {
        256: Upsampler128to256,
        512: Upsampler128to512
    }[image_size](3, 1, up_dim)
    opt = optim.Adam(up.parameters(), lr=1e-4)
    up.to(device)
    
    d = {
        'aug_train': IconContourDownscaleDataset(dataset_root, image_size, 128, True, True, True, (0.0, 0.9), legacy_normalize=True),
        'train': IconContourDownscaleDataset(dataset_root, image_size, 128, split=(0, 0.9), legacy_normalize=True),
        'test': IconContourDownscaleDataset(dataset_root, image_size, 128, split=(0.9, 1.0), legacy_normalize=True)
    }
    
    train_loader = DataLoader(
        d['aug_train'],
        batch_size=batch_size,
        sampler=random_sampler(len(d['aug_train'])),
        pin_memory=(device.type == 'cuda'),
        num_workers=num_workers
    )
    
    try:
        ckpt_path = os.path.join(output_dir, 'checkpoint.pt')
        state = torch.load(ckpt_path, map_location=device)
        up.load_state_dict(state['up'])
        opt.load_state_dict(state['opt'])
        it = state['it']
        print(f'loaded from checkpoint, it: {it}')
        del state
    except:
        it = 0
    
    '''
    Sample
    '''
    
    def sample():
        up.eval()

        n_train, n_test = {256: (2, 2),
                           512: (1, 1)}[image_size]

        rng = random.Random(it % 300000 + 1337)
        sX, sC, _, sC_up = zip(*rng.choices(d['train'], k=n_train),
                               *rng.choices(d['test'], k=n_test))
        sX, sC, sC_up = map(torch.stack, [sX, sC, sC_up])
        
        
        with torch.no_grad():
            sRs = [net(sC.to(device), torch.roll(sX, n, 0).to(device)) for n in range(sX.shape[0])]
            sUs = [up(rs, sC_up.to(device)).cpu() for rs in sRs]
            sRs = [rs.cpu() for rs in sRs]
            
        sC = sC.repeat(1, 3, 1, 1) if sC.shape[1] == 1 else sC
        sGr = torch.stack([sX, sC, *sRs], 1)
        sGr = torch.clamp(sGr + 0.5, 0.0, 1.0)
        image1 = make_grid(sGr.reshape(-1, *sGr.shape[2:]), nrow=sGr.shape[1])
        
        sGu = torch.stack([
            F.upsample_nearest(sX, image_size),
            F.upsample_nearest(sC, image_size),
            *sUs
        ], 1)
        sGu = torch.clamp(sGu + 0.5, 0.0, 1.0)
        image2 = make_grid(sGu.reshape(-1, *sGu.shape[2:]), nrow=sGu.shape[1])
        
        return image1, image2
    
    '''
    Save
    '''
    
    def save(add_postfix=False):
        if add_postfix:
            file_name = f'checkpoint_{it}.pt'
        else:
            file_name = 'checkpoint.pt'
        output_path = os.path.join(output_dir, file_name)

        torch.save({
            'up': up.state_dict(),
            'opt': opt.state_dict(),
            'it': it,
            'config': config,
        }, output_path)
    
    
    '''
    Train
    '''
    
    def step(X, C, X_up, C_up):
        X = X.to(device) # 256*256 or 512*512
        C = C.to(device)
        X_up = X_up.to(device) # 128*128
        C_up = C_up.to(device) # 128*128
        
        log_dict = {}
        
        with torch.no_grad():
            R = net(C, X)
            
        loss = 0

        mse_loss = F.mse_loss(up(R, C_up), X_up)
        log_dict['mse_loss'] = mse_loss.item()
        loss = loss + mse_loss
            
        if loss.isnan():
            ei.embed('Loss contains nan!!!', exit=True)

        opt.zero_grad()
        loss.backward()
        opt.step()
        net.zero_grad()
        
        return log_dict


    try:
        with tqdm(cycle_iter(train_loader), total=end_iter-it) as iter_loader:
            for X, C, X_up, C_up in iter_loader:
                if isinstance(end_iter, int) and it >= end_iter:
                    raise KeyboardInterrupt
                
                log_dict = step(X, C, X_up, C_up)
                
                iter_loader.set_postfix({'it': it, **log_dict})
 
                if it % log_int == 0:
                    for key, value in log_dict.items():
                        writer.add_scalar(f'training/{key}', value, it)
                
                if (it < 1000 and it % 200 == 0) or (it >= 1000 and it % sample_int == 0):
                    for i, img in enumerate(sample()):
                        writer.add_image(f'sampling/sample{i}', img, it)
                
                if it % save_int == 0:
                    save()
                
                if (it + 1) in save_iters:
                    save(add_postfix=True)
                
                it += 1

    except KeyboardInterrupt:
        print('saving checkpoint')
        save()

if __name__ == '__main__':
    import fire

    fire.Fire(train)
