


def train(
    dataset_root='datasets/icon4/data',
    output_dir='/tmp/test',
    device='cpu',
    image_size=64,
    num_workers=8,
    in_ch_G=3,
    out_ch_G=3,
    in_ch_D=3,
    out_ch_D=1,
    batch_size=512,
    base_dim_G=64,
    base_dim_D=64,
    nb=8,
    train_epoch=100,
    lr_D=0.0002,
    lr_G=0.0002,
    beta1=0.5,
    beta2=0.999,
    save_epoch=100,
):
    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data.dataloader import DataLoader
    from torch.utils.tensorboard import SummaryWriter
    from torchvision.utils import make_grid
    from torchvision import transforms
    
    import os, itertools
    from tqdm import tqdm
    
    from .dataset import UnpairedIconContourDataset
    from .networks import Decoder, Discriminator, Encoder

    device = torch.device(device)
    
    if torch.backends.cudnn.enabled:
        torch.backends.cudnn.benchmark = True

    os.makedirs(output_dir, exist_ok=True)
    
    def get_dataloader(dataset, batch_size, num_workers=0):
        return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
    
    train_loader = get_dataloader(
        UnpairedIconContourDataset(dataset_root, image_size, True, True, (0.0, 0.9)),
        batch_size=batch_size, num_workers=num_workers
    )
    test_loader = get_dataloader(
        UnpairedIconContourDataset(dataset_root, image_size, False, False, (0.9, 1.0)),
        batch_size=64, num_workers=0
    )

    # network
    
    En_A = Encoder(in_nc=in_ch_G, nf=base_dim_G, img_size=image_size)
    En_B = Encoder(in_nc=in_ch_G, nf=base_dim_G, img_size=image_size)
    De_A = Decoder(out_nc=out_ch_G, nf=base_dim_G)
    De_B = Decoder(out_nc=out_ch_G, nf=base_dim_G)
    Dis_A = Discriminator(in_nc=in_ch_D, out_nc=out_ch_D, nf=base_dim_D, img_size=image_size)
    Dis_B = Discriminator(in_nc=in_ch_D, out_nc=out_ch_D, nf=base_dim_D, img_size=image_size)
    nn.ModuleList([En_A, En_B, De_A, De_B, Dis_A, Dis_B]).to(device)

    # loss
    BCE_loss = nn.BCELoss().to(device)
    L1_loss = nn.L1Loss().to(device)

    # Adam optimizer
    opt_G = optim.Adam(itertools.chain(En_A.parameters(), De_A.parameters(), En_B.parameters(), De_B.parameters()), lr=lr_G, betas=(beta1, beta2))
    opt_D_A = optim.Adam(Dis_A.parameters(), lr=lr_D, betas=(beta1, beta2))
    opt_D_B = optim.Adam(Dis_B.parameters(), lr=lr_D, betas=(beta1, beta2))
    
    writer = SummaryWriter(output_dir)

    real = torch.ones(batch_size, 1, 1, 1).to(device)
    fake = torch.zeros(batch_size, 1, 1, 1).to(device)
    
    start_epoch = 0
    start_it = 0
    
    try:
        ckpt_path = os.path.join(output_dir, 'checkpoint.pt')
        state = torch.load(ckpt_path, map_location='cpu')
        opt_G.load_state_dict(state['opt_G'])
        opt_D_A.load_state_dict(state['opt_D_A'])
        opt_D_B.load_state_dict(state['opt_D_B'])
        En_A.load_state_dict(state['En_A'])
        En_B.load_state_dict(state['En_B'])
        De_A.load_state_dict(state['De_A'])
        De_B.load_state_dict(state['De_B'])
        Dis_A.load_state_dict(state['Dis_A'])
        Dis_B.load_state_dict(state['Dis_B'])
        start_it = state['start_it']
        start_epoch = state['start_epoch']
        del state
        print('loaded from', ckpt_path)
    except FileNotFoundError:
        pass
        
    def get_state():
        return {
            'opt_G': opt_G.state_dict(),
            'opt_D_A': opt_D_A.state_dict(),
            'opt_D_B': opt_D_B.state_dict(),
            'En_A': En_A.state_dict(),
            'En_B': En_B.state_dict(),
            'De_A': De_A.state_dict(),
            'De_B': De_B.state_dict(),
            'Dis_A': Dis_A.state_dict(),
            'Dis_B': Dis_B.state_dict(),
            'start_it': start_it,
            'start_epoch': start_epoch,
        }
    
    it = start_it
    
    for i in tqdm(range(train_epoch)):
        epoch = start_epoch + i
        
        En_A.train()
        En_B.train()
        De_A.train()
        De_B.train()
        
        for A, B in train_loader:
            log_dict = {}
            
            A, B = A.to(device), B.to(device)

            # train Disc_A & Disc_B
            # Disc real loss
            Dis_A_real = Dis_A(A)
            Dis_A_real_loss = BCE_loss(Dis_A_real, real)

            Dis_B_real = Dis_B(B)
            Dis_B_real_loss = BCE_loss(Dis_B_real, real)

            # Disc fake loss
            in_A, sp_A = En_A(A)
            in_B, sp_B = En_B(B)

            # De_A == B2A decoder, De_B == A2B decoder
            B2A = De_A(in_B + sp_A)
            A2B = De_B(in_A + sp_B)

            Dis_A_fake = Dis_A(B2A)
            Dis_A_fake_loss = BCE_loss(Dis_A_fake, fake)

            Dis_B_fake = Dis_B(A2B)
            Dis_B_fake_loss = BCE_loss(Dis_B_fake, fake)

            Dis_A_loss = Dis_A_real_loss + Dis_A_fake_loss
            Dis_B_loss = Dis_B_real_loss + Dis_B_fake_loss

            opt_D_A.zero_grad()
            Dis_A_loss.backward(retain_graph=True)
            opt_D_A.step()

            opt_D_B.zero_grad()
            Dis_B_loss.backward(retain_graph=True)
            opt_D_B.step()

            log_dict['Dis_A_loss'] = Dis_A_loss.item()
            log_dict['Dis_B_loss'] = Dis_B_loss.item()

            # train Gen
            # Gen adversarial loss
            in_A, sp_A = En_A(A)
            in_B, sp_B = En_B(B)

            B2A = De_A(in_B + sp_A)
            A2B = De_B(in_A + sp_B)

            Dis_A_fake = Dis_A(B2A)
            Gen_A_fake_loss = BCE_loss(Dis_A_fake, real)

            Dis_B_fake = Dis_B(A2B)
            Gen_B_fake_loss = BCE_loss(Dis_B_fake, real)

            # Gen Dual loss
            in_A_hat, sp_B_hat = En_B(A2B)
            in_B_hat, sp_A_hat = En_A(B2A)

            A_hat = De_A(in_A_hat + sp_A)
            B_hat = De_B(in_B_hat + sp_B)

            Gen_gan_loss = Gen_A_fake_loss + Gen_B_fake_loss
            Gen_dual_loss = L1_loss(A_hat, A.detach()) ** 2 + L1_loss(B_hat, B.detach()) ** 2
            Gen_in_loss = L1_loss(in_A_hat, in_A.detach()) ** 2 + L1_loss(in_B_hat, in_B.detach()) ** 2
            Gen_sp_loss = L1_loss(sp_A_hat, sp_A.detach()) ** 2 + L1_loss(sp_B_hat, sp_B.detach()) ** 2

            Gen_loss = Gen_gan_loss + Gen_dual_loss + Gen_in_loss + Gen_sp_loss

            opt_G.zero_grad()
            Gen_loss.backward()
            opt_G.step()

            log_dict['Gen_loss'] = Gen_loss.item()

            for key, value in log_dict.items():
                writer.add_scalar(f'cimg2img/{key}', value, it)

            it += 1
        
        
        if epoch % save_epoch == 0:
            with torch.no_grad():
                En_A.eval(); En_B.eval()
                De_A.eval(); De_B.eval()
                
                A, B = next(iter(test_loader))                
                A, B = A.to(device), B.to(device)

                in_A, sp_A = En_A(A)
                in_B, sp_B = En_B(B)

                B2A = De_A(in_B + sp_A)
                A2B = De_B(in_A + sp_B)
                
                rows = torch.cat([
                    torch.stack([A_chunk, B_chunk, A2B_chunk, B2A_chunk])
                    for A_chunk, B_chunk, A2B_chunk, B2A_chunk
                    in zip(A.split(16), B.split(16), A2B.split(16), B2A.split(16))
                ]) / 2.0 + 0.5
                
                result = make_grid(rows.reshape(-1, *rows.shape[2:]), nrow=rows.shape[1])

                writer.add_image('cimg2img/sample', result, it)
                
            torch.save(get_state(), os.path.join(output_dir, f'checkpoint_{epoch:04d}.pt'))

        torch.save(get_state(), os.path.join(output_dir, f'checkpoint.pt'))

if __name__ == '__main__':
    import fire
    import ei
    
    ei.patch()

    fire.Fire(train)
    