'''
python -m others.comi.train --device cuda:0 --output-dir output/comi/224_1 --num-workers 8

ReLU -> LeakyReLU
Discriminator -> Icon, Hinge Loss
'''

import ei
import os, sys
import json
import warnings


def get_icon_dataset(image_size=224, train_ratio=0.9):
    from .dataset import IconContourDataset

    root = f'datasets/icon4/data/in_memory'
    
    n = min(train_ratio, 0.9)

    splits = {
        'train': IconContourDataset(root=root, image_size=256, output_size=image_size, split=(0, n)),
        'test': IconContourDataset(root=root, image_size=256, output_size=image_size, split=(n, 1.0))
    }

    if train_ratio > 0.0:
        splits['aug_train'] = IconContourDataset(
            root=root,
            image_size=256,
            output_size=image_size,
            random_crop=True,
            random_transpose=True,
            split=(0, train_ratio)
        )
        
    return splits

def get_net():
    import torch.nn as nn
    from .model import Comi
    from .enhence import Discriminator
    return nn.ModuleDict({
        'G': Comi(),
        'D': Discriminator()
    })


def get_opt(net):
    import torch.optim as optim
    opt_G = optim.Adam(net['G'].parameters(), lr=0.0001, weight_decay=0.0001)
    opt_D = optim.Adam(net['D'].parameters(), lr=0.0001, weight_decay=0.0001)
    return opt_G, opt_D

def train(
    image_size=224,
    
    device='cpu',
    batch_size=64,
    num_workers=4,
    output_dir='output/temp/test',
    train_ratio=0.9,
    
    log_int=100,
    sample_int=1000,
    save_int=1000,
    save_iters=[],
    end_iter=300000,
    
    G_lambda=50.0,
):
    
    assert image_size == 224
    
    ei.patch()
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    config = locals().copy()
    config['argv'] = sys.argv.copy()
    warnings.filterwarnings("ignore")
    
    import random
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard.writer import SummaryWriter
    from torchvision.utils import make_grid
    from tqdm import tqdm

    from iconflow.utils import cycle_iter, random_sampler
    
    os.makedirs(output_dir, exist_ok=True)
    device = torch.device(device)
    writer = SummaryWriter(output_dir)


    '''
    Dataset
    '''
    
    d = get_icon_dataset(
        image_size,
        train_ratio=train_ratio,
    )
    
    train_loader = DataLoader(
        d['aug_train'],
        batch_size=batch_size,
        sampler=random_sampler(len(d['aug_train'])),
        pin_memory=(device.type == 'cuda'),
        num_workers=num_workers
    )
    

    '''
    Model
    '''

    net = get_net()
    opt_G, opt_D = get_opt(net)
    net.to(device)
    
    G = net['G']
    D = net['D']
    
    
    try:
        ckpt_path = os.path.join(output_dir, 'checkpoint.pt')
        state = torch.load(ckpt_path, map_location=device)
        net.load_state_dict(state['net'])
        opt_G.load_state_dict(state['opt_G'])
        opt_D.load_state_dict(state['opt_D'])
        it = state['it']
        print(f'loaded from checkpoint, it: {it}')
        del state
    except:
        it = 0
    
    
    '''
    Save config
    '''
    
    print(config)
    
    with open(os.path.join(output_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=2)
    
    
    '''
    Sample
    '''
    
    @torch.no_grad()
    def sample():
        G.eval()

        n_train, n_test = 2, 2

        rng = random.Random(it % 300000 + 1337)
        x1, c1 = zip(*rng.choices(d['train'], k=n_train),
                     *rng.choices(d['test'], k=n_test))
        x1, c1 = map(torch.stack, (x1, c1))
        x1 = x1.to(device)
        c1 = c1.to(device)
        
        r12_list = []
        
        for i in range(len(x1)):
            r12 = G(c1, x1.roll(i, 0).contiguous())
            r12_list.append(r12)
        
        rows = torch.stack([x1, c1.expand_as(x1), *r12_list])
        images = rows.permute(1, 0, 2, 3, 4).reshape(-1, *rows.shape[-3:])
        
        return make_grid(images, nrow=len(rows))
    
    '''
    Save
    '''
    
    def save(add_postfix=False):
        if add_postfix:
            file_name = f'checkpoint_{it}.pt'
        else:
            file_name = 'checkpoint.pt'
        output_path = os.path.join(output_dir, file_name)

        torch.save({
            'net': net.state_dict(),
            'opt_G': opt_G.state_dict(),
            'opt_D': opt_D.state_dict(),
            'it': it,
            'config': config,
        }, output_path)
    
    
    '''
    Train
    '''
    
    def step(X: torch.Tensor, C: torch.Tensor):
        x1 = X.to(device)
        c1 = C.to(device)
        
        log_dict = {}
        
        G.train()
        D.train()
        
        output_G = G(c1, x1)
        
        '''
        Update G
        '''
        
        G_rec = F.mse_loss(output_G, x1)
        log_dict['G_rec'] = G_rec.item()
        
        G_gan = -D(output_G).mean()
        log_dict['G_gan'] = G_gan.item()
        
        opt_G.zero_grad()
        (G_gan + G_lambda * G_rec).backward()
        opt_G.step()
        
        del G_gan, G_rec
        
        '''
        Update D
        '''
        
        D_gan_real = torch.relu(1 - D(x1)).mean()
        D_gan_fake = torch.relu(1 + D(output_G.detach())).mean()
        D_gan = D_gan_fake + D_gan_real
        log_dict['D_gan'] = D_gan.item()
        log_dict['D_fake'] = D_gan_fake.item()
        log_dict['D_real'] = D_gan_real.item()
        
        opt_D.zero_grad()
        D_gan.backward()
        opt_D.step()
        
        return log_dict

    try:
        with tqdm(cycle_iter(train_loader), total=end_iter-it) as iter_loader:
            for X, C in iter_loader:
                if isinstance(end_iter, int) and it >= end_iter:
                    raise KeyboardInterrupt
                
                log_dict = step(X, C)
                
                iter_loader.set_postfix({'it': it, **log_dict})
 
                if it % log_int == 0:
                    for key, value in log_dict.items():
                        writer.add_scalar(f'comi/{key}', value, it)
                
                if (it < 1000 and it % 200 == 0) or (it >= 1000 and it % sample_int == 0):
                    writer.add_image(f'comi/sample', sample(), it)
                
                if it % save_int == 0:
                    save()
                
                if (it + 1) in save_iters:
                    save(add_postfix=True)
                
                it += 1

    except KeyboardInterrupt:
        print('saving checkpoint')
        save()


if __name__ == '__main__':
    import fire

    fire.Fire(train)
