import ei
import os, sys
import json
import warnings


def get_icon_dataset(image_size, random_color=True, train_ratio=0.9):
    from iconflow.dataset import IconContourDataset

    root = f'datasets/icon/data/in_memory'
    
    n = min(train_ratio, 0.9)

    splits = {
        'train': IconContourDataset(root, image_size, split=(0, n), legacy_normalize=True),
        'test': IconContourDataset(root, image_size, split=(n, 1.0), legacy_normalize=True)
    }

    if train_ratio > 0.0:
        splits['aug_train'] = IconContourDataset(
            root, image_size,
            random_crop=True,
            random_transpose=True,
            random_color=random_color,
            split=(0, train_ratio),
            legacy_normalize=True,
        )

    return splits

def get_net(
    c_ch=1,
    x_ch=3,
    e_dim=16,
    s_dim=48,
    unet_arch='M51',
    cc_arch='S31',
    resnet_depth=50,
    decoder_width=32,
    decoder_depth=4,
):
    config = locals().copy()
    
    if cc_arch == 'NU3':
        cc_arch = 'S31'
    
    import torch
    import torch.nn as nn
    
    from .model.rescae import get_residual_unet
    from .model.resnet import get_resnet_by_depth
    from .model.nconv import NormConv2d
    
    class Net(nn.Module):
        def __init__(self):
            super().__init__()
            
            self.config = config
            
            self.content_encoder = nn.Sequential(
                get_residual_unet(c_ch, e_dim, unet_arch),
                nn.Tanh()
            )
            
            self.content_extractor = nn.Sequential(
                NormConv2d(16, 2),
                get_residual_unet(16, 1, cc_arch),
                nn.Tanh()
            )
            
            self.style_encoder = nn.Sequential(
                get_resnet_by_depth(resnet_depth, num_classes=s_dim, in_channels=x_ch),
                nn.Tanh()
            )
            
            decoder_layers = []
            last_width = e_dim + s_dim
            for _ in range(decoder_depth - 1):
                decoder_layers += [nn.Linear(last_width, decoder_width),
                                   nn.LayerNorm(decoder_width),
                                   nn.ReLU()]
                last_width = decoder_width
            decoder_layers += [nn.Linear(last_width, x_ch),
                               nn.Tanh()]
            self.decoder = nn.Sequential(*decoder_layers)
        
        def extract_content(self, x):
            c = self.content_extractor(x) / 2
            return c
        
        def encode_content(self, c):
            e = self.content_encoder(c)
            return e
        
        def encode_style(self, x):
            s = self.style_encoder(x)
            return s
        
        def decode(self, e, s):
            b = e.shape[0]
            ch = e.shape[1] + s.shape[1]
            size = e.shape[2:]
            s = s[:, :, None, None].expand(-1, -1, *size)
            h = torch.cat([e, s], 1)
            h = h.permute(0, 2, 3, 1).reshape(-1, ch)
            r = self.decoder(h)
            r = r.view(b, *size, -1).permute(0, 3, 1, 2).contiguous()
            r = r / 2
            return r
        
        def forward(self, c, x):
            e = self.encode_content(c)
            s = self.encode_style(x)
            r = self.decode(e, s)
            return r
            
    return Net()


def load_method(output_dir='output/iconflow/final'):
    import torch
    from others.api import ComparedMethod

    with open(os.path.join(output_dir, 'config.json')) as f:
        config = json.load(f)

    net = get_net(
        1, 3,
        config['e_dim'], config['s_dim'], config['unet_arch'], config['cc_arch'],
        config['resnet_depth'], config['decoder_width'], config['decoder_depth']
    )
    
    net.load_state_dict(torch.load(os.path.join(output_dir, 'checkpoint.pt'), map_location='cpu')['net'])

    class IconFlow(ComparedMethod):
        image_size = 128

        def __init__(self, net):
            super().__init__()
            self.net = net.eval()

        @torch.no_grad()
        def forward(self, c: torch.Tensor, x: torch.Tensor):
            return self.net(c - 0.5, x - 0.5) + 0.5

    return IconFlow(net)



