# -*- coding: utf-8 -*-

import numpy as np
import os
import torch
import torch.nn as nn
import torch.utils.data as data
import torchvision.transforms as transforms
from torch.autograd import Variable
from collections import OrderedDict
from model.utils import DataLoader
from memory import convAE
from flownet import UNet, getFlowCoeff, backWarp
import glob
from utils import psnr, anomaly_score_list, score_sum, AUC, filter, calc, score_mul

def evaluation(args, memory=None, flownet=None):

    os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
    if args.gpus is None:
        gpus = "0"
        os.environ["CUDA_VISIBLE_DEVICES"]= gpus
    else:
        gpus = ""
        for i in range(len(args.gpus)):
            gpus = gpus + args.gpus[i] + ","
        os.environ["CUDA_VISIBLE_DEVICES"]= gpus[:-1]

    torch.backends.cudnn.enabled = True

    test_folder = args.dataset_path+"/"+args.dataset_type+"/testing/frames"

    # Loading dataset
    test_dataset = DataLoader(test_folder, transforms.Compose([
                transforms.ToTensor(),
                ]), resize_height=args.h, resize_width=args.w, time_step=args.t_length-1)


    test_batch = data.DataLoader(test_dataset, batch_size = args.test_batch_size,
                                shuffle=False, num_workers=args.num_workers_test, drop_last=False)

    loss_func_mse = nn.MSELoss(reduction='none')

    # Loading the trained model
    if memory is None:
        model = convAE()
        model.load_state_dict(torch.load(args.model_dir))
        model.cuda()
    else:
        memory.eval()
        flownet.eval()
        #m_items.eval()
    labels = np.load('./data/frame_labels_'+args.dataset_type+'.npy')

    videos = OrderedDict()
    videos_list = sorted(glob.glob(os.path.join(test_folder, '*')))
    for video in videos_list:
        video_name = video.split('/')[-1]
        videos[video_name] = {}
        videos[video_name]['path'] = video
        videos[video_name]['frame'] = glob.glob(os.path.join(video, '*.jpg'))
        videos[video_name]['frame'].sort()
        videos[video_name]['length'] = len(videos[video_name]['frame'])

    labels_list = []
    label_length = 0
    psnr_list = {}
    feature_distance_list = {}
    optical_flow_list = {}

    print('Evaluation of', args.dataset_type)

    # Setting for video anomaly detection
    for video in sorted(videos_list):
        video_name = video.split('/')[-1]
        if args.method == 'inte':
            labels_list = np.append(labels_list, labels[0][1+label_length:videos[video_name]['length']+label_length-1])
        else:
            labels_list = np.append(labels_list, labels[0][label_length:videos[video_name]['length']+label_length])
        label_length += videos[video_name]['length']
        psnr_list[video_name] = []
        feature_distance_list[video_name] = []
        optical_flow_list[video_name] = []

    label_length = 0
    video_num = 0
    label_length += videos[videos_list[video_num].split('/')[-1]]['length']

    validationFlowBackWarp = backWarp(args.h, args.w)
    validationFlowBackWarp = validationFlowBackWarp.cuda()

    for k,(imgs) in enumerate(test_batch):
        targets = imgs[:,3:6]
        if args.method == 'inte':
            if k == label_length-2*(video_num+1):
                video_num += 1
                label_length += videos[videos_list[video_num].split('/')[-1]]['length']
        else:
            if k == label_length:
                video_num += 1
                label_length += videos[videos_list[video_num].split('/')[-1]]['length']

        targets = Variable(targets).cuda()


        with torch.no_grad():

            # optical flow
            I0, I2 = Variable(imgs[:, :3]).cuda(), Variable(imgs[:, 6:]).cuda()
            flowOut = flownet(torch.cat((I0, I2), dim=1))
            F_0_2 = flowOut[:, :2, :, :]
            F_2_0 = flowOut[:, 2:, :, :]
            F_t_0 = 0.25 * F_0_2 - 0.25 * F_2_0
            F_t_2 = -0.25 * F_0_2 + 0.25 * F_2_0
            g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)
            g_I2_F_t_2 = validationFlowBackWarp(I2, F_t_2)
            Ft_f = 0.5 * g_I0_F_t_0 + 0.5 * g_I2_F_t_2
            # reconstruction
            inputs = torch.cat((F_t_0,F_t_2,Ft_f), dim=1)
            outputs, _, _, _, fea_loss = memory(inputs, None, True)

            # frame prediction
            # F_t_0_f = outputs[:, :2, :, :] #+ F_t_0
            # F_t_2_f = outputs[:, 2:4, :, :] #+ F_t_2
            #flowRec = outputs[:, :4, :, :]
            Ft_p = outputs
            #V_t_0 = torch.sigmoid(outputs[:, 4:5, :, :])
            #V_t_2 = 1 - V_t_0
            #g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)
            #g_I2_F_t_2_f = validationFlowBackWarp(I2, F_t_2_f)
            #Ft_r = 0.5 * g_I0_F_t_0_f + 0.5 * g_I2_F_t_2_f
            #Ft_p = (0.5 * V_t_0 * g_I0_F_t_0_f + 0.5 * V_t_2 * g_I2_F_t_2_f) / (
            #        0.5 * V_t_0 + 0.5 * V_t_2)
            #print(outputs)
            mse_imgs = torch.mean(loss_func_mse((Ft_p+1)/2, (targets+1)/2)).item()
            mse_feas = fea_loss.mean().item()
            #mse_flow = mse_imgs#torch.mean(loss_func_mse(flowOut, flowRec))
            #mse_imgs = torch.max(loss_func_msg(outputs[1].unsqueeze(0),imgs[1].unsqueeze(0))).item()

        #psnr_list[videos_list[video_num].split('/')[-1]].append(psnr(mse_imgs))
        feature_distance_list[videos_list[video_num].split('/')[-1]].append(psnr(mse_feas))
        optical_flow_list[videos_list[video_num].split('/')[-1]].append(psnr(mse_imgs))


    # Measuring the abnormality score and the AUC
    anomaly_score_feat_list = []
    anomaly_score_flow_list = []
    anomaly_score_mult_list = []
    anomaly_score_total_list = []
    for video in sorted(videos_list):
        video_name = video.split('/')[-1]
        template = calc(15, 2)
        #aa = filter(anomaly_score_list(psnr_list[video_name]), template, 15)
        bb = filter(anomaly_score_list(feature_distance_list[video_name]), template, 15)
        cc = filter(anomaly_score_list(optical_flow_list[video_name]), template, 15)
        #anomaly_score_intp_list += score_sum(aa, bb, 0.5)
        anomaly_score_flow_list += score_sum(cc, cc, 0.5)
        #anomaly_score_feat_list += score_sum(bb, bb, 0.1)
        #anomaly_score_mult_list += score_mul(bb, cc, 0.2)
        anomaly_score_total_list += score_sum(bb, cc, 0.3)


    #anomaly_score_total_list = np.asarray(anomaly_score_total_list)


    total_accuracy = round(AUC(anomaly_score_total_list, np.expand_dims(1-labels_list, 0)),4)
    mult_accuracy = 0#round(AUC(anomaly_score_mult_list, np.expand_dims(1 - labels_list, 0)),4)
    flow_accuracy = round(AUC(anomaly_score_flow_list, np.expand_dims(1-labels_list, 0)),4)
    feat_accuracy = 0#round(AUC(anomaly_score_feat_list, np.expand_dims(1 - labels_list, 0)), 4)
    print(args.dataset_type,
          '-Flow:', flow_accuracy * 100, '%',
          '-Feature:', feat_accuracy*100, '%',
          '-Total:', total_accuracy*100, '%',
          '-Mul:', mult_accuracy*100, '%')
    #print('The result of ', args.dataset_type)
    #print('Flow-AUC: ', flow_accuracy*100, '%')

    return flow_accuracy

if __name__ == '__main__':
    evaluation()
