# -*- coding: utf-8 -*-

import torch.nn.functional as F
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.autograd import Variable
from model.utils import DataLoader
from memory import convAE,unet
from flownet import UNet, getFlowCoeff, backWarp
import argparse
from tqdm import tqdm
from testing import evaluation
import numpy as np
import random
from loss import SSIM

import warnings

warnings.filterwarnings("ignore", message="Was asked to gather*")


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False


def train(args):
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    if args.gpus is None:
        gpus = "0"
        os.environ["CUDA_VISIBLE_DEVICES"] = gpus
    else:
        gpus = ""
        for i in range(len(args.gpus)):
            gpus = gpus + args.gpus[i] + ","
        os.environ["CUDA_VISIBLE_DEVICES"] = gpus[:-1]

    torch.backends.cudnn.enabled = True

    train_folder = args.dataset_path + "/" + args.dataset_type + "/training/frames"
    train_dataset = DataLoader(train_folder, transforms.Compose([
        transforms.ToTensor(),
    ]), resize_height=args.h, resize_width=args.w, time_step=args.t_length - 1)

    train_batch = data.DataLoader(train_dataset, batch_size=args.batch_size,
                                  shuffle=True, num_workers=args.num_workers, drop_last=True)

    # Model setting
    memory = convAE(3, 7, 3, 10, 128, 128)  # Uformer()#VIAD()
    memory.cuda()
    flownet = unet(6,4)
    flownet.cuda()
    trainFlowBackWarp = backWarp(args.h, args.w)
    trainFlowBackWarp = trainFlowBackWarp.cuda()
    params_encoder = list(memory.encoder.parameters())
    params_decoder = list(memory.decoder.parameters())
    params_proto = list(memory.prototype.parameters())
    params_output = list(memory.ohead.parameters())
    params_flow = list(flownet.parameters())
    # params = list(model.memory.parameters())
    params_D = params_encoder + params_decoder + params_output + params_proto + params_flow
    optimizer = torch.optim.Adam(params_D, lr=args.lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
    memory = nn.DataParallel(memory, device_ids=[0, 1])
    flownet = nn.DataParallel(flownet, device_ids=[0, 1])

    # Report the training process
    log_dir = os.path.join('./exp', args.dataset_type, args.method, args.exp_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    # orig_stdout = sys.stdout
    # f = open(os.path.join(log_dir, 'log.txt'),'w')
    # sys.stdout= f

    loss_func_mse = nn.MSELoss(reduction='mean')
    loss_func_ssim = SSIM()
    # loss_func_msg = MSGMS_Loss()

    # Training
    best_auc = 0
    best_epo = 0
    # cur_auc = evaluation(args, memory=memory, flownet=flownet)

    for epoch in range(args.epochs):
        memory.train()
        flownet.train()

        # evaluation(args, model=model)

        loss_pixel = 0

        # evaluation(model = model,m_items = m_items)
        for j, (imgs) in enumerate(tqdm(train_batch)):
            targets = imgs[:, 3:6]
            targets = Variable(targets).cuda()

            # optical flow
            I0, I2 = Variable(imgs[:, :3]).cuda(), Variable(imgs[:, 6:]).cuda()
            flowOut = flownet(torch.cat((I0, I2), dim=1))
            F_0_2 = flowOut[:, :2, :, :]
            F_2_0 = flowOut[:, 2:, :, :]
            F_t_0 = 0.25 * F_0_2 - 0.25 * F_2_0
            F_t_2 = -0.25 * F_0_2 + 0.25 * F_2_0
            g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)
            g_I2_F_t_2 = trainFlowBackWarp(I2, F_t_2)
            Ft_f = 0.5 * g_I0_F_t_0 + 0.5 * g_I2_F_t_2

            # reconstruction
            inputs = torch.cat((F_t_0, F_t_2, Ft_f), dim=1)
            outputs, _, _, _, fea_loss = memory(inputs, None, True)
            Ft_p = outputs

            # object function
            pixel_loss = loss_func_mse(Ft_p, targets) #+ loss_func_mse(Ft_f, targets)
            ssim_loss = 1 - loss_func_ssim(Ft_p, targets) #+ 1 - loss_func_ssim(Ft_f, targets)
            warpLoss = loss_func_mse(trainFlowBackWarp(I0, F_2_0), I2) + loss_func_mse(trainFlowBackWarp(I2, F_0_2), I0)
            loss = pixel_loss + warpLoss + (1e-4) * ssim_loss + fea_loss
            loss = loss.mean()
            loss_pixel += loss.item()

            optimizer.zero_grad()
            loss.backward(retain_graph=True)
            optimizer.step()

        scheduler.step()

        # print('----------------------------------------')
        # print('Epoch:', epoch+1)
        print('Epoch:', epoch + 1, 'Loss: Frame interpolation {:.4f}'.format(loss_pixel / (len(train_batch))))
        # print('----------------------------------------')

        cur_auc = evaluation(args, memory=memory, flownet=flownet)
        if cur_auc > best_auc:
            best_auc = cur_auc
            best_epo = epoch
            torch.save({'memory': memory.module.state_dict(),
                        'flownext': flownet.module.state_dict()}, os.path.join(log_dir, 'model.pth'))
        print('The result of ', args.dataset_type, '- AUC: ', cur_auc * 100, '%')

    print('The result of ', args.dataset_type, '- AUC: ', best_auc * 100, '%', '- Epoch: ', best_epo)

    print('Training is finished')
    # Save the model and the memory items
    # sys.stdout = orig_stdout
    # f.close()


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="VIAD")
    parser.add_argument('--gpus', default=['0', '1'], type=str, help='gpus')
    parser.add_argument('--batch_size', type=int, default=4, help='batch size for training')
    parser.add_argument('--test_batch_size', type=int, default=1, help='batch size for test')
    parser.add_argument('--epochs', type=int, default=60, help='number of epochs for training')
    parser.add_argument('--h', type=int, default=256, help='height of input images')
    parser.add_argument('--w', type=int, default=256, help='width of input images')
    parser.add_argument('--c', type=int, default=3, help='channel of input images')
    parser.add_argument('--lr', type=float, default=2e-4, help='initial learning rate')
    parser.add_argument('--method', type=str, default='inte', help='The target task for anoamly detection')
    parser.add_argument('--t_length', type=int, default=3, help='length of the frame sequences')
    parser.add_argument('--fdim', type=int, default=512, help='channel dimension of the features')
    parser.add_argument('--mdim', type=int, default=512, help='channel dimension of the memory items')
    parser.add_argument('--msize', type=int, default=10, help='number of the memory items')
    parser.add_argument('--num_workers', type=int, default=4, help='number of workers for the train loader')
    parser.add_argument('--num_workers_test', type=int, default=1, help='number of workers for the test loader')
    parser.add_argument('--dataset_type', type=str, default='ped2', help='type of dataset: ped2, avenue, shanghai')
    
    parser.add_argument('--dataset_path', type=str, default='../VIAD/dataset', help='directory of data')
    parser.add_argument('--exp_dir', type=str, default='log', help='directory of log')

    print("Start:")

    setup_seed(42)

    args = parser.parse_args()
    train(args)