def train(
    image_size=128,
    
    device='cpu',
    batch_size=32,
    num_workers=8,
    output_dir='output/iconflow/final',
    train_ratio=0.9,
    
    log_int=100,
    sample_int=1000,
    save_int=500,
    save_iters=[100000,200000,300000,400000,500000],
    end_iter=600000,
    
    random_color=True,
    
    e_dim=16,
    s_dim=48,
    unet_arch='M51',
    cc_arch='S31',
    resnet_depth=50,
    decoder_width=32,
    decoder_depth=4,
    
    cc_loss_weight=1.0,
):
    assert image_size == 128, '128 for the best quality'
    
    ei.patch()
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    config = locals().copy()
    config['argv'] = sys.argv.copy()
    warnings.filterwarnings("ignore")
    
    import random
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard.writer import SummaryWriter
    from torchvision.utils import make_grid
    from tqdm import tqdm

    from .utils import cycle_iter, random_sampler
    
    os.makedirs(output_dir, exist_ok=True)
    device = torch.device(device)
    writer = SummaryWriter(output_dir)


    '''
    Dataset
    '''
    
    d = get_icon_dataset(
        image_size,
        random_color=random_color,
        train_ratio=train_ratio,
    )
    
    train_loader = DataLoader(
        d['aug_train'],
        batch_size=batch_size,
        sampler=random_sampler(len(d['aug_train'])),
        pin_memory=(device.type == 'cuda'),
        num_workers=num_workers
    )
    
    x_ch, c_ch = map(lambda t: t.shape[0], d['train'][0])
    config['c_ch'] = c_ch
    config['x_ch'] = x_ch

    '''
    Model
    '''

    net = get_net(c_ch, x_ch, e_dim, s_dim, unet_arch, cc_arch, resnet_depth, decoder_width, decoder_depth)
    opt = optim.Adam(net.parameters(), lr=1e-4)
    net.to(device)
    
    try:
        ckpt_path = os.path.join(output_dir, 'checkpoint.pt')
        state = torch.load(ckpt_path, map_location=device)
        net.load_state_dict(state['net'])
        opt.load_state_dict(state['opt'])
        it = state['it']
        print(f'loaded from checkpoint, it: {it}')
        del state
    except FileNotFoundError:
        it = 0
    
    
    '''
    Save config
    '''
    
    print(config)
    
    with open(os.path.join(output_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=2)
    
    
    '''
    Sample
    '''
    
    @torch.no_grad()
    def sample():
        net.eval()

        if image_size > 256:
            n_train, n_test = 1, 1
        elif image_size > 128:
            n_train, n_test = 2, 2
        else:
            n_train, n_test = 4, 4

        rng = random.Random(it % 300000 + 1337)
        sX, sC = zip(*rng.choices(d['train'], k=n_train),
                     *rng.choices(d['test'], k=n_test))
        sX, sC = map(torch.stack, [sX, sC])
        
        sRs = [
            net(
                sC.to(device),
                torch.roll(sX, n, 0).to(device)
            ).cpu()
            for n in range(sX.shape[0])
        ]

        sG = torch.stack([sX, sC.expand(-1, 3, -1, -1), *sRs], 1)
        n_cols = sG.shape[1]
        sG = sG.reshape(-1, *sG.shape[2:])
        sG = (sG + 0.5).clamp(0, 1)  # unnormalize
        
        image = make_grid(sG, n_cols)
        
        return image

    @torch.no_grad()
    def sample_extract():
        net.eval()

        if image_size > 256:
            n_train, n_test = 1, 1
        elif image_size > 128:
            n_train, n_test = 2, 2
        else:
            n_train, n_test = 4, 4

        rng = random.Random(it % 300000 + 1337)
        sX, sC = zip(*rng.choices(d['train'], k=n_train),
                     *rng.choices(d['test'], k=n_test))
        sX, sC = map(torch.stack, (sX, sC))
        
        sEX = net.extract_content(sX.to(device)).cpu()
        sR = net(sC.to(device), sX.roll(1, 0).to(device))
        
        sER = net.extract_content(sR).cpu()
        sR = sR.cpu()
        
        sC = sC.expand(-1, 3, -1, -1)
        sEX = sEX.expand(-1, 3, -1, -1)
        sER = sER.expand(-1, 3, -1, -1)

        sG = torch.stack([sC, sX, sEX, sR, sER], 0)
        n_cols = sG.shape[1]
        sG = sG.reshape(-1, *sG.shape[2:])
        sG = (sG + 0.5).clamp(0, 1)  # unnormalize
        
        image = make_grid(sG, n_cols)
        
        return image
    
    '''
    Save
    '''
    
    def save(add_postfix=False):
        if add_postfix:
            file_name = f'checkpoint_{it}.pt'
        else:
            file_name = 'checkpoint.pt'
        output_path = os.path.join(output_dir, file_name)

        torch.save({
            'net': net.state_dict(),
            'opt': opt.state_dict(),
            'it': it,
            'config': config,
        }, output_path)
    
    
    '''
    Train
    '''
    
    def step(X, C):
        x1 = X.to(device)
        c1 = C.to(device)
        
        log_dict = {}
        
        net.train()
        
        e1 = net.encode_content(c1)
        s1 = net.encode_style(x1)
        s2 = s1.roll(1, 0).contiguous()
        r11 = net.decode(e1, s1)
        r12 = net.decode(e1, s2)
        
        loss = 0
        
        # Reconstruction Error
        reconstruction_error = F.mse_loss(r11, x1)
        loss = loss + reconstruction_error
        log_dict['RE'] = reconstruction_error.item()
        
        if cc_loss_weight > 0.0:
            # Extraction Error
            extraction_error = torch.stack([
                F.mse_loss(net.extract_content(x1), c1),
            ]).mean()
            loss = loss + extraction_error
            log_dict['EE'] = extraction_error.item()
        
            # Content Consistency
            net.content_extractor.requires_grad_(False)
            content_consistency = torch.stack([
                F.mse_loss(net.extract_content(r12), c1)
            ]).mean()
            net.content_extractor.requires_grad_(True)
            loss = loss + content_consistency * cc_loss_weight
            log_dict['CC'] = content_consistency.item()
        
        opt.zero_grad()
        loss.backward()
        opt.step()
        opt.zero_grad()
        
        return log_dict

    try:
        with tqdm(cycle_iter(train_loader), total=end_iter-it) as iter_loader:
            for X, C in iter_loader:
                if isinstance(end_iter, int) and it >= end_iter:
                    raise KeyboardInterrupt
                
                log_dict = step(X, C)
                
                iter_loader.set_postfix({'it': it, **log_dict})
 
                if it % log_int == 0:
                    for key, value in log_dict.items():
                        writer.add_scalar(f'training/{key}', value, it)
                
                if (it < 1000 and it % 200 == 0) or (it >= 1000 and it % sample_int == 0):
                    writer.add_image(f'sampling/sample', sample(), it)
                    if cc_loss_weight > 0.0:
                        writer.add_image(f'sampling/extract', sample_extract(), it)
                
                if it % save_int == 0:
                    save()
                
                if (it + 1) in save_iters:
                    save(add_postfix=True)
                
                it += 1

    except KeyboardInterrupt:
        print('saving checkpoint')
        save()


if __name__ == '__main__':
    import fire

    fire.Fire(train)
