import ei
import os, sys
import json
import warnings



def get_flow(z_dim, width, depth, condition_size=2):
    from .model.flow import create_conditional_flow
    return create_conditional_flow(z_dim, (width,) * depth, condition_size, 1)


def get_icon_flow_dataset(image_size, max_samples):
    from .dataset import IconContourDataset
    from .dataset import StylePaletteDataset
    return {
        'style': StylePaletteDataset(
            'datasets/icon/data/in_memory',
            image_size,
            'datasets/icon',
            max_samples,
            legacy_normalize=True
        ),
        'train': IconContourDataset('datasets/icon/data/in_memory',
                                    image_size, split=(0, 0.9),
                                    legacy_normalize=True),
        'test': IconContourDataset('datasets/icon/data/in_memory',
                                   image_size, split=(0.9, 1.0),
                                   legacy_normalize=True)
    }


def load_method(output_dir='output/iconflow/final', flow_folder='flow_v2', up_folder='up_512'):
    import torch
    import torch.nn as nn
    from torch.distributions import Normal
    from PIL import Image

    from .train_net import get_net
    from .train_up import Upsampler128to512, Upsampler128to256
    import torchvision.transforms.functional as T


    '''
    load net
    '''
    net_output_dir = output_dir
    
    with open(os.path.join(net_output_dir, 'config.json')) as f:
        net_config = json.load(f)

    net = get_net(
        1, 3,
        net_config['e_dim'],
        net_config['s_dim'],
        net_config['unet_arch'],
        net_config['cc_arch'],
        net_config['resnet_depth'],
        net_config['decoder_width'],
        net_config['decoder_depth']
    )
    
    net_state = torch.load(os.path.join(net_output_dir, 'checkpoint.pt'), map_location='cpu')
    net.load_state_dict(net_state['net'])
    del net_state


    '''
    load flow
    '''
    
    flow_output_dir = os.path.join(output_dir, flow_folder)
    
    with open(os.path.join(flow_output_dir, 'config.json')) as f:
        flow_config = json.load(f)
    
    flow = get_flow(
        net_config['s_dim'],
        flow_config['flow_width'],
        flow_config['flow_depth']
    )
    
    flow_state = torch.load(os.path.join(flow_output_dir, 'checkpoint.pt'), map_location='cpu')
    flow.load_state_dict(flow_state['flow'])
    del flow_state
    
    
    '''
    load up
    '''
    
    if up_folder is not None:
        up_output_dir = os.path.join(output_dir, up_folder)
    
        with open(os.path.join(up_output_dir, 'config.json')) as f:
            up_config = json.load(f)
            
        up = {
            256: Upsampler128to256,
            512: Upsampler128to512
        }[up_config['image_size']](3, 1, up_config['up_dim'])
        
        up_state = torch.load(os.path.join(up_output_dir, 'checkpoint.pt'), map_location='cpu')
        up.load_state_dict(up_state['up'])
        del up_state
    
    else:
        up = up_config = None


    class IconFlow(nn.Module):
        def __init__(self, net, flow, up, net_config, flow_config, up_config):
            super().__init__()
            
            self.net = net.eval()
            self.flow = flow.eval()
            self.up = up.eval()
            
            self.net_config = net_config
            self.flow_config = flow_config
            self.up_config = up_config
            
            self.image_size = up_config['image_size'] if up_config is not None else net_config['image_size']
            self.base = Normal(0, 1)
            self.register_buffer('zero', torch.zeros([1]))
        
        @property
        def device(self):
            return self.zero.device

        @torch.no_grad()
        def sample_noise(self, n: int, t=1.0):
            zs = self.base.sample((n, self.net_config['s_dim']))
            zs = zs.to(self.device) * t
            return zs
        
        @torch.no_grad()
        def inverse_flow(self, zs: torch.Tensor, ps: torch.Tensor):
            return self.flow(zs, ps / 3, reverse=True)

        @torch.no_grad()
        def sample_style(self, n: int, ps: torch.Tensor, t=1.0):
            m = ps.shape[0]
            zs = self.sample_noise(n, t)
            zs = zs[None, :, :].expand(m, -1, -1)
            ps = ps[:, None, :].expand(-1, n, -1)
            zs = zs.reshape(-1, zs.shape[-1])
            ps = ps.reshape(-1, ps.shape[-1])
            ss = self.inverse_flow(zs, ps)
            ss = ss.reshape(m, n, zs.shape[-1])
            return ss
        
        @torch.no_grad()
        def encode_content(self, cs: torch.Tensor):
            return self.net.encode_content(cs - 0.5)
        
        @torch.no_grad()
        def decode(self, es: torch.Tensor, ss: torch.Tensor, cs_up=None):
            assert es.shape[0] == ss.shape[0]
            assert es.dim() == 4
            if cs_up is not None:
                assert self.up is not None
                assert cs_up.shape[0] == es.shape[0]
            m, n, _ = ss.shape
            # es: [M, E, H, W]
            # ss: [M, N, S]
            es = es[:, None].expand(-1, n, -1, -1, -1).reshape(m * n, *es.shape[-3:])
            ss = ss.reshape(m * n, ss.shape[-1])
            # es: [M*N, E, H, W]
            # ss: [M*N, S]
            rs = self.net.decode(es, ss)
            if cs_up is not None:
                cs_up = (cs_up - 0.5)[:, None].expand(-1, n, -1, -1, -1).reshape(m * n, *cs_up.shape[1:])
                rs = self.up(rs, cs_up)
            rs = rs + 0.5
            rs = rs.reshape(m, n, *rs.shape[1:])
            return rs

        def sample_inputs_to_tensors(self, cs, ps):
            # return: cs, ps, cs_up, single
            single = isinstance(cs, Image.Image)
            if single:
                cs, ps = [cs], [ps]
            assert len(cs) == len(ps)
            assert all(isinstance(c, Image.Image) for c in cs)
            assert all(all(isinstance(value, (float, int)) for value in p) for p in ps)
            ps = torch.FloatTensor(ps).to(self.device)
            if self.up:
                cs_up = torch.stack([
                    T.to_tensor(c) for c in cs
                ]).to(self.device)
                cs = torch.stack([
                    T.to_tensor(c.resize((self.net_config['image_size'],) * 2, Image.BICUBIC)) for c in cs
                ]).to(self.device)
            else:
                cs_up = None
                cs = torch.stack([
                    T.to_tensor(c) for c in cs
                ]).to(self.device)
            return cs, ps, cs_up, single
        
        def interp_inputs_to_tensors(self, cs, p1s, p2s):
            # return: cs, p1s, p2s, cs_up, single
            single = isinstance(cs, Image.Image)
            if single:
                cs, p1s, p2s = [cs], [p1s], [p2s]
            assert len(cs) == len(p1s) == len(p2s)
            assert all(isinstance(c, Image.Image) for c in cs)
            assert all(all(isinstance(value, (float, int)) for value in p) for p in p1s)
            assert all(all(isinstance(value, (float, int)) for value in p) for p in p2s)
            p1s = torch.FloatTensor(p1s).to(self.device)
            p2s = torch.FloatTensor(p2s).to(self.device)
            if self.up:
                cs_up = torch.stack([
                    T.to_tensor(c) for c in cs
                ]).to(self.device)
                cs = torch.stack([
                    T.to_tensor(c.resize((self.net_config['image_size'],) * 2, Image.BICUBIC)) for c in cs
                ]).to(self.device)
            else:
                cs_up = None
                cs = torch.stack([
                    T.to_tensor(c) for c in cs
                ]).to(self.device)
            return cs, p1s, p2s, cs_up, single
        
        def output_from_tensors(self, rs, single=False):
            rs = [[T.to_pil_image(r) for r in rr] for rr in rs]
            return rs[0] if single else rs
        
        def sample(self, cs, ps, n, t):
            cs, ps, cs_up, single = self.sample_inputs_to_tensors(cs, ps)
            es = self.encode_content(cs)
            ss = self.sample_style(n, ps, t)
            rs = self.decode(es, ss, cs_up)
            return self.output_from_tensors(rs, single)
        
        def interp(self, cs, p1s, p2s, n, t):
            cs, p1s, p2s, cs_up, single = self.interp_inputs_to_tensors(cs, p1s, p2s)
            m = len(cs)
            zs = self.sample_noise(m, t)[:, None, :].expand(-1, n, -1).reshape(m*n, -1) # [M*N, S]
            w = torch.linspace(0, 1, n)[None, :, None].to(self.device)
            ps = (p1s[:, None] * (1-w) + p2s[:, None] * w).reshape(m*n, -1) # [M*N, 2]
            ss = self.inverse_flow(zs, ps).reshape(m, n, -1) # [M, N, S]
            es = self.encode_content(cs)
            rs = self.decode(es, ss, cs_up)
            return self.output_from_tensors(rs, single)
            
    return IconFlow(net, flow, up, net_config, flow_config, up_config)


