import os
import gc

import cv2
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from PIL import Image
from tqdm import tqdm

import fhog, colorname
from data_utils import gaussian_shaped_labels, get_det, get_connection
from correlation import Correlation2


def preprocess_template(template_path = '',
                        template_list = ['01.png','02.png','03.png','04.png'],
                        save_weight_file = '',
                        sigma = 5):
    
    w_dict = {'single': [], 'single_up': [], 'single_50': [], 'single_down': [], 'single_30': [], 'single_max': [], 'single_min': [], 'template_list': template_list}
    for f in sorted(os.listdir(template_path)):
        if f not in template_list:
            continue
        print(f'Processing template: {f}')

        single = cv2.imread(os.path.join(template_path, f))
        print('Template:', f, '  Size:', single.shape) 
        single_max = cv2.resize(single,(108,108))  # sigma*1.75
        single_up = cv2.resize(single,(100,100))   # sigma*1.5
        single_50 = cv2.resize(single,(84,84))   # sigma*1.25
        single_30 = cv2.resize(single,(52,52))   # sigma*0.75
        single_down = cv2.resize(single,(36,36)) # sigma*0.5
        single_min = cv2.resize(single,(20,20))    # sigma*0.25
        single = cv2.resize(single,(68,68))      # sigma=64

        # single 21
        mapps = {'sizeX':0, 'sizeY':0, 'numFeatures':0, 'map':0}
        mapps = fhog.getFeatureMaps(single, 4, mapps)
        hs = int(mapps['sizeY'])
        ws = int(mapps['sizeX'])
        ds = int(mapps['numFeatures'])
        #print(mapps['map'].shape)
        featuress = mapps['map'].reshape((hs,ws,ds))
        cns = colorname.getcolorname(single, hs)
        featuress = np.concatenate((featuress,cns),axis=2)

        #single_up 31
        mapps_up = {'sizeX':0, 'sizeY':0, 'numFeatures':0, 'map':0}
        mapps_up = fhog.getFeatureMaps(single_up, 4, mapps_up)
        hs_up = int(mapps_up['sizeY'])
        ws_up = int(mapps_up['sizeX'])
        ds_up = int(mapps_up['numFeatures'])
        featuress_up = mapps_up['map'].reshape((hs_up,ws_up,ds_up))
        cns_up = colorname.getcolorname(single_up, hs_up)
        featuress_up = np.concatenate((featuress_up,cns_up),axis=2)

        #single_50 25
        mapps_50 = {'sizeX':0, 'sizeY':0, 'numFeatures':0, 'map':0}
        mapps_50 = fhog.getFeatureMaps(single_50, 4, mapps_50)
        hs_50 = int(mapps_50['sizeY'])
        ws_50 = int(mapps_50['sizeX'])
        ds_50 = int(mapps_50['numFeatures'])
        featuress_50 = mapps_50['map'].reshape((hs_50,ws_50,ds_50))
        cns_50 = colorname.getcolorname(single_50, hs_50)
        featuress_50 = np.concatenate((featuress_50,cns_50),axis=2)

        #single_down 11
        mapps_down = {'sizeX':0, 'sizeY':0, 'numFeatures':0, 'map':0}
        mapps_down = fhog.getFeatureMaps(single_down, 4, mapps_down)
        hs_down = int(mapps_down['sizeY'])
        ws_down = int(mapps_down['sizeX'])
        ds_down = int(mapps_down['numFeatures'])
        featuress_down = mapps_down['map'].reshape((hs_down,ws_down,ds_down))
        cns_down = colorname.getcolorname(single_down, hs_down)
        featuress_down = np.concatenate((featuress_down,cns_down),axis=2)

        #single_30 15
        mapps_30 = {'sizeX':0, 'sizeY':0, 'numFeatures':0, 'map':0}
        mapps_30 = fhog.getFeatureMaps(single_30, 4, mapps_30)
        hs_30 = int(mapps_30['sizeY'])
        ws_30 = int(mapps_30['sizeX'])
        ds_30 = int(mapps_30['numFeatures'])
        featuress_30 = mapps_30['map'].reshape((hs_30,ws_30,ds_30))
        cns_30 = colorname.getcolorname(single_30, hs_30)
        featuress_30 = np.concatenate((featuress_30,cns_30),axis=2)

        #single_max 37
        mapps_max = {'sizeX':0, 'sizeY':0, 'numFeatures':0, 'map':0}
        mapps_max = fhog.getFeatureMaps(single_max, 4, mapps_max)
        hs_max = int(mapps_max['sizeY'])
        ws_max = int(mapps_max['sizeX'])
        ds_max = int(mapps_max['numFeatures'])
        featuress_max = mapps_max['map'].reshape((hs_max,ws_max,ds_max))
        cns_max = colorname.getcolorname(single_max, hs_max)
        featuress_max = np.concatenate((featuress_max,cns_max),axis=2)

        #single_min 5
        mapps_min = {'sizeX':0, 'sizeY':0, 'numFeatures':0, 'map':0}
        mapps_min = fhog.getFeatureMaps(single_min, 4, mapps_min)
        hs_min = int(mapps_min['sizeY'])
        ws_min = int(mapps_min['sizeX'])
        ds_min = int(mapps_min['numFeatures'])
        featuress_min = mapps_min['map'].reshape((hs_min,ws_min,ds_min))
        cns_min = colorname.getcolorname(single_min, hs_min)
        featuress_min = np.concatenate((featuress_min,cns_min),axis=2)

        # single
        t0 = gaussian_shaped_labels(sigma,(hs,ws))
        t0 = torch.from_numpy(t0)
        if t0.ndim == 2:
            t0 = t0.unsqueeze(2)
        t0f = torch.fft.fftn(t0,dim=[0,1],norm='forward')
        s0 = torch.from_numpy(featuress)
        s0f = torch.fft.fftn(s0,dim=[0,1],norm='forward')
        s0f_conj = s0f.conj()
        sum1 = s0f_conj * t0f # \hat{x}^* \odot \hat{y}
        sum2 = s0f_conj * s0f + 1e-3 # \hat{x}^* \odot \hat{x} + \lambda
        w1f = sum1 / sum2
        w1 = torch.fft.ifftn(w1f,dim=[0,1],norm='forward').float().flip(0).flip(1)
        w_dict['single'].append(w1.clone())

        #single_up
        t0_up = gaussian_shaped_labels(sigma*1.5,(hs_up,ws_up))
        t0_up = torch.from_numpy(t0_up)
        if t0_up.ndim == 2:
            t0_up = t0_up.unsqueeze(2)
        t0f_up = torch.fft.fftn(t0_up,dim=[0,1],norm='forward')
        s0_up = torch.from_numpy(featuress_up)
        s0f_up = torch.fft.fftn(s0_up,dim=[0,1],norm='forward')
        s0f_conj_up = s0f_up.conj()
        sum1_up = s0f_conj_up * t0f_up
        sum2_up = s0f_conj_up * s0f_up + 1e-3
        w1f_up = sum1_up / sum2_up
        w1_up = torch.fft.ifftn(w1f_up,dim=[0,1],norm='forward').float().flip(0).flip(1)
        w_dict['single_up'].append(w1_up.clone())

        #single_down
        t0_down = gaussian_shaped_labels(sigma/2,(hs_down,ws_down))
        t0_down = torch.from_numpy(t0_down)
        if t0_down.ndim == 2:
            t0_down = t0_down.unsqueeze(2)
        t0f_down = torch.fft.fftn(t0_down,dim=[0,1],norm='forward')
        s0_down = torch.from_numpy(featuress_down)
        s0f_down = torch.fft.fftn(s0_down,dim=[0,1],norm='forward')
        s0f_conj_down = s0f_down.conj()
        sum1_down = s0f_conj_down * t0f_down
        sum2_down = s0f_conj_down * s0f_down + 1e-3
        w1f_down = sum1_down / sum2_down
        w1_down = torch.fft.ifftn(w1f_down,dim=[0,1],norm='forward').float().flip(0).flip(1)
        w_dict['single_down'].append(w1_down.clone())

        #single_50
        t0_50 = gaussian_shaped_labels(sigma*1.25,(hs_50,ws_50))
        t0_50 = torch.from_numpy(t0_50)
        if t0_50.ndim == 2:
            t0_50 = t0_50.unsqueeze(2)
        t0f_50 = torch.fft.fftn(t0_50,dim=[0,1],norm='forward')
        s0_50 = torch.from_numpy(featuress_50)
        s0f_50 = torch.fft.fftn(s0_50,dim=[0,1],norm='forward')
        s0f_conj_50 = s0f_50.conj()
        sum1_50 = s0f_conj_50 * t0f_50
        sum2_50 = s0f_conj_50 * s0f_50 + 1e-3
        w1f_50 = sum1_50 / sum2_50
        w1_50 = torch.fft.ifftn(w1f_50,dim=[0,1],norm='forward').float().flip(0).flip(1)
        w_dict['single_50'].append(w1_50.clone())

        #single_30
        t0_30 = gaussian_shaped_labels(sigma*0.75,(hs_30,ws_30))
        t0_30 = torch.from_numpy(t0_30)
        if t0_30.ndim == 2:
            t0_30 = t0_30.unsqueeze(2)
        t0f_30 = torch.fft.fftn(t0_30,dim=[0,1],norm='forward')
        s0_30 = torch.from_numpy(featuress_30)
        s0f_30 = torch.fft.fftn(s0_30,dim=[0,1],norm='forward')
        s0f_conj_30 = s0f_30.conj()
        sum1_30 = s0f_conj_30 * t0f_30
        sum2_30 = s0f_conj_30 * s0f_30 + 1e-3
        w1f_30 = sum1_30 / sum2_30
        w1_30 = torch.fft.ifftn(w1f_30,dim=[0,1],norm='forward').float().flip(0).flip(1)
        w_dict['single_30'].append(w1_30.clone())

        #single_max
        t0_max = gaussian_shaped_labels(sigma*1.75,(hs_max,ws_max))
        t0_max = torch.from_numpy(t0_max)
        if t0_max.ndim == 2:
            t0_max = t0_max.unsqueeze(2)
        t0f_max = torch.fft.fftn(t0_max,dim=[0,1],norm='forward')
        s0_max = torch.from_numpy(featuress_max)
        s0f_max = torch.fft.fftn(s0_max,dim=[0,1],norm='forward')
        s0f_conj_max = s0f_max.conj()
        sum1_max = s0f_conj_max * t0f_max
        sum2_max = s0f_conj_max * s0f_max + 1e-3
        w1f_max = sum1_max / sum2_max
        w1_max = torch.fft.ifftn(w1f_max,dim=[0,1],norm='forward').float().flip(0).flip(1)
        w_dict['single_max'].append(w1_max.clone())

        #single_min
        t0_min = gaussian_shaped_labels(sigma*0.25,(hs_min,ws_min))
        t0_min = torch.from_numpy(t0_min)
        if t0_min.ndim == 2:
            t0_min = t0_min.unsqueeze(2)
        t0f_min = torch.fft.fftn(t0_min,dim=[0,1],norm='forward')
        s0_min = torch.from_numpy(featuress_min)
        s0f_min = torch.fft.fftn(s0_min,dim=[0,1],norm='forward')
        s0f_conj_min = s0f_min.conj()
        sum1_min = s0f_conj_min * t0f_min
        sum2_min = s0f_conj_min * s0f_min + 1e-3
        w1f_min = sum1_min / sum2_min
        w1_min = torch.fft.ifftn(w1f_min,dim=[0,1],norm='forward').float().flip(0).flip(1)
        w_dict['single_min'].append(w1_min.clone())


    torch.save(w_dict, save_weight_file)


