import argparse
from PIL import Image
import time
import logging
import os
import numpy as np
import random
from datetime import datetime
from collections import OrderedDict

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from utils.scheduler import GradualWarmupScheduler
from TM_model import Model
import utils.losses as losses
from utils.utils_image import eval_tensor_imgs
from utils.general import create_log_folder, get_cuda_info, find_latest_checkpoint, change_checkpoint
from data.dataset_LMDB_train import DataLoaderTurbVideo
import lpips


def get_args():
    parser = argparse.ArgumentParser(description='Train the UNet on images and restoration')
    parser.add_argument('--iters', type=int, default=400000, help='Number of epochs')
    parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=8, help='Batch size')
    parser.add_argument('--patch-size', '-ps', dest='patch_size', type=int, default=256, help='Batch size')
    parser.add_argument('--print-period', '-pp', dest='print_period', type=int, default=1000, help='number of iterations to save checkpoint')
    parser.add_argument('--val-period', '-vp', dest='val_period', type=int, default=5000, help='number of iterations for validation')
    parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=0.0001, help='Learning rate', dest='lr')
    parser.add_argument('--warmup_iters', type=int, default=10000, help='warm up iterations')
    parser.add_argument('--num_frames', type=int, default=16, help='number of frames for the model')
    parser.add_argument('--num_workers', type=int, default=16, help='number of workers in dataloader')
    parser.add_argument('--train_path', type=str, default='~/data/lmdb_ATSyn/train_lmdb/', help='path of training imgs')
    parser.add_argument('--train_info', type=str, default='~/data/lmdb_ATSyn/train_lmdb/train_info.json', help='info of training imgs')
    parser.add_argument('--val_path', type=str, default='~/data/lmdb_ATSyn/test_lmdb/', help='path of validation imgs')
    parser.add_argument('--val_info', type=str, default='~/data/lmdb_ATSyn/test_lmdb/test_info.json', help='info of testing imgs')   
    parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
    parser.add_argument('--log_path', type=str, default='~/data/log_rnn/Mamba/', help='path to save logging files and images')
    parser.add_argument('--task', type=str, default='turb', help='choose turb or blur or both')
    parser.add_argument('--run_name', type=str, default='MambaTM', help='name of this running')
    parser.add_argument('--start_over', action='store_true')

    parser.add_argument('--model', type=str, default='MambaTM_NOLPD', help='type of model to construct')
    parser.add_argument('--output_full', action='store_true', help='output # of frames is the same as the input')
    parser.add_argument('--n_features', type=int, default=16, help='base # of channels for Conv')
    parser.add_argument('--n_blocks', type=int, default=4, help='# of blocks in middle part of the model')
    parser.add_argument('--future_frames', type=int, default=2, help='use # of future frames')
    parser.add_argument('--past_frames', type=int, default=2, help='use # of past frames')
    parser.add_argument('--seed', type=int, default=3275, help='random seed')
    return parser.parse_args()


def validate(args, model, val_loader, criterion, c_lpips, iter_count, im_save_freq, im_save_path, device, level):
        test_results_folder = OrderedDict()
        test_results_folder['psnr'] = []
        test_results_folder['ssim'] = []
        eval_lpips = 0
        eval_loss = 0
        model.eval()
        for s, data in enumerate(val_loader):
            input_ = data[0].cuda()
            if args.output_full:
                target = data[1].to(device)
            else:
                target = data[1][:, args.past_frames:args.num_frames-args.future_frames, ...].to(device)
            with torch.no_grad():
                output = model(input_)
                if not args.output_full:
                    input_ = input_[:, args.past_frames:args.num_frames-args.future_frames, ...]
                
                loss = criterion(output, target)
                loss_lpips = c_lpips(output.flatten(0,1)*2-1, target.flatten(0,1)*2-1).mean()
                eval_loss += loss.item()
                eval_lpips += loss_lpips.item()
            
            if s % im_save_freq == 0:
                psnr_batch, ssim_batch = eval_tensor_imgs(target, output, input_, save_path=im_save_path, kw=level+'val', iter_count=iter_count)
            else:
                psnr_batch, ssim_batch = eval_tensor_imgs(target, output, input_)
            test_results_folder['psnr'] += psnr_batch
            test_results_folder['ssim'] += ssim_batch
                        
        psnr = sum(test_results_folder['psnr']) / len(test_results_folder['psnr'])
        ssim = sum(test_results_folder['ssim']) / len(test_results_folder['ssim'])
        eval_loss /= (s + 1)
        eval_lpips /= (s + 1)
        return psnr, ssim, eval_loss, eval_lpips