'''
python -m iconflow.train_flow_net4 output/iconflow/icon128_v4_M51r50e16s48d32x4_NU3CC --device cuda:5 --batch-size 64 --end-iter 300000 --flow-width 512 --dirname flow1
python -m iconflow.train_flow_net4 output/iconflow/icon128_v4_M51r50e16s48d32x4_NU3CC --device cuda:6 --batch-size 64 --end-iter 300000 --flow-width 512 --dirname flow2 --pseudo-prob 0.3
python -m iconflow.train_flow output/iconflow/final --device cuda:6 --batch-size 64 --end-iter 100000 --flow-width 512 --dirname flow_512
'''


def train(
    net_output_dir,
    dirname,
    device='cpu',
    batch_size=64,
    num_workers=8,
    log_int=50,
    sample_int=200,
    save_int=200,
    save_iters=[],
    end_iter=30000,
    flow_width=512,
    flow_depth=4,
    max_samples=1000,
):
    output_dir = os.path.join(net_output_dir, dirname)


    ei.patch()
    os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
    warnings.filterwarnings("ignore")
  
    '''
    Save config
    '''
    
    config = locals().copy()
    config['argv'] = sys.argv.copy()
    print(config)
    os.makedirs(output_dir, exist_ok=True)
    with open(os.path.join(output_dir, 'config.json'), 'w') as f:
        json.dump(config, f, indent=2)
    
    '''
    Load net config
    '''
    
    with open(os.path.join(net_output_dir, 'config.json')) as f:
        c = json.load(f)
        image_size = c['image_size']
        
        assert '_v4_' in net_output_dir
        
        s_dim = c['s_dim']
        e_dim = c['e_dim']
        unet_arch = c['unet_arch']
        resnet_depth = c['resnet_depth']
        cc_arch = c['cc_arch']
        decoder_width = c['decoder_width']
        decoder_depth = c['decoder_depth']
        

    import math
    import random
    import torch
    import numpy as np
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    from torch.utils.data import DataLoader
    from torch.utils.tensorboard.writer import SummaryWriter
    from torch.distributions import Normal
    from torchvision.utils import make_grid
    from tqdm import tqdm

    from .utils import cycle_iter, random_sampler
    from .train_net import get_net as get_net4
    
    
    device = torch.device(device)
    writer = SummaryWriter(output_dir)


    ckpt_path = os.path.join(net_output_dir, 'checkpoint.pt')
    state = torch.load(ckpt_path, map_location=device)


    '''
    Dataset
    '''
    
    d = get_icon_flow_dataset(image_size, max_samples)

    train_loader = DataLoader(
        d['style'],
        batch_size=batch_size,
        sampler=random_sampler(len(d['style'])),
        pin_memory=(device.type == 'cuda'),
        num_workers=num_workers
    )
    
    x_ch, c_ch = map(lambda t: t.shape[0], d['train'][0])
    

    '''
    Model
    '''

    net = get_net4(c_ch, x_ch, e_dim, s_dim, unet_arch, cc_arch, resnet_depth, decoder_width, decoder_depth)
    net.to(device)
    net.load_state_dict(state['net'])
    print('loaded from checkpoint, it: {}'.format(state['it']))
    del state

    
    '''
    Flow
    '''
    
    flow = get_flow(s_dim, flow_width, flow_depth)
    opt = optim.Adam(flow.parameters(), lr=1e-3)
    flow.to(device)
    it = 0
    try:
        state = torch.load(os.path.join(output_dir, 'checkpoint.pt'), map_location=device)
        flow.load_state_dict(state['flow'])
        opt.load_state_dict(state['opt'])
        it = state['it']
        del state
    except FileNotFoundError:
        pass
    
    base_dist = Normal(0, 1)
    
    
    '''
    Sample
    '''
    
    def sample():
        
        from utils.hist import draw_style
        from utils.test import add_text, Image, from_image1
        
        def get_style_img(cmb, text):
            img = Image.new('RGB', (128,)*2, (255,)*3)
            img.paste(draw_style(cmb).resize((128, 128-40)), (0, 40))
            img = add_text(img, text)
            return img
        
        net.eval()
        flow.eval()
        
        temps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]
        
        if image_size > 256:
            n_train, n_test = 1, 1
        elif image_size > 128:
            n_train, n_test = 2, 2
        else:
            n_train, n_test = 4, 4
        
        c_train = random.Random(it).choice(d['train'])[1]
        c_test = random.Random(it).choice(d['test'])[1]
        c_batch = [c_train] * n_train + [c_test] * n_test
        C = torch.stack(c_batch).to(device)

        
        style_dataset: StylePaletteDataset = d['style']
        
        selected_style_names = np.random.choice(style_dataset.style_names, size=(n_train + n_test))
        selected_style_images = torch.stack([
            from_image1(get_style_img(style_dataset.style_to_cmb[name], name))
            for name in selected_style_names
        ])
        
        R_list = [selected_style_images]
        L = torch.stack([
            style_dataset.position_to_condition(style_dataset.style_to_pos[name])
            for name in selected_style_names
        ]).to(device)
        
        for temp in temps:
            with torch.no_grad():
                Z = base_dist.sample([C.shape[0], s_dim]) * temp
                Z = Z.to(device)
                e = net.encode_content(C)
                s = flow(Z, L, reverse=True)
                R = net.decode(e, s)
                R_list.append(R.cpu())
                del Z, e, s, R
        
        R = torch.cat(R_list)
        R = torch.clamp(R + 0.5, 0, 1)
        image = make_grid(R, C.shape[0])
        
        return image
    
    '''
    Save
    '''
    
    def save(add_postfix=False):
        if add_postfix:
            file_name = f'checkpoint_{it}.pt'
        else:
            file_name = 'checkpoint.pt'

        torch.save({
            'flow': flow.state_dict(),
            'opt': opt.state_dict(),
            'it': it,
            'config': config,
        }, os.path.join(output_dir, file_name))
    
    
    '''
    Train
    '''
    
    def step(X, L):
        X = X.to(device)
        
        net.eval()
        flow.train()
        
        with torch.no_grad():
            S = net.encode_style(X)
            S = S.reshape(batch_size, -1)
        
        zero = torch.zeros(batch_size, 1, device=device)

        L = L.to(device)
        z, dlogp = flow(S, L, zero)

        logpz = base_dist.log_prob(z).sum(-1)
        logpx = logpz - dlogp
        nll = -logpx.mean()
        bpd = nll / s_dim / math.log(2)
        
        opt.zero_grad()
        bpd.backward()
        opt.step()
        
        return bpd.item()


    try:
        with tqdm(cycle_iter(train_loader), total=end_iter-it) as iter_loader:
            for X, L in iter_loader:
                if it >= end_iter:
                    raise KeyboardInterrupt
                    
                loss = step(X, L)
                
                iter_loader.set_postfix({'loss': loss})

                if it % log_int == 0:
                    writer.add_scalar('training/loss', loss, it)
                
                if it % sample_int == 0:
                    writer.add_image(f'sampling/sample', sample(), it)
                
                it += 1
                
                if it % save_int == 0:
                    save()
                
                if it in save_iters:
                    save(add_postfix=True)
                


    except KeyboardInterrupt:
        print('saving checkpoint')
        save()


if __name__ == '__main__':
    import fire

    fire.Fire(train)