def preprocess_image(device_str = '3',
                     input_train_path = '',
                     output_root_path = '',
                     time='mask_380_0.88_25',
                     save_weight_file = '',
                     thresh_absolute = 380,
                     thresh_center = 0.88,
                     rad = 25):

    corr_path = os.path.join(output_root_path, str(time), 'corr/')
    mask_path = os.path.join(output_root_path, str(time), 'mask/')
    grad_path = os.path.join(output_root_path, str(time), 'grad/')
    pool_path = os.path.join(output_root_path, str(time), 'pool/')
    visualize_path = os.path.join(output_root_path, str(time), 'visualize/')
    visualize_mask_path = os.path.join(output_root_path, str(time), 'visualize_mask/')
    os.makedirs(corr_path, exist_ok=True)
    os.makedirs(mask_path, exist_ok=True)
    os.makedirs(grad_path, exist_ok=True)
    os.makedirs(pool_path, exist_ok=True)
    os.makedirs(visualize_path, exist_ok=True)
    os.makedirs(visualize_mask_path, exist_ok=True)

    os.environ['CUDA_VISIBLE_DEVICES'] = device_str
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    w_dict = torch.load(save_weight_file)
    torch.save(w_dict, os.path.join(output_root_path, str(time), 'template.pth'))
    with open(os.path.join(output_root_path, str(time), 'template.txt'), 'w') as f:
        f.write(str(w_dict['template_list']))
    
    net = Correlation2(w_dict['single'], 
                       w_dict['single_up'], 
                       w_dict['single_down'], 
                       w_dict['single_50'], 
                       w_dict['single_30'], 
                       w_dict['single_max'], 
                       w_dict['single_min'],
                      ).to(device)
    
    pbar = tqdm(os.listdir(input_train_path))
    process_count = 0
    for f in sorted(os.listdir(input_train_path)):
        if f.startswith('.') or os.path.isdir(os.path.join(input_train_path, f)):
            pbar.set_postfix_str(f'Ignore: {f}')
            pbar.update()
            continue
        pbar.set_postfix_str(f'Processing image: {f}')
        process_count += 1

        img_bgr = cv2.imread(os.path.join(input_train_path, f))

        mapp = {'sizeX':0, 'sizeY':0, 'numFeatures':0, 'map':0}
        mapp = fhog.getFeatureMaps(img_bgr, 4, mapp)
        h = int(mapp['sizeY'])
        w = int(mapp['sizeX'])
        d = int(mapp['numFeatures'])
        features = mapp['map'].reshape((h,w,d))
        cn = colorname.getcolorname(img_bgr,h)
        features = np.concatenate((features,cn),axis=2)
        features = torch.from_numpy(features.transpose((2,0,1))).to(device)

        with torch.no_grad():
            try:
                result = net(features).detach()
            except RuntimeError:
                net.cpu()
                result = net(features.cpu()).detach()
                net.to(device)

        try:
            result = F.interpolate(result, size=img_bgr.shape[0:2], mode='bilinear')[0]
        except RuntimeError:
            result = F.interpolate(result.cpu(), size=img_bgr.shape[0:2], mode='bilinear')[0]
        gc.collect()
        torch.cuda.empty_cache()

        try:
            result, _ = torch.sort(result, dim=0, descending=True)
        except RuntimeError:
            result, _ = torch.sort(result.cpu(), dim=0, descending=True)
        del _
        gc.collect()
        torch.cuda.empty_cache()
        
        try:
            if result.shape[0] > 2:
                avgpool = torch.mean(result[0:3], dim=0).unsqueeze_(-1) 
            else:
                avgpool = result[0].unsqueeze_(-1)
        except RuntimeError:
            result = result.cpu()
            if result.shape[0] > 2:
                avgpool = torch.mean(result[0:3], dim=0).unsqueeze_(-1) 
            else:
                avgpool = result[0].unsqueeze_(-1)
        del result
        gc.collect()
        torch.cuda.empty_cache()

        try:
            avgpool_min, avgpool_max = avgpool.min(), avgpool.max()
            avgout = ((avgpool - avgpool_min) / (avgpool_max - avgpool_min) * 255).cpu().numpy() 
        except RuntimeError:
            avgpool = avgpool.cpu()
            avgpool_min, avgpool_max = avgpool.min(), avgpool.max()
            avgout = ((avgpool - avgpool_min) / (avgpool_max - avgpool_min) * 255).cpu().numpy()
        del avgpool_min, avgpool_max
        avgpool = avgpool.cpu().numpy()
        gc.collect()
        torch.cuda.empty_cache()

        cv2.imwrite(os.path.join(corr_path, f), avgout)
        del avgout

        avgmask = np.zeros_like(avgpool)
        avgmask[avgpool>=thresh_absolute] = 255
        cv2.imwrite(os.path.join(mask_path, f), avgmask)
        if process_count <= 100:
            vis = np.zeros(img_bgr.shape)
            flag = (avgmask==255).repeat(3, axis=2)
            vis[flag] = img_bgr[flag]
            vis = 0.6 * img_bgr + 0.4 * vis    
            cv2.imwrite(os.path.join(visualize_mask_path, f), vis)
            del vis
        del avgmask

        avgdet = get_det(avgpool) 
        avgdetout = avgdet.copy()
        avgdetout[avgdetout>0] = 255 
        cv2.imwrite(os.path.join(grad_path, f), avgdetout) 

        avgdet[avgdet>0] = 1
        avgidx, avgbboxs, avgregions = get_connection(avgdet, thresh=1) 
        pool = np.zeros((avgpool.shape[0],avgpool.shape[1])) 
        for avgbox in avgbboxs:
            hmin,wmin,hmax,wmax = avgbox
            center = [int((hmax+hmin)/2),int((wmax+wmin)/2)]
            if avgpool[center[0],center[1]] < thresh_absolute: 
                continue
            for i in range(max(0,center[0]-rad),min(avgpool.shape[0],center[0]+rad)): 
                for j in range(max(0,center[1]-rad),min(avgpool.shape[1],center[1]+rad)):
                    if (i-center[0])**2+(j-center[1])**2<=rad**2 and avgpool[i,j] >= thresh_center*avgpool[center[0],center[1]]: 
                        pool[i,j] = 255
        cv2.imwrite(os.path.join(pool_path, f), pool)

        if process_count <= 100:
            pbar.set_postfix_str(f'Visualizing image {f}')
            mask = pool[:, :, None]
            vis = np.zeros(img_bgr.shape)
            flag = (mask==255).repeat(3, axis=2)
            vis[flag] = img_bgr[flag]
            
            vis = 0.6 * img_bgr + 0.4 * vis
                    
            cv2.imwrite(os.path.join(visualize_path, f), vis)
        pbar.update()
    pbar.close()



if __name__ == '__main__':
    preprocess_template()
    preprocess_image()