def main():
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    run_name = args.run_name + '_' + datetime.now().strftime("%m-%d-%Y-%H-%M-%S")
    run_path = os.path.join(args.log_path, run_name)
    if not os.path.exists(run_path):
        result_img_path, path_ckpt, path_scipts = create_log_folder(run_path)
    logging.basicConfig(filename=f'{run_path}/recording.log', \
                        level=logging.INFO, format='%(levelname)s: %(message)s')
    gpu_count = torch.cuda.device_count()
    get_cuda_info(logging)
    
    train_dataset = DataLoaderTurbVideo(args.train_path, args.train_info, turb=True, tilt=False, blur=False, \
                                    num_frames=args.num_frames, patch_size=args.patch_size, noise=0.0001, is_train=True)
    train_loader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, \
                              drop_last=True, pin_memory=True, prefetch_factor=3)

    val_dataset_weak = DataLoaderTurbVideo(args.val_path, args.val_info, turb=True, tilt=False, blur=False, level='weak', \
                                    num_frames=args.num_frames, patch_size=args.patch_size, noise=0.0001, is_train=False)
    val_loader_weak = DataLoader(dataset=val_dataset_weak, batch_size=args.batch_size*2, shuffle=True, num_workers=args.num_workers, \
                                 drop_last=False, pin_memory=True, prefetch_factor=2)

    val_dataset_medium = DataLoaderTurbVideo(args.val_path, args.val_info, turb=True, tilt=False, blur=False, level='medium', \
                                    num_frames=args.num_frames, patch_size=args.patch_size, noise=0.0001, is_train=False)
    val_loader_medium = DataLoader(dataset=val_dataset_medium, batch_size=args.batch_size*2, shuffle=True, num_workers=args.num_workers, \
                                   drop_last=False, pin_memory=True, prefetch_factor=2)

    val_dataset_strong = DataLoaderTurbVideo(args.val_path, args.val_info, turb=True, tilt=False, blur=False, level='strong', \
                                    num_frames=args.num_frames, patch_size=args.patch_size, noise=0.0001, is_train=False)
    val_loader_strong = DataLoader(dataset=val_dataset_strong, batch_size=args.batch_size*2, shuffle=True, num_workers=args.num_workers, \
                                   drop_last=False, pin_memory=True, prefetch_factor=2)

    model = Model(args, input_size=(args.patch_size, args.patch_size, args.num_frames)).cuda()
    # model.set_h_curve(args.patch_size, args.patch_size, args.num_frames)
    new_lr = args.lr
    optimizer = optim.Adam(model.parameters(), lr=new_lr, betas=(0.9, 0.99), eps=1e-8)
    ######### Scheduler ###########
    total_iters = args.iters
    start_iter = 0
    warmup_iter = args.warmup_iters
    scheduler_cosine = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, total_iters-warmup_iter, eta_min=1e-6)
    scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_iter, after_scheduler=scheduler_cosine)
    
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)
    random.seed(args.seed)
    np.random.seed(args.seed)

    ######### Resume ###########
    if args.load:
        if args.load == 'latest':
            load_path = find_latest_checkpoint(args.log_path, args.run_name)
            if not load_path:
                print(f'search for the latest checkpoint of {args.run_name} failed!')
        else:
            load_path = args.load
        checkpoint = torch.load(load_path)
        try:
            model.load_state_dict(checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint)
        except:
            change_checkpoint(model, checkpoint, logging)
            model.load_state_dict(checkpoint['state_dict'] if 'state_dict' in checkpoint.keys() else checkpoint)
        if not args.start_over:
            if 'epoch' in checkpoint.keys():
                start_iter = checkpoint["epoch"] * len(train_dataset)
            elif 'iter' in checkpoint.keys():
                start_iter = checkpoint["iter"] 
            if checkpoint['optimizer'] is not None:
                optimizer.load_state_dict(checkpoint['optimizer'])
            if 'scheduler' in checkpoint.keys():
                # scheduler.load_state_dict(checkpoint['scheduler'])
                # scheduler_cosine = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, total_iters-warmup_iter, eta_min=1e-6, last_epoch=start_iter)
                # scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_iter, after_scheduler=scheduler_cosine)
                for i in range(0, start_iter):
                    scheduler.step()
                # scheduler.after_scheduler.T_cur = start_iter-warmup_iter
            else:
                for i in range(0, start_iter):
                    scheduler.step()
            new_lr = optimizer.param_groups[0]['lr']
            print('------------------------------------------------------------------------------')
            print("==> Resuming Training with learning rate:", new_lr)
            logging.info(f'==> Resuming Training with learning rate: {new_lr}')
            print('------------------------------------------------------------------------------')
            


    if gpu_count > 1:
        model = torch.nn.DataParallel(model, device_ids=[i for i in range(gpu_count)]).cuda()

    ######### Loss ###########
    criterion_char = losses.CharbonnierLoss()
    criterion_lpips= lpips.LPIPS(net='vgg').cuda()
    # criterion_edge = losses.EdgeLoss3D()
    
    logging.info(f'''Starting training:
        Total_iters:     {total_iters}
        Start_iters:     {start_iter}
        Batch size:      {args.batch_size}
        Learning rate:   {new_lr}
        Training size:   {len(train_dataset)}
        val_dataset_weak size: {len(val_dataset_weak)}
        val_dataset_medium size: {len(val_dataset_medium)}
        val_dataset_strong size: {len(val_dataset_strong)}
        Checkpoints:     {path_ckpt}
    ''')
    
    ######### train ###########
    best_psnr = 0
    iter_count = start_iter

    current_start_time = time.time()
    current_loss = 0
    current_lpips = 0
    train_results_folder = OrderedDict()
    train_results_folder['psnr'] = []
    train_results_folder['ssim'] = []
    
    model.train()
    for epoch in range(1000000):
        for data in train_loader:
            model.zero_grad()
            
            input_ = data[0].to(device)
            output = model(input_)
            if args.output_full:
                target = data[1].to(device)
            else:
                target = data[1][:, args.past_frames:args.num_frames-args.future_frames, ...].to(device)
                input_ = input_[:, args.past_frames:args.num_frames-args.future_frames, ...]
            loss = criterion_char(output, target)
            loss_lpips = criterion_lpips(output.flatten(0,1)*2-1, target.flatten(0,1)*2-1).mean()
            loss_all = loss + 0.02 * loss_lpips
            # loss = criterion_char(output, target) + 0.05*criterion_edge(output, target)
            loss_all.backward()
            clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)

            optimizer.step()
            scheduler.step()
            current_loss += loss.item()
            current_lpips += loss_lpips.item()
            iter_count += 1
            # print(scheduler.get_lr()[0], scheduler.after_scheduler.get_last_lr()[0], optimizer.param_groups[0]['lr'], scheduler.after_scheduler.last_epoch)
            
            if iter_count % 500 == 0:
                psnr_batch, ssim_batch = eval_tensor_imgs(target, output, input_, save_path=result_img_path, kw='train', iter_count=iter_count)
            else:
                psnr_batch, ssim_batch = eval_tensor_imgs(target, output, input_)
            train_results_folder['psnr'] += psnr_batch
            train_results_folder['ssim'] += ssim_batch

            if iter_count>start_iter and iter_count % args.print_period == 0:
                psnr = sum(train_results_folder['psnr']) / len(train_results_folder['psnr'])
                ssim = sum(train_results_folder['ssim']) / len(train_results_folder['ssim'])
                
                logging.info('Training: iters {:d}/{:d} -Time:{:.6f} -LR:{:.7f} -Loss {:8f} LPIPS {:8f} -PSNR: {:.2f} dB; SSIM: {:.4f}'.format(
                    iter_count, total_iters, time.time()-current_start_time, optimizer.param_groups[0]['lr'], current_loss/args.print_period, \
                    current_lpips/args.print_period, psnr, ssim))

                torch.save({'iter': iter_count, 
                            'psnr': psnr,
                            'state_dict': model.module.state_dict() if gpu_count > 1 else model.state_dict(),
                            'optimizer' : optimizer.state_dict(),
                            'scheduler' : scheduler.state_dict()
                            }, os.path.join(path_ckpt, f"model_{iter_count}.pth")) 

                torch.save({'iter': iter_count, 
                            'psnr': psnr,
                            'state_dict': model.module.state_dict() if gpu_count > 1 else model.state_dict(),
                            'optimizer' : optimizer.state_dict(),
                            'scheduler' : scheduler.state_dict()
                            }, os.path.join(path_ckpt, "latest.pth")) 
                current_start_time = time.time()
                current_loss = 0
                current_lpips = 0
                train_results_folder = OrderedDict()
                train_results_folder['psnr'] = []
                train_results_folder['ssim'] = []
                                          
            #### Evaluation ####
            if iter_count>0 and iter_count % args.val_period == 0:
                psnr_w, ssim_w, loss_w, lpips_w = validate(args, model, val_loader_weak, criterion_char, criterion_lpips, iter_count, 200, result_img_path, device, 'weak')
                logging.info('Validation W: Iters {:d}/{:d} - Loss {:8f} - LPIPS {:8f} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(iter_count, total_iters, loss_w, lpips_w, psnr_w, ssim_w))
                
                psnr_m, ssim_m, loss_m, lpips_m = validate(args, model, val_loader_medium, criterion_char, criterion_lpips, iter_count, 200, result_img_path, device, 'medium')
                logging.info('Validation M: Iters {:d}/{:d} - Loss {:8f} - LPIPS {:8f} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(iter_count, total_iters, loss_m, lpips_m, psnr_m, ssim_m))
                
                psnr_s, ssim_s, loss_s, lpips_s = validate(args, model, val_loader_strong, criterion_char, criterion_lpips, iter_count, 200, result_img_path, device, 'strong')
                logging.info('Validation S: Iters {:d}/{:d} - Loss {:8f} - LPIPS {:8f} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(iter_count, total_iters, loss_s, lpips_s, psnr_s, ssim_s))
                psnr = (psnr_w + psnr_m + psnr_s) / 3
                ssim = (ssim_w + ssim_m + ssim_s) / 3
                val_loss = (loss_w + loss_m + loss_s) / 3
                val_lpips = (lpips_w + lpips_m + lpips_s) / 3
                logging.info('Validation All: Iters {:d}/{:d} - Loss {:8f} - LPIPS {:8f} - PSNR: {:.2f} dB; SSIM: {:.4f}'.format(iter_count, total_iters, val_loss, val_lpips, psnr, ssim))
                if psnr > best_psnr:
                    best_psnr = psnr
                    torch.save({'iter': iter_count,
                                'psnr': psnr,
                                'state_dict': model.module.state_dict() if gpu_count > 1 else model.state_dict(),
                                'optimizer' : optimizer.state_dict(),
                                'scheduler' : scheduler.state_dict(),
                                "after_scheduler" : scheduler.after_scheduler.state_dict()
                                }, os.path.join(path_ckpt, "model_best.pth"))
                model.train()
                
if __name__ == '__main__':
    main()
