import cv2
import os
import shutil
import numpy as np
import matplotlib.pyplot as plt
import torch
from PIL import Image
import time
from torchvision import transforms
'''
from functools import reduce
import time
import random
from ae import Autoencoder
import torch
'''
#torch.cuda.set_device(0)

liver2grad = {}
for grad in range(5): 
    grad2name = np.load('/nfs/yuxiaotian/HCC/grad2name.npy', allow_pickle=True)
    for file in grad2name[grad]: 
        liver2grad[file] = grad
    
def liver_name_to_grad(liver_name):
    if len(liver_name)==4:
        if liver_name[0] == '1':
            if int(liver_name[-2:]) < 20:
                cancer_grad = 1
            else:
                cancer_grad = 2
        elif liver_name[0] == '2':
            cancer_grad = 2
        elif liver_name[0] == '3':
            if liver_name == '3M01':
                cancer_grad = 2
            else:
                cancer_grad = 3
        elif liver_name[0] == '4':
            if int(liver_name[-2:]) < 6:
                cancer_grad = 3
            else:
                cancer_grad = 4
        elif liver_name[0] == '5':
            cancer_grad = 4
    elif len(liver_name) == 5 and liver_name[0]=='0':
        cancer_grad = 0
    else:
        cancer_grad = liver2grad[liver_name]
    return cancer_grad

def show_example(data_name='slide003_core004'):
    train_img = cv2.imread('data/Train Imgs/{}.jpg'.format(data_name))
    plt.figure()
    ax1 = plt.subplot2grid((3, 3), (0, 1), colspan=1, rowspan=1)
    ax1.imshow(train_img)
    ax1.set_title(data_name)
    ax1.axis('off')
    ax_list = [0] * 6
    for i in range(6):
        mask_img = cv2.imread('data/Maps{0}_T/{1}_classimg_nonconvex.png'.format((i+1), data_name))
        if mask_img is None:
            continue
        else:
            ax_list[i] = plt.subplot2grid((3, 3), (i//3+1, i % 3))
            ax_list[i].set_title(str(set(mask_img.flatten())))
            mask_img = mask_img * (255//5)
            ax_list[i].imshow(mask_img)
            ax_list[i].axis('off')
    plt.show()


def label_processing():
    maps_name_list = [file[:-4] for file in os.listdir('data/Train Imgs/')]
    label_arr = np.empty((6, len(maps_name_list), 7))
    for i in range(6):
        for j, maps_name in enumerate(maps_name_list):
            mask_img = cv2.imread('data/Maps{0}_T/{1}_classimg_nonconvex.png'.format((i+1), maps_name))
            if mask_img is None:
                label_rate = np.array([-1] * 7).astype(label_arr.dtype)
            else:
                label_rate = np.array([0] * 7).astype(label_arr.dtype)
                total_num = reduce(lambda x, y: x*y, [*mask_img.shape])
                mask_img = mask_img.flatten()
                for grade in set(mask_img):
                    label_rate[grade] = (mask_img == grade).sum() / total_num
            label_arr[i][j] = label_rate
            print('\r{0}-{1}:done'.format(i, j), end='')
    np.save('label_array.npy', label_arr)


def label_analysis():
    label_arr = np.load('label_array.npy')
    label_ratio = np.array([[label_arr[j, :, i][label_arr[j, :, i] >= 0].mean() for i in range(7)] for j in range(6)])

    plt.figure()
    ax1 = plt.subplot2grid((1, 2), (0, 0), colspan=1, rowspan=1)
    ax1.bar(np.arange(7), np.mean(label_ratio, axis=0))
    ax2 = plt.subplot2grid((1, 2), (0, 1), colspan=1, rowspan=1)
    for i in range(6):
        ax2.bar(np.arange(7)+i*0.15, label_ratio[i], width=0.15, label=str(i))
    plt.show()

    
def maps_processing():
    for i in range(1, 7):
        main_path = 'data/Maps{}_T/'.format(i)
        map_path_list = os.listdir(main_path)
        for map_path in map_path_list:
            map_img = cv2.imread(main_path+map_path)
            map_img[map_img>5] = 5
            map_img *= 51
            cv2.imwrite(main_path+map_path[:16]+'.png', map_img)
    

def ground_truth(data_file='data/Train Imgs/', maps_num=6, start_idx=0):
    img_name_list = os.listdir(data_file)[start_idx:]
    for name_idx, img_name in enumerate(img_name_list):
        t0 = time.time()
        # load maps
        img = cv2.imread(data_file + img_name)
        maps_arr = np.zeros((maps_num, *img.shape))
        map_name = img_name[:-4]
        for map_id in range(maps_num):
            map_file = 'data/Maps{0}_T/{1}_classimg_nonconvex.png'.format((map_id+1), map_name)
            mask_img = cv2.imread(map_file)
            if mask_img is None:
                maps_arr[map_id] -= 1
            else:
                maps_arr[map_id] = mask_img

        # get mode
        def get_mode(grade_arr):
            grade_arr = grade_arr[grade_arr >= 0]
            grades = np.array(list(set(grade_arr)))
            grade_num = np.array([(grade_arr == grade).sum() for grade in grades])
            grades = grades[grade_num == grade_num.max()]
            return max(grades)

        assert np.all(maps_arr[:, :, :, 0] == maps_arr[:, :, :, 1])
        assert np.all(maps_arr[:, :, :, 0] == maps_arr[:, :, :, 2])
        maps_arr = maps_arr[:, :, :, 0].transpose((1, 2, 0))
        maps_arr[maps_arr<=51*2] = 0
        maps_arr[maps_arr>255] = 255
        gt_map = [get_mode(grade_arr) for grade_arr in maps_arr.reshape((-1, maps_num))]
        gt_map = np.array(gt_map).reshape([*maps_arr.shape[:2]])
        cv2.imwrite('data/GroundTruth/{}.png'.format(map_name), gt_map)
        t1 = time.time()
        print('\r{0}/{1} done\t time:{2:.2}'.format(name_idx, len(img_name_list), t1-t0), end='')

        
def ground_truth_binary(data_file='data/Train Imgs/', maps_num=6, start_idx=0):
    img_name_list = os.listdir(data_file)[start_idx:]
    for name_idx, img_name in enumerate(img_name_list):
        t0 = time.time()
        # load maps
        img = cv2.imread(data_file + img_name)
        maps_arr = np.zeros((maps_num, *img.shape))
        map_name = img_name[:-4]
        if os.path.exists('data/GroundTruth_binary/{}.png'.format(map_name)):
            continue
        for map_id in range(maps_num):
            map_file = 'data/Maps{0}_T/{1}.png'.format((map_id+1), map_name)
            mask_img = cv2.imread(map_file)
            if mask_img is None:
                maps_arr[map_id] -= 1
            else:
                maps_arr[map_id] = mask_img

        # get mode
        def get_mode(grade_arr):
            grade_arr = grade_arr[grade_arr >= 0]
            grades = np.array(list(set(grade_arr)))
            grade_num = np.array([(grade_arr == grade).sum() for grade in grades])
            grades = grades[grade_num == grade_num.max()]
            return max(grades)

        assert np.all(maps_arr[:, :, :, 0] == maps_arr[:, :, :, 1])
        assert np.all(maps_arr[:, :, :, 0] == maps_arr[:, :, :, 2])
        maps_arr = maps_arr[:, :, :, 0].transpose((1, 2, 0))
        maps_arr[maps_arr>51*2] = 255
        maps_arr[maps_arr<=51*2] = 0
        gt_map = [get_mode(grade_arr) for grade_arr in maps_arr.reshape((-1, maps_num))]
        gt_map = np.array(gt_map).reshape([*maps_arr.shape[:2]])
        cv2.imwrite('data/GroundTruth_binary/{}.png'.format(map_name), gt_map)
        t1 = time.time()
        print('\r{0}/{1} done\t time:{2:.2}'.format(name_idx, len(img_name_list), t1-t0), end='')


def ground_truth_multi(data_file='data/Train Imgs/', maps_num=6, start_idx=0):
    img_name_list = os.listdir(data_file)[start_idx:]
    for name_idx, img_name in enumerate(img_name_list):
        t0 = time.time()
        # load maps
        img = cv2.imread(data_file + img_name)
        maps_arr = np.zeros((maps_num, *img.shape))
        map_name = img_name[:-4]
        if os.path.exists('data/GroundTruth_multi/{}_0.png'.format(map_name)):
            continue
        for map_id in range(maps_num):
            map_file = 'data/Maps{0}_T/{1}.png'.format((map_id+1), map_name)
            mask_img = cv2.imread(map_file)
            if mask_img is None:
                maps_arr[map_id] -= 1
            else:
                maps_arr[map_id] = mask_img

        # get mode
        def get_conf(grade_arr):
            grade_list = [0, 3*51, 4*51, 5*51] 
            grade_arr = grade_arr[grade_arr >= 0]
            
            conf_list = [(grade_arr==grade).sum() / len(grade_arr) for grade in grade_list]
            return conf_list

        assert np.all(maps_arr[:, :, :, 0] == maps_arr[:, :, :, 1])
        assert np.all(maps_arr[:, :, :, 0] == maps_arr[:, :, :, 2])
        maps_arr = maps_arr[:, :, :, 0].transpose((1, 2, 0))
        maps_arr[maps_arr>255] = 255
        maps_arr[maps_arr<=51*2] = 0
        gt_map = [get_conf(grade_arr) for grade_arr in maps_arr.reshape((-1, maps_num))]
        gt_map = np.array(gt_map).reshape([*maps_arr.shape[:2], 4])
        for idx, grade_idx in enumerate([0, 3, 4, 5]):
            cv2.imwrite('data/GroundTruth_multi/{0}_{1}.png'.format(map_name, grade_idx), np.array(gt_map[:,:,idx]*255, dtype=int))
        t1 = time.time()
        print('\r{0}/{1} done\t time:{2:.2}'.format(name_idx, len(img_name_list), t1-t0), end='')


def split_latticed_instances(img, m=224*4):
    H, W = img.shape[:2]
    N1 = H // m
    N2 = W // m
    if H == N1*m:
        rand_h = np.zeros(N1, dtype=int)
    else:
        rand_h = np.random.randint(0, H-N1*m, N1)
    if W == N2*m:
        rand_w = np.zeros(N2, dtype=int)
    else:
        rand_w = np.random.randint(0, W-N2*m, N2)

    latticed_imgs = np.empty((N1*N2, m, m, 3), dtype=img.dtype)
    loc_list = []
    for i in range(N1*N2):
        x_start = (i % N2) * m + rand_w[(i % N2)]
        y_start = (i // N2) * m + rand_h[(i // N2)]
        patch = img[y_start:y_start+m, x_start:x_start+m]
        latticed_imgs[i] = patch
        loc_list.append([x_start, y_start])
    return latticed_imgs, loc_list

def create_cv_file(main_path, cv_num=10):
    if not os.path.exists(main_path):
        os.mkdir(main_path)
    for cv_idx in range(cv_num):
        os.mkdir(main_path+'cv_{0}'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/train'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/unlabeled'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/test'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/valid'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/extra'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/unlabeled_extra'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/train/0'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/train/1'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/unlabeled/0'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/unlabeled/1'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/test/0'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/test/1'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/valid/0'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/valid/1'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/extra/0'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/extra/1'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/unlabeled_extra/0'.format(cv_idx))
        os.mkdir(main_path+'cv_{0}/unlabeled_extra/1'.format(cv_idx))

        
def transfer_samples(main_path, num, source, target, cv_num=10, num_mode='remain'):
    random.seed(0)
    for cv_idx in range(cv_num):
        source_path = main_path + 'cv_{}/'.format(cv_idx) + source
        target_path = main_path + 'cv_{}/'.format(cv_idx) + target
        source_file_0 = os.listdir(source_path+'0/')
        source_file_1 = os.listdir(source_path+'1/')
        if num_mode=='remain':
            source_file_0 = np.random.choice(source_file_0, len(source_file_0)-num, replace=False)
            source_file_1 = np.random.choice(source_file_1, len(source_file_1)-num, replace=False)
        else:
            source_file_0 = np.random.choice(source_file_0, num, replace=False)
            source_file_1 = np.random.choice(source_file_1, num, replace=False)
            
        for file in source_file_0:
            os.rename(source_path+'0/'+file, target_path+'0/'+file)
        for file in source_file_1:
            os.rename(source_path+'1/'+file, target_path+'1/'+file)
    
    
def split_latticed_patches(img_file, gt_file, cv_idx, main_path='/nfs/yuxiaotian/Gleason/normal&cancer/semi_impure/', train=True, m=448, overlap=50, edge=1000, stride=10, start_idx=0, rgb=True):
    img = cv2.imread(img_file)
    if not rgb:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY).reshape(*img.shape[:2], 1)
    img_gt = cv2.imread(gt_file)
    H, W = img_gt.shape[:2]
    
    if train:
        x1, y1, x2, y2 = [], [], [], []
        x1_, y1_, x2_, y2_ = [], [], [], []
        rate_list = [0, 0, 0, 0, 0, 0]
        for h in range(edge, H-m, stride):
            for w in range(edge, W-m, stride):
                labeled_flag = [x1[i]<=h and x2[i]>=h and y1[i]<=w and y2[i]>=w for i in range(len(x1))]
                if sum(labeled_flag) == 0:
                    img_gt_patch = img_gt[h:h+m, w:w+m, 0].flatten()
                    normal_rate = (img_gt_patch==0).sum() / m / m
                    img_patch = img[h:h+m, w:w+m, :]
                    # background
                    if img_patch[:, :, 0].mean()>200 and img_patch[:, :, 1].mean()>200 and img_patch[:, :, 2].mean()>200:
                        continue

                    if normal_rate <= 0.75:
                        cv2.imwrite(main_path+'cv_{0}/train/1/{1}_Y{2}_X{3}_{4}.png'.format(cv_idx, img_file[-20:-4], h, w, start_idx), img_patch)
                        x1.append(h-m+overlap)
                        y1.append(w-m+overlap)
                        x2.append(h+m-overlap)
                        y2.append(w+m-overlap)
                        start_idx += 1
                    elif normal_rate >= 0.99:
                        cv2.imwrite(main_path+'cv_{0}/train/0/{1}_Y{2}_X{3}_{4}.png'.format(cv_idx, img_file[-20:-4], h, w, start_idx), img_patch)
                        x1.append(h-m+overlap)
                        y1.append(w-m+overlap)
                        x2.append(h+m-overlap)
                        y2.append(w+m-overlap)
                        start_idx += 1
                    else:
                        continue
                
                labeled_flag = [x1[i]<=h and x2[i]>=h and y1[i]<=w and y2[i]>=w for i in range(len(x1))]
                unlabeled_flag = [x1_[i]<=h and x2_[i]>=h and y1_[i]<=w and y2_[i]>=w for i in range(len(x1_))]
                if sum(unlabeled_flag) == 0 and sum(labeled_flag) == 0:
                    img_gt_patch = img_gt[h:h+m, w:w+m, 0].flatten()
                    normal_rate = (img_gt_patch==0).sum() / m / m
                    img_patch = img[h:h+m, w:w+m, :]
                    
                    if normal_rate < 0.99:
                        cv2.imwrite(main_path+'cv_{0}/unlabeled/1/{1}_Y{2}_X{3}_{4}.png'.format(cv_idx, img_file[-20:-4], h, w, start_idx), img_patch)
                    else:
                        cv2.imwrite(main_path+'cv_{0}/unlabeled/0/{1}_Y{2}_X{3}_{4}.png'.format(cv_idx, img_file[-20:-4], h, w, start_idx), img_patch)
                    x1_.append(h-m+overlap)
                    y1_.append(w-m+overlap)
                    x2_.append(h+m-overlap)
                    y2_.append(w+m-overlap)
                    start_idx += 1

    else:
        x1, y1, x2, y2 = [], [], [], []
        rate_list = [0, 0, 0, 0, 0, 0]
        for h in range(edge, H-edge-m, stride):
            for w in range(edge, W-edge-m, stride):
                flag = [x1[i]<=h and x2[i]>=h and y1[i]<=w and y2[i]>=w for i in range(len(x1))]
                if sum(flag) == 0:
                    img_gt_patch = img_gt[h:h+m, w:w+m, 0].flatten()
                    normal_rate = (img_gt_patch==0).sum() / m / m
                    img_patch = img[h:h+m, w:w+m, :]
                    # background
                    if img_patch[:, :, 0].mean()>200 and img_patch[:, :, 1].mean()>200 and img_patch[:, :, 2].mean()>200:
                        continue
                    
                    if normal_rate < 0.75:
                        cv2.imwrite(main_path+'cv_{0}/test/1/{1}_Y{2}_X{3}_{4}_{5:.2f}.png'.format(cv_idx, img_file[-20:-4], h, w, start_idx, normal_rate), img_patch)
                    elif normal_rate >= 0.99:
                        cv2.imwrite(main_path+'cv_{0}/test/0/{1}_Y{2}_X{3}_{4}_{5:.2f}.png'.format(cv_idx, img_file[-20:-4], h, w, start_idx, normal_rate), img_patch)
                    else:
                        cv2.imwrite(main_path+'cv_{0}/extra/2/{1}_Y{2}_X{3}_{4}_{5:.2f}.png'.format(cv_idx, img_file[-20:-4], h, w, start_idx, normal_rate), img_patch)
                    x1.append(h-m+overlap)
                    y1.append(w-m+overlap)
                    x2.append(h+m-overlap)
                    y2.append(w+m-overlap)
                    start_idx += 1
            
        '''
        for i in range(6):
            if ((img_gt_patch==i*51).sum() / m / m) >= 0.9:
                x1.append(h-m)
                y1.append(w-m)
                x2.append(h+m)
                y2.append(w+m)
            elif ((img_gt_patch==i*51).sum() / m / m) >= 0.5:
                x1.append(h-m+100)
                y1.append(w-m+100)
                x2.append(h+m-100)
                y2.append(w+m-100)
            else:
                continue
            rate_list[i] += 1
            img_patch = img[h:h+m, w:w+m, :]
            cv2.imwrite('/nfs/yuxiaotian/Gleason/grading/train/{0}/{1}.png'.format(i, start_idx), img_patch)
            start_idx += 1
        '''
    return start_idx, rate_list


def patch_split(img_path='data/Train Imgs/', gt_path='data/GroundTruth/', main_path='/nfs/yuxiaotian/Gleason/normal&cancer/semi_gray/', cv_num=10, rgb=True):
    random.seed(0)
    img_path_list = os.listdir(img_path)
    img_path_list = random.sample(img_path_list, len(img_path_list))
    #img_path_num = [len(img_path_list)//cv_num+1] * (len(img_path_list)%cv_num) + [len(img_path_list)//cv_num] * (10-len(img_path_list)%cv_num)
    img_path_num = [23, 24, 27, 23, 24, 25, 24, 25, 25, 24]
    img_path_group = [img_path_list[sum(img_path_num[:i]):sum(img_path_num[:i+1])] for i in range(cv_num)]
    for cv_idx in range(10):
        next_idx = 0
        total_list = [0, 0, 0, 0, 0, 0]
        for group_idx, img_path_list in enumerate(img_path_group):
            for img_idx, img_file in enumerate(img_path_list):
                gt_file = gt_path + img_file[:-4] + '.png'
                img_file = img_path + img_file
                if not (os.path.exists(img_file) and os.path.exists(gt_file)):
                    continue
                #if (group_idx > cv_idx and group_idx <= (cv_idx+5)) or group_idx <= cv_idx-(10-5):
                if group_idx != cv_idx:
                    next_idx, rate_list = split_latticed_patches(img_file, gt_file, main_path=main_path, start_idx=next_idx, cv_idx=cv_idx, train=True, rgb=rgb)
                elif group_idx == cv_idx:
                    next_idx, rate_list = split_latticed_patches(img_file, gt_file, main_path=main_path, start_idx=next_idx, cv_idx=cv_idx, train=False, rgb=rgb)

                print('\rcv{0} {1}/{2} done {3}'.format(cv_idx, img_idx+sum(img_path_num[:group_idx]), len(os.listdir(img_path)), next_idx), end='')
        print()
        # remove extra file
        train_num = min(len(os.listdir(main_path + 'cv_{0}/train/1/'.format(cv_idx))), len(os.listdir(main_path + 'cv_{0}/train/0/'.format(cv_idx))))
        if len(os.listdir(main_path + 'cv_{0}/train/1/'.format(cv_idx))) > train_num:
            remove_files = np.random.choice(os.listdir(main_path + 'cv_{0}/train/1/'.format(cv_idx)), len(os.listdir(main_path + 'cv_{0}/train/1/'.format(cv_idx)))-train_num, replace=False)
            for file in remove_files:
                os.remove(main_path + 'cv_{0}/train/1/'.format(cv_idx) + file)
        elif len(os.listdir(main_path + 'cv_{0}/train/0/'.format(cv_idx))) > train_num:
            remove_files = np.random.choice(os.listdir(main_path + 'cv_{0}/train/0/'.format(cv_idx)), len(os.listdir(main_path + 'cv_{0}/train/0/'.format(cv_idx)))-train_num, replace=False)
            for file in remove_files:
                os.remove(main_path + 'cv_{0}/train/0/'.format(cv_idx) + file)
                
        unlabeled_num = min(len(os.listdir(main_path + 'cv_{0}/unlabeled/1/'.format(cv_idx))), len(os.listdir(main_path + 'cv_{0}/unlabeled/0/'.format(cv_idx))), 10000)
        if len(os.listdir(main_path + 'cv_{0}/unlabeled/1/'.format(cv_idx))) > unlabeled_num:
            remove_files = np.random.choice(os.listdir(main_path + 'cv_{0}/unlabeled/1/'.format(cv_idx)), len(os.listdir(main_path + 'cv_{0}/unlabeled/1/'.format(cv_idx)))-unlabeled_num, replace=False)
            for file in remove_files:
                os.remove(main_path + 'cv_{0}/unlabeled/1/'.format(cv_idx) + file)
        if len(os.listdir(main_path + 'cv_{0}/unlabeled/0/'.format(cv_idx))) > unlabeled_num:
            remove_files = np.random.choice(os.listdir(main_path + 'cv_{0}/unlabeled/0/'.format(cv_idx)), len(os.listdir(main_path + 'cv_{0}/unlabeled/0/'.format(cv_idx)))-unlabeled_num, replace=False)
            for file in remove_files:
                os.remove(main_path + 'cv_{0}/unlabeled/0/'.format(cv_idx) + file)
                
        test_num = min(len(os.listdir(main_path + 'cv_{0}/test/1/'.format(cv_idx))), len(os.listdir(main_path + 'cv_{0}/test/0/'.format(cv_idx))), 1280)
        if len(os.listdir(main_path + 'cv_{0}/test/1/'.format(cv_idx))) > test_num:
            remove_files = np.random.choice(os.listdir(main_path + 'cv_{0}/test/1/'.format(cv_idx)), len(os.listdir(main_path + 'cv_{0}/test/1/'.format(cv_idx)))-test_num, replace=False)
            for file in remove_files:
                os.remove(main_path + 'cv_{0}/test/1/'.format(cv_idx) + file)
        if len(os.listdir(main_path + 'cv_{0}/test/0/'.format(cv_idx))) > test_num:
            remove_files = np.random.choice(os.listdir(main_path + 'cv_{0}/test/0/'.format(cv_idx)), len(os.listdir(main_path + 'cv_{0}/test/0/'.format(cv_idx)))-test_num, replace=False)
            for file in remove_files:
                os.remove(main_path + 'cv_{0}/test/0/'.format(cv_idx) + file)
        
        #valid
        valid_files_0 = np.random.choice(os.listdir(main_path + 'cv_{0}/test/0/'.format(cv_idx)), 128, replace=False)
        valid_files_1 = np.random.choice(os.listdir(main_path + 'cv_{0}/test/1/'.format(cv_idx)), 128, replace=False)
        for file in valid_files_0:
            os.rename(main_path + 'cv_{0}/test/0/'.format(cv_idx) + file, main_path + 'cv_{0}/valid/0/'.format(cv_idx) + file)
        for file in valid_files_1:
            os.rename(main_path + 'cv_{0}/test/1/'.format(cv_idx) + file, main_path + 'cv_{0}/valid/1/'.format(cv_idx) + file)


def ground_truth_split(data_file='data/GroundTruth/'):
    for img_idx, img_name in enumerate(os.listdir(data_file)):
        img = cv2.imread(data_file+img_name)
        latticed_imgs, loc_list = split_latticed_instances(img, 224*4)
        latticed_imgs = np.where(latticed_imgs > 255/6, 255, 0)

        for i, latticed_img in enumerate(latticed_imgs):
            if latticed_img.mean() > 255*0.8:
                continue
            elif latticed_img.mean() >= 255/100:
                cv2.imwrite('data/gt_cancer/1/{0}-{1}-{2}.png'.format(img_name[:-4], *loc_list[i]), latticed_img)
            else:
                cv2.imwrite('data/gt_cancer/0/{0}-{1}-{2}.png'.format(img_name[:-4], *loc_list[i]), latticed_img)
        print('{0}/{1} done'.format(img_idx, len(os.listdir(data_file))))



def create_cancer_dataset(img_path='data/Train Imgs/', gt_path='data/gt_cancer/', img_size=224*4):
    noncancer_path = os.listdir(gt_path+'0/')
    cancer_path = os.listdir(gt_path+'1/')
    image_path = cancer_path + noncancer_path
    label_all = [1] * len(cancer_path) + [0] * len(noncancer_path)
    random_idx = np.random.choice(np.arange(len(label_all)), len(label_all), replace=False)
    image_path = np.array(image_path)[random_idx]
    label_all = np.array(label_all)[random_idx]

    idx = 0
    imageset = []
    labelset = []
    for i, path in enumerate(image_path):
        img_name, x_start, y_start = path.split('-')
        img = cv2.imread(img_path + img_name + '.jpg')
        x_start = int(x_start)
        y_start = int(y_start[:-4])
        img = img[y_start:y_start+img_size, x_start:x_start+img_size].transpose((2, 1, 0))
        if img.shape != (3, img_size, img_size):
            continue
        imageset.append(img / 255.)
        labelset.append(label_all[i])
        print('{0}/{1} done'.format(i, len(label_all)))
        if (i+1)%512 == 0:
            np.save('data/cancer_imageset_{}.npy'.format(idx), imageset)
            np.save('data/cancer_labelset_{}.npy'.format(idx), labelset)
            idx += 1
            imageset = []
            labelset = []
    np.save('data/cancer_imageset_{}.npy'.format(idx), imageset)
    np.save('data/cancer_labelset_{}.npy'.format(idx), labelset)
    np.save('data/cancer_filename.npy', image_path)


def load_image_dataset(idx, dataset_path='data/'):
    imageset_file = dataset_path + 'cancer_imageset_{}.npy'.format(idx)
    labelset_file = dataset_path + 'cancer_labelset_{}.npy'.format(idx)
    image_set = np.load(imageset_file)
    label_set = np.load(labelset_file)
    return image_set, label_set


def load_instance_dataset(dataset_path='data/'):
    imageset_file = dataset_path + 'instance_image_set.npy'
    labelset_file = dataset_path + 'instance_label_set.npy'
    image_set = np.load(imageset_file)
    label_set = np.load(labelset_file)
    return image_set, label_set


def create_patch_dataset(
        img_path='data/Train Imgs/',
        target_path='data/Train_patch/',
        gt_path='data/GroundTruth/',
        label_path='data/GroundTruth_patch/',
        patch_size=224,
        dataset_num=100000,
        dataset_size=10000):
    # create image
    if len(os.listdir(target_path)) == 0:
        img_file_list = os.listdir(gt_path)
        for img_num, img_file in enumerate(img_file_list):
            gt_img = cv2.imread(gt_path+img_file)
            img = cv2.imread(img_path+img_file[:-4]+'.jpg')
            split_img, split_loc = split_latticed_instances(gt_img, m=patch_size)
            for i, patch in enumerate(split_img):
                loc_x, loc_y = split_loc[i]
                cv2.imwrite('{0}{1}_{2}_{3}.png'.format(label_path, img_file[:-4], loc_x, loc_y), patch)
                cv2.imwrite('{0}{1}_{2}_{3}.jpg'.format(target_path, img_file[:-4], loc_x, loc_y), img[loc_y:loc_y+patch_size, loc_x:loc_x+patch_size])
            print('{}/{}'.format(img_num, len(img_file_list)))
    # create npy
    else:
        input_file_list = os.listdir(target_path)
        input_file_list = np.random.choice(input_file_list, dataset_num, replace=False)

        for i, input_file in enumerate(input_file_list):
            if i % dataset_size == 0:
                imageset = np.zeros((dataset_size, patch_size, patch_size, 3))
                labelset = np.zeros((dataset_size, 4))

            input_file = input_file[:-4]
            image = cv2.imread('{}.jpg'.format(target_path + input_file))
            gt_image = cv2.imread('{}.png'.format(label_path + input_file))[:, :, 0].flatten()
            label = np.zeros(4)
            if len(set(gt_image).union(set([0, 153, 204, 255]))) != 4:
                print(i)

            for idx, no in enumerate([0, 153, 204, 255]):
                if (gt_image == no).sum() >= patch_size ** 2 / 100:
                    label[idx] = 1
            imageset[i % dataset_size] = image
            labelset[i % dataset_size] = label

            if i % dataset_size == (dataset_size - 1):
                np.save('data/patch_train_{}.npy'.format(i//dataset_size), imageset)
                np.save('data/patch_label_{}.npy'.format(i//dataset_size), labelset)
                print((i+1), 'done')


# liver patches
# xml file

def read_points_from_xml(liver_name, xml_path='/nfs/yuxiaotian/HCC/xml/'):
    import xml.dom.minidom as dom
    xml = dom.parse(xml_path+liver_name+'_Annotations.xml')
    anno_list = xml.documentElement.getElementsByTagName('annotation')
    polygons = []
    for anno in anno_list:
        polygons.append([])
        for point in anno.getElementsByTagName('p'):
            x = int(point.getAttribute('x'))
            y = int(point.getAttribute('y'))
            polygons[-1].append([y, x])
    return polygons


def read_points_from_json(liver_name, json_path='/nfs3-p2/yuxiaotian/CAMELYON16/training/lesion_annotations/'):
    import json
    if not os.path.exists(json_path+liver_name+'.json'):
        return None
    with open(json_path+liver_name+'.json') as f:
        jsn = json.load(f)
    polygons = []
    for anno in jsn['positive']:
        polygons.append([])
        for x, y in anno['vertices']:
            polygons[-1].append([y, x])
    return polygons


# polygon & point: (y, x)
def point_in_polygon(polygon, point):
    y, x = point
    # out of box
    bbox = [np.min(polygon[:, 0]), np.min(polygon[:, 1]), np.max(polygon[:, 0]), np.max(polygon[:, 1])] # [min_y, min_x, max_y, max_x]
    if y < bbox[0] or y > bbox[2] or x < bbox[1] or x > bbox[3]:
        return False
    
    # on box border
    if np.max(np.sum(polygon==point, axis=1)) == 2:
        return True
    
    # in polygon
    flag = False
    for i in range(len(polygon)-1):
        start_y, start_x = polygon[i]
        end_y, end_x = polygon[i+1]
        if (y-start_y)*(y-end_y) <= 0:
            if start_y == end_y:
                projection_x = x
            else:
                projection_x = start_x + (end_x-start_x)/(end_y-start_y) * (y-start_y)
            if x > projection_x:
                flag = not flag
            elif x == projection_x:
                return True
    return flag

def point_in_polygons(polygons, point):
    for polygon in polygons:
        polygon = np.array(polygon)
        if point_in_polygon(polygon, point):
            return True
    return False

def patch_in_polygons(polygons, patch_point, patch_size=448):
    for polygon in polygons:
        polygon = np.array(polygon)
        patch_point = np.array(patch_point)
        if point_in_polygon(polygon, patch_point) or point_in_polygon(polygon, patch_point+(0, patch_size)) or point_in_polygon(polygon, patch_point+(patch_size, 0)) or point_in_polygon(polygon, patch_point+(patch_size, patch_size)):
            return True
    return False


def cluster_in_polygons(polygons, cluster, init_point, ratio=1.0):
    for polygon in polygons:
        polygon = np.array(polygon)
        if point_in_polygon(polygon, (init_point[0]+cluster.h*ratio, init_point[1]+cluster.w*ratio)):
            return True
    return False



def get_all_patch(liver_name, source_path, target_path, stride=5000):
    #source_path = '/nfs3-p2/yuxiaotian/HCC_data/HCC/'
    #target_path = '/nfs3-p2/yuxiaotian/slic/train_image/'
    import openslide
    if 'HCC' in source_path:
        tumor_type = 'hcc'
        grad = liver_name_to_grad(liver_name)
        source_path = os.path.join(source_path, str(grad))
        liver_file = source_path + liver_name + '.mrxs'
        polygons = read_points_from_xml(liver_name)
    elif 'CAMELYON16' in source_path:
        tumor_type = 'camelyon16'
        liver_file = source_path + liver_name + '.tif'
        polygons = read_points_from_json(liver_name)
    slide = openslide.open_slide(liver_file)
    ymin = min([np.array(poly)[:, 0].min() for poly in polygons])
    ymax = min([np.array(poly)[:, 0].max() for poly in polygons])
    xmin = min([np.array(poly)[:, 1].min() for poly in polygons])
    xmax = max([np.array(poly)[:, 1].max() for poly in polygons])
    ymin = (ymin//stride-2)*stride
    ymax = (ymax//stride+3)*stride
    xmin = (xmin//stride-2)*stride
    xmax = (xmax//stride+3)*stride
    for y in range(ymin, ymax, stride):
        for x in range(xmin, xmax, stride):
            region = slide.read_region((x, y), 0, (stride, stride)).convert('RGB')
            cv2.imwrite(target_path+'{0}_Y{1}_X{2}_SIZE{3}.png'.format(liver_name, y, x, stride), np.array(region)[:,:,::-1])


def split_liver_patches(liver_name, main_path='/nfs3-p2/yuxiaotian/CAMELYON16/training/tumor/', target_path='/nfs3-p2/yuxiaotian/CAMELYON16/training/valid_patch/', cancer_grad=1, train=True, m=448, stride='same', start_idx=0):
    import openslide
    if stride=='same':
        stride = m
    liver_file = main_path + liver_name + '.mrxs'
    try:
        slide = openslide.open_slide(liver_file)
    except:
        return
    X, Y = slide.level_dimensions[0]
    if len([file for file in os.listdir(target_path+'0/') if file.split('_')[0]==liver_name]) == 1000:
        print(liver_name, 'done')
        return
    
    polygons = read_points_from_xml(liver_name)
    #polygons = read_points_from_json(liver_name)
    if polygons is None:
        xmin = 20000
        xmax = 80000
        ymin = 100000
        ymax = 170000
        stride *= 5
    else:
        xmin = min([np.array(poly)[:, 1].min() for poly in polygons])
        ymin = min([np.array(poly)[:, 0].min() for poly in polygons])
        xmax = max([np.array(poly)[:, 1].max() for poly in polygons])
        ymax = max([np.array(poly)[:, 0].max() for poly in polygons])
    target_path_n = target_path+'0/'
    target_path_c = target_path+'{}/'.format(cancer_grad)
    
    for y in range(ymin, ymax, stride):
        for x in range(xmin, xmax, stride):
            if os.path.exists(target_path+'0/'+'{0}_Y{1}_X{2}.png'.format(liver_name, y, x)):
                continue
            region = slide.read_region((x, y), 0, (m, m)).convert('RGB')
            # contain cell
            region_f = np.array(region).reshape(-1, 3)
            mean_rgb = np.mean(region_f, axis=1)
            print('\r{0} {1}'.format(x, y), end='')
            if (np.mean(np.abs((region_f.transpose(1,0) - (mean_rgb, mean_rgb, mean_rgb))), axis=0)<10).sum() / mean_rgb.shape[0] > 0.8:
                continue
            # polygon region
            if (not polygons is None) and patch_in_polygons(polygons, (y, x), patch_size=m):
                # cancer
                cv2.imwrite(target_path_c+'{0}_Y{1}_X{2}.png'.format(liver_name, y, x), cv2.resize(np.array(region)[:,:,::-1], (m, m)))
            else:
                # normal
                cv2.imwrite(target_path_n+'{0}_Y{1}_X{2}.png'.format(liver_name, y, x), cv2.resize(np.array(region)[:,:,::-1], (m, m)))
            start_idx += 1
            #print('\r{0}_{1}'.format(liver_name, start_idx), end='')
    
    
    liver_patch_c = [target_path_c+file for file in os.listdir(target_path_c) if file[:len(liver_name)]==liver_name]
    liver_patch_n = [target_path_n+file for file in os.listdir(target_path_n) if file[:len(liver_name)]==liver_name]
    print(liver_name, len(liver_patch_c), len(liver_patch_n))
    if len(liver_patch_c) > 0: 
        liver_patch_c = np.random.choice(liver_patch_c, -300, replace=False)
        for file in liver_patch_c:
            os.remove(file)
    if len(liver_patch_n) > 0: 
        liver_patch_n = np.random.choice(liver_patch_n, -150, replace=False)
        for file in liver_patch_n:
            os.remove(file)
    
    
def split_PANDA_patches(name, wsi_path='/nfs3-p2/yuxiaotian/PANDA/train_images/', mask_path='/nfs3-p2/yuxiaotian/PANDA/train_label_masks/', target_path='/nfs3-p2/yuxiaotian/PANDA/train_region/', m=2500, stride='same'):
    import openslide
    if stride=='same':
        stride = m
    wsi_file = wsi_path + name + '.tiff'
    mask_file = mask_path + name + '_mask.tiff'
    try:
        slide = openslide.open_slide(wsi_file)
        mask = openslide.open_slide(mask_file)
    except:
        return
    X, Y = slide.level_dimensions[0]
    xmin = 0
    ymin = 0
    xmax = X - m + 1 
    ymax = Y - m + 1
    
    for y in range(ymin, ymax, stride):
        for x in range(xmin, xmax, stride):
            if os.path.exists(target_path+'{0}_Y{1}_X{2}.png'.format(name, y, x)):
                continue
            region = slide.read_region((x, y), 0, (m, m)).convert('RGB')
            try:
                region_mask = mask.read_region((x, y), 0, (m, m))
            except:
                continue
            region_mask = np.array(region_mask)[:,:,0] # 0/1/2/3/4/5
            if (region_mask>0).mean() < 0.5:
                continue
            # contain cell
            region_f = np.array(region).reshape(-1, 3)
            mean_rgb = np.mean(region_f, axis=1)
            print('\r{0} {1}'.format(x, y), end='')
            if np.array(region).shape[:2] != (m, m):
                continue
            if (np.mean(np.abs((region_f.transpose(1,0) - (mean_rgb, mean_rgb, mean_rgb))), axis=0)<10).sum() / mean_rgb.shape[0] > 0.8:
                continue
            # polygon region
            r = [int((region_mask<3).mean()*100)] + [int((region_mask==i).mean()*100) for i in range(3,6)]
            cv2.imwrite(target_path+'{0}_Y{1}_X{2}_r0{3}_r3{4}_r4{5}_r5{6}.png'.format(name, y, x, r[0], r[1], r[2], r[3]), np.array(region)[:,:,::-1])
            cv2.imwrite(target_path+'{0}_Y{1}_X{2}_r0{3}_r3{4}_r4{5}_r5{6}_mask.png'.format(name, y, x, r[0], r[1], r[2], r[3]), np.array(region_mask))
    print(name)
            
# 1M01-1M19; 1M20-3M01; 3M02-4M05; 4M06-5M07
def liver_patch_dataset(patch_size=448*10):
    main_path='/nfs2/wzh/Grading/'
    liver_name_list = [file for file in os.listdir(main_path) if not file.endswith('.mrxs')][::-1]
    for idx, liver_name in enumerate(liver_name_list):
        print(liver_name)
        cancer_grad = -1
        if liver_name[0] == '1':
            if int(liver_name[-2:]) < 20:
                cancer_grad = 1
            else:
                cancer_grad = 2
        elif liver_name[0] == '2':
            cancer_grad = 2
        elif liver_name[0] == '3':
            if liver_name == '3M01':
                cancer_grad = 2
            else:
                cancer_grad = 3
        elif liver_name[0] == '4':
            if int(liver_name[-2:]) < 6:
                cancer_grad = 3
            else:
                cancer_grad = 4
        elif liver_name[0] == '5':
            cancer_grad = 4
        
        if cancer_grad == -1:
            print('LIVER NAME ERROR:', liver_name)
        split_liver_patches(liver_name, main_path='/nfs2/wzh/Grading/', cancer_grad=cancer_grad, m=patch_size)
        print('done {}'.format(idx))
    '''
    main_path='/home/disk1/yuxiaotian/HCC_data/HCC/'
    total_num = 0
    for grade in range(1,5):
        liver_name_list = [file for file in os.listdir(main_path+'{}/'.format(grade)) if not file.endswith('.mrxs')]
        for liver_name in liver_name_list:
            print('\r{}: done'.format(total_num), end='')
            total_num += 1
            split_liver_patches(liver_name, main_path=main_path+'{}/'.format(grade), cancer_grad=grade, m=patch_size)
    '''

    
# whole patch to cv dataset
def get_cv_dataset_info(whole_path='/nfs/yuxiaotian/HCC/HCC_whole_patch_5x/', select_file_num=None):
    whole_file_list = [os.listdir(whole_path+'{}/'.format(i)) for i in range(5)]
    whole_file_dict = [{} for i in range(5)]
    file_num_dict = [{} for i in range(5)]
    for i in range(5):
        for file in whole_file_list[i]:
            name = file.split('_')[0]
            if name in whole_file_dict[i]:
                whole_file_dict[i][name].append(file)
                file_num_dict[i][name] += 1
            else:
                whole_file_dict[i][name] = [file]
                file_num_dict[i][name] = 1
    max_sum = 0
    best_file_num = 0
    file_num_arr = np.array(list(file_num_dict[0].values()))
    for file_num in range(1, max(file_num_arr)):
        amount = (file_num_arr>=file_num).sum()
        if amount*file_num > max_sum:
            max_sum = amount*file_num
            best_file_num = file_num
    print('each file num:{0}\ttotal num:{1}'.format(best_file_num, max_sum))
    
    for i in range(5):
        file_num_arr = np.array(list(file_num_dict[i].values()))
        if select_file_num is None:
            print((file_num_arr>=best_file_num).sum(), end=' ')
        else:
            print((file_num_arr>=select_file_num).sum(), end=' ')
    return whole_file_dict


def cv_dataset(num_list, file_num, whole_file_dict, whole_path='/nfs/yuxiaotian/HCC/HCC_whole_patch_5x/', target_path='/nfs/yuxiaotian/HCC/HCC_binary_CV_5x/'):
    for i in range(5):
        n = 0
        label = 0 if i==0 else 1
        for name in whole_file_dict[i]:
            if len(whole_file_dict[i][name])>=file_num:
                # filter blank image
                blank_rate_dict = {}
                for file in whole_file_dict[i][name]:
                    img = cv2.imread(whole_path+'{}/'.format(i)+file)
                    blank_rate = (np.sum(img<220, axis=-1)==0).sum() / (img.shape[0] * img.shape[1])
                    blank_rate_dict[file] = blank_rate
                for file in [item[0] for item in sorted(blank_rate_dict.items(), key=lambda x:x[1], reverse=False)][:file_num]:
                    shutil.copy(whole_path+'{}/'.format(i)+file, target_path+'{}/'.format(label)+file)
                n+=1
                if n == num_list[i]:
                    break
            
            
# patch transfering
def patch_transfer(main_path_list=['/home/disk1/yuxiaotian/HCC_data/HCC/', '/nfs2/wzh/Grading/', '/nfs/yuxiaotian/HCC/slide_0/'], source_path='/nfs/yuxiaotian/HCC/HCC_screened_grading_50x/test/', target_path='/nfs/yuxiaotian/HCC/HCC_screened_grading_5x/test/'):
    import openslide
    origin_scale = 50
    source_scale = 50 # int(source_path.split('_')[-1][:-2])
    target_scale = 5#t(target_path.split('_')[-1][:-2])
    source_patch_size  = (448 * 50) // source_scale
    patch_size = (448 * 50) // target_scale
    
    source_file_list_0 = os.listdir(source_path+'{}/'.format(0))
    source_file_list_1 = os.listdir(source_path+'{}/'.format(1))
    target_file_list = [os.listdir(target_path+'{}/'.format(i)) for i in range(5)]
    hcc_name = ''
    mrxs_list_0 = [os.listdir(main_path_list[0]+'{}/'.format(i)) for i in range(1,5)]
    for label in range(5):
        for file in os.listdir(source_path+'{}/'.format(label)):
            if file in target_file_list[label]:
                continue
            if file.split('_')[0] != hcc_name:
                hcc_name = file.split('_')[0]
                if hcc_name[0] == '0' and len(hcc_name) == 5:
                    main_path = main_path_list[2]
                    liver_file = main_path + '{}.mrxs'.format(hcc_name[1:])
                elif len(hcc_name) > 4:
                    main_path = main_path_list[0]
                    for i in range(1,5):
                        if hcc_name+'.mrxs' in mrxs_list_0[i-1]:
                            mrxs_label = i
                    liver_file = main_path + '{0}/{1}.mrxs'.format(mrxs_label, hcc_name)
                else:
                    main_path = main_path_list[1]
                    liver_file = main_path + '{}.mrxs'.format(hcc_name)
                slide = openslide.open_slide(liver_file)
            y,x = [int(item[1:]) for item in file[:-4].split('_')[1:3]]
            x_ = x + (source_patch_size-patch_size) // 2
            y_ = y + (source_patch_size-patch_size) // 2
            region = slide.read_region((x_, y_), 0, (patch_size, patch_size)).convert('RGB')
            # contain cell
            region_f = np.array(region).reshape(-1, 3)
            mean_rgb = np.mean(region_f, axis=1)
            if (np.mean(np.abs((region_f.transpose(1,0) - (mean_rgb, mean_rgb, mean_rgb))), axis=0)<10).sum() / mean_rgb.shape[0] > 0.8:
                #continue
                None
            cv2.imwrite(target_path+'{0}/{1}'.format(label, file), cv2.resize(np.array(region)[:,:,::-1], (448, 448)))


# liver sample screening
def sample_screening(source_path='/nfs/yuxiaotian/grading/liver_whole_patch/', target_path='/nfs/yuxiaotian/grading/liver_maxvar_patch/', checkpoint_path='checkpoints/liver_ae_flatten.pth', patch_per_img=100):
    from ae import Autoencoder
    transform=transforms.Compose([
        transforms.ToTensor()])
    
    def cos_similarity(a, b):
        a_norm = np.linalg.norm(a)
        b_norm = np.linalg.norm(b)
        sim = (a.reshape(1, -1) @ b.reshape(-1, 1)) / (a_norm*b_norm)
        return sim.flatten()[0]
    
    with torch.no_grad():
        model = Autoencoder().cuda()
        model.load_state_dict(torch.load(checkpoint_path))
        for cls_idx in range(5):
            whole_patch_list = os.listdir(source_path+'{}/'.format(cls_idx))
            whole_name_list = list(set([item[:4] for item in whole_patch_list]))
            for name in whole_name_list:
                patch_file_list = [source_path+'{}/'.format(cls_idx)+item for item in whole_patch_list if item[:len(name)]==name]
                print('name:{0}, num:{1}'.format(name, len(patch_file_list)))
                select_file_list = []
                select_fea_list = []
                max_sim_list = []
                for file in patch_file_list:
                    img = transform(Image.open(file)).unsqueeze(0)
                    fea = model.encoder(img.cuda()).reshape(-1).detach().cpu().numpy() # size: 7*7*64=3136
                    if len(select_file_list) < patch_per_img:
                        select_file_list.append(file)
                        select_fea_list.append(fea)
                        if len(select_file_list) == patch_per_img:
                            for i,fea_i in enumerate(select_fea_list):
                                sim_list = [cos_similarity(fea_i, fea_j) for j,fea_j in enumerate(select_fea_list) if i!=j]
                                max_sim_list.append(max(sim_list))
                    else:
                        sim_list = [cos_similarity(fea, fea_j) for j,fea_j in enumerate(select_fea_list)]
                        max_max_idx = np.argmax(max_sim_list)
                        sim_list[max_max_idx] = 0
                        if max(sim_list) < max(max_sim_list):
                            max_sim_list[max_max_idx] = max(sim_list)
                            select_file_list[max_max_idx] = file
                            select_fea_list[max_max_idx] = fea
                for file in select_file_list:
                    shutil.copyfile(file, target_path+'{}/'.format(cls_idx)+file[len(source_path)+2:])

# feature assemble 
'''
from ae import AutoEncoder

def get_patch_list(patch_list='/nfs/yuxiaotian/Gleason/normal&cancer/semi_pure_binary_loc/'):
    patch_file_list = []
    for cv_idx in range(10):
        for file_type in ['train/', 'test/', 'valid/']:
            patch_file_list += os.listdir(patch_list+'cv_{}/'.format(cv_idx)+file_type+'0/')
            patch_file_list += os.listdir(patch_list+'cv_{}/'.format(cv_idx)+file_type+'1/')
    patch_file_list = ['_'.join(item.split('_')[:4]) for item in patch_file_list]
    patch_file_list = list(set(patch_file_list))
    return patch_file_list


def check_max_around(patch_size=448, around_size=448, img_path='data/Train Imgs/'):
    patch_file_list = get_patch_list()
    img_file_dict = {}
    for item in patch_file_list:
        img_file = item[:16]
        y, x = [int(pos[1:]) for pos in item.split('_')[-2:]]
        if img_file in img_file_dict:
            img_file_dict[img_file].append([y, x])
        else:
            img_file_dict[img_file] = [[y,x]]
    MIN_BORDER = 10000
    for img_file in img_file_dict.keys():
        img = cv2.imread(img_path+img_file+'.jpg')
        img_y, img_x = img.shape[:2]
        for (y, x) in img_file_dict[img_file]:
            min_border = min(x, y, img_x-x-patch_size, img_y-y-patch_size)
            if min_border < MIN_BORDER:
                MIN_BORDER = min_border
                print(min_border, img_file, x, y)
    return MIN_BORDER // around_size
    

def get_around_patch(img, x_, y_, around_size=448, data_type=np.float):
    img_patch = np.ones((around_size, around_size, 3), dtype=data_type)
    max_y, max_x = img.shape[:2]
    if y_+around_size <= 0 or y_ >= max_y:
        None
    elif y_ < 0 and y_+around_size > 0:
        if x_+around_size <= 0 or x_ >= max_x:
            None
        elif x_ < 0:
            img_patch[-y_:, -x_:] = img[:y_+around_size, :x_+around_size]
        elif x_+around_size > max_x:
            img_patch[-y_:, :max_x-x_] = img[:y_+around_size, x_:]
        elif x_ >= 0 and x_+around_size <= max_x:
            img_patch[-y_:, :] = img[:y_+around_size, x_:x_+around_size]
    elif y_ >= 0 and y_+around_size <= max_y:
        if x_+around_size <= 0 or x_ >= max_x:
            None
        elif x_ < 0:
            img_patch[:, -x_:] = img[y_:y_+around_size, :x_+around_size]
        elif x_+around_size > max_x:
            img_patch[:, :max_x-x_] = img[y_:y_+around_size, x_:]
    elif y_+around_size > max_y:
        if x_+around_size <= 0 or x_ >= max_x:
            None
        elif x_ < 0:
            img_patch[:max_y-y_, -x_:] = img[y_:, :x_+around_size]
        elif x_+around_size > max_x:
            img_patch[:max_y-y_, :max_x-x_] = img[y_:, x_:]
        elif x_ >= 0 and x_+around_size <= max_x:
            img_patch[:max_y-y_, :] = img[y_:, x_:x_+around_size]
    return img_patch

    
def patch_around_feature(patch_size=448, around_size=448, n_around=5, checkpoint_path='checkpoints/vae_0_.pth', img_path='data/Train Imgs/', save_file='/nfs/yuxiaotian/Gleason/normal&cancer/binary_feature_polymer/'):
    model = Autoencoder().cuda()
    model.load_state_dict(torch.load(checkpoint_path))
    fea_channal = 256
    fea_size = 7
    
    patch_file_list = get_patch_list()
    patch_file_list.sort()
    img_file_dict = {}
    for item in patch_file_list:
        img_file = item[:16]
        y, x = [int(pos[1:]) for pos in item.split('_')[-2:]]
        if img_file in img_file_dict:
            img_file_dict[img_file].append([y, x])
        else:
            img_file_dict[img_file] = [[y,x]]
    
    for file_idx, img_file in enumerate(img_file_dict.keys()):
        img = cv2.imread(img_path+img_file+'.jpg')/255
        img = img.astype('float32')
        for (y, x) in img_file_dict[img_file]:
            around_img = []
            start_y = y-((n_around*around_size//2)-patch_size)
            start_x = x-((n_around*around_size//2)-patch_size)
            end_y = start_y + n_around*around_size
            end_x = start_x + n_around*around_size
            
            for x_ in range(start_x, end_x, around_size):
                for y_ in range(start_y, end_y, around_size):
                    img_patch = img[y_:y_+around_size, x_:x_+around_size, :]
                    # n_around = 5
                    if img_patch.shape != (around_size, around_size, 3):
                        img_patch = get_around_patch(img, x_, y_, around_size=around_size, data_type=img_patch.dtype)
                    img_patch = cv2.resize(img_patch, (224, 224))
                    around_img.append(img_patch)
            around_img = torch.tensor(np.array(around_img).transpose((0, 3, 1, 2))).cuda() # size: (5*5, 3, 224, 224)
            with torch.no_grad():
                around_fea = model.encoder(around_img).reshape(n_around, n_around, fea_channal, fea_size, fea_size) # size: (5, 5, 256, 7, 7)
                around_fea = around_fea.cpu().numpy().transpose((0, 3, 1, 4, 2)).reshape((n_around*fea_size, n_around*fea_size, fea_channal)) # size: (35, 35, 256)
            np.save('{0}{1}_Y{2}_X{3}.png'.format(save_file, img_file, y, x), around_fea)
        print('\r{0}/{1}'.format(file_idx, len(img_file_dict)), end='')
'''

# liver
def get_patch_list(patch_list='/nfs/yuxiaotian/HCC/HCC_binary_CV/', n_class=2):
    patch_file_list = []
    for i in range(n_class):
        patch_file_list += [file[:-4] for file in os.listdir(patch_list+'{}/'.format(i))]
    patch_file_list = list(set(patch_file_list))
    return patch_file_list


def check_max_around(patch_size=448, around_size=448, img_path='/nfs2/wzh/Grading/'):
    patch_file_list = get_patch_list()
    img_file_dict = {}
    for item in patch_file_list:
        img_file = item[:4]
        y, x = [int(pos[1:]) for pos in item.split('_')[-2:]]
        if img_file in img_file_dict:
            img_file_dict[img_file].append([y, x])
        else:
            img_file_dict[img_file] = [[y,x]]
    MIN_BORDER = 10000
    for img_file in img_file_dict.keys():
        slide = openslide.open_slide(img_path+img_file+'.mrxs')
        img_x, img_y = slide.level_dimensions[0]
        for (y, x) in img_file_dict[img_file]:
            min_border = min(x, y, img_x-x-patch_size, img_y-y-patch_size)
            if min_border < MIN_BORDER:
                MIN_BORDER = min_border
                print(min_border, img_file, x, y)
    return MIN_BORDER // around_size
    

def get_around_patch(slide, x_, y_, around_size=448, data_type=np.float):
    img_patch = np.ones((around_size, around_size, 3), dtype=data_type)
    max_y, max_x = img.shape[:2]
    if y_+around_size <= 0 or y_ >= max_y:
        None
    elif y_ < 0 and y_+around_size > 0:
        if x_+around_size <= 0 or x_ >= max_x:
            None
        elif x_ < 0:
            img_patch[-y_:, -x_:] = np.array(slide.read_region((0, 0) ,0, (x_+around_size, y_+around_size)).convert('RGB'))
        elif x_+around_size > max_x:
            img_patch[-y_:, :max_x-x_] = np.array(slide.read_region((x_, 0) ,0, (max_x-x_, y_+around_size)).convert('RGB'))
        elif x_ >= 0 and x_+around_size <= max_x:
            img_patch[-y_:, :] = np.array(slide.read_region((x_, 0) ,0, (around_size, y_+around_size)).convert('RGB'))
    elif y_ >= 0 and y_+around_size <= max_y:
        if x_+around_size <= 0 or x_ >= max_x:
            None
        elif x_ < 0:
            img_patch[:, -x_:] = np.array(slide.read_region((0, y_) ,0, (x_+around_size, around_size)).convert('RGB'))
        elif x_+around_size > max_x:
            img_patch[:, :max_x-x_] = np.array(slide.read_region((x_, y_) ,0, (max_x-x_, around_size)).convert('RGB'))
    elif y_+around_size > max_y:
        if x_+around_size <= 0 or x_ >= max_x:
            None
        elif x_ < 0:
            img_patch[:max_y-y_, -x_:] = np.array(slide.read_region((0, y_) ,0, (x_+around_size, max_y-y_)).convert('RGB'))
        elif x_+around_size > max_x:
            img_patch[:max_y-y_, :max_x-x_] = np.array(slide.read_region((x_, y_) ,0, (max_x-x_, max_y-y_)).convert('RGB'))
        elif x_ >= 0 and x_+around_size <= max_x:
            img_patch[:max_y-y_, :] = np.array(slide.read_region((x_, y_) ,0, (around_size, max_y-y_)).convert('RGB'))
    return img_patch


def patch_around_feature(patch_size=448, around_size=448, n_around=9, checkpoint_path='checkpoints/hcc_grading_ae.pth', img_path_list=['/nfs/yuxiaotian/HCC/slide_0/', '/nfs2/wzh/Grading/', '/home/disk1/yuxiaotian/HCC_data/HCC/'], save_file='/nfs/yuxiaotian/HCC/HCC_AF/'):    
    import openslide
    from ae import Autoencoder
    import time
    
    #torch.cuda.set_device(0)
    t0 = time.time()
    model = Autoencoder().cuda()
    model.load_state_dict(torch.load(checkpoint_path))
    model.eval()
    fea_channal = 64
    fea_size = 7
    
    patch_file_list = get_patch_list(patch_list='/nfs/yuxiaotian/HCC/A2B_binary_CV_screen_50x/', n_class=5)
    patch_file_list.sort()
    img_file_dict = {}
    af_file_list = os.listdir(save_file)
    for file in patch_file_list:
        #if not file in [file[:-4] for file in af_file_list]:
        img_file = file.split('_')[0]
        y, x = [int(pos[1:]) for pos in file.split('_')[-2:]]
        if img_file in img_file_dict:
            img_file_dict[img_file].append([y, x])
        else:
            img_file_dict[img_file] = [[y,x]]
    
    total_idx = -1
    for img_file in img_file_dict.keys():            
        total_idx += 1
        print('\r{0}/{1} continue. '.format(total_idx, len(img_file_dict)), end='')
        print(img_file)
        if os.path.exists(img_path_list[0] + img_file[1:] + '.mrxs'):
            slide = openslide.open_slide(img_path_list[0] + img_file[1:] + '.mrxs')
        elif os.path.exists(img_path_list[1] + img_file + '.mrxs'):
            slide = openslide.open_slide(img_path_list[1] + img_file + '.mrxs')
        else:
            for i in range(1,5):
                if os.path.exists(img_path_list[2] + '{}/'.format(i) + img_file + '.mrxs'):
                    slide = openslide.open_slide(img_path_list[2] + '{}/'.format(i) + img_file + '.mrxs')
                    break
            
        exist_flag = True
        for img_idx, (y, x) in enumerate(img_file_dict[img_file]):
            if not os.path.exists('{0}{1}_Y{2}_X{3}.npy'.format(save_file, img_file, y, x)):
                exist_flag = False
        if exist_flag:
            None#continue
        
        if n_around%2 == 0:
            y_min = np.min(np.array(img_file_dict[img_file])[:,0])-(n_around//2*around_size-patch_size//2)
            y_max = np.max(np.array(img_file_dict[img_file])[:,0])+((n_around//2)*around_size+patch_size//2)
            x_min = np.min(np.array(img_file_dict[img_file])[:,1])-(n_around//2*around_size-patch_size//2)
            x_max = np.max(np.array(img_file_dict[img_file])[:,1])+((n_around//2)*around_size+patch_size//2)
        else:
            y_min = np.min(np.array(img_file_dict[img_file])[:,0])-(n_around//2*around_size)
            y_max = np.max(np.array(img_file_dict[img_file])[:,0])+((n_around//2+1)*around_size)
            x_min = np.min(np.array(img_file_dict[img_file])[:,1])-(n_around//2*around_size)
            x_max = np.max(np.array(img_file_dict[img_file])[:,1])+((n_around//2+1)*around_size)
        whole_feature_polymer = np.zeros((fea_channal, ((y_max-y_min)//patch_size)*fea_size, ((x_max-x_min)//patch_size)*fea_size), dtype=np.float)

        for x_idx, x in enumerate(range(x_min, x_max, around_size)):
            for y_idx, y in enumerate(range(y_min, y_max, around_size)):
                img_patch = np.array(slide.read_region((x, y) ,0, (around_size, around_size)).convert('RGB'))
                if img_patch.shape != (around_size, around_size, 3):
                    img_patch = get_around_patch(slide, x, y, around_size=around_size, data_type=img_patch.dtype)
                img_patch = img_patch/255.
                img_patch = torch.tensor(img_patch.transpose((2, 0, 1))).unsqueeze(0).float().cuda()
                with torch.no_grad():
                    fea = model.encoder(img_patch).reshape(fea_channal, fea_size, fea_size) # size: (64, 7, 7)
                    whole_feature_polymer[:, y_idx*fea_size:(y_idx+1)*fea_size, x_idx*fea_size:(x_idx+1)*fea_size] = fea.detach().cpu().numpy()
        print('{0}/{1}: polymer done. '.format(total_idx, len(img_file_dict)), end='')

        for img_idx, (y, x) in enumerate(img_file_dict[img_file]):
            start_y_idx = (y-(n_around//2*around_size-patch_size//2)-y_min)//around_size if n_around%2 == 0 else (y-(n_around//2*around_size)-y_min)//around_size
            start_x_idx = (x-(n_around//2*around_size-patch_size//2)-x_min)//around_size if n_around%2 == 0 else (x-(n_around//2*around_size)-x_min)//around_size
            end_y_idx = start_y_idx + n_around
            end_x_idx = start_x_idx + n_around
            around_fea = whole_feature_polymer[:, start_y_idx*fea_size:end_y_idx*fea_size, start_x_idx*fea_size:end_x_idx*fea_size]
            np.save('{0}{1}_Y{2}_X{3}.npy'.format(save_file, img_file, y, x), around_fea)
            
        t1 = time.time()
        print('{0:.4f}'.format(t1-t0), end='')


def whole_slide_feature(patch_size=448, around_size=448, checkpoint_path='checkpoints/hcc_grading_ae.pth', img_path_list=['/nfs/yuxiaotian/HCC/slide_0/', '/nfs2/wzh/Grading/', '/nfs3-p2/yuxiaotian/HCC_data/HCC/', '/nfs3-p2/yuxiaotian/CAMELYON16/training/tumor/'], save_file='/nfs3-p2/yuxiaotian/CAMELYON16/training/FP/'):    
    import openslide
    from ae import Autoencoder
    import time

    #torch.cuda.set_device(0)
    t0 = time.time()
    model = Autoencoder().cuda()
    model.load_state_dict(torch.load(checkpoint_path))
    model.eval()
    fea_channal = 64
    fea_size = 7

    img_file_list = os.listdir('/nfs3-p2/yuxiaotian/CAMELYON16/training/tumor')[13:]
    for total_idx, img_file in enumerate(img_file_list):
        if img_file.endswith('.tif'):
            slide = openslide.open_slide(img_path_list[3]+img_file)
        elif os.path.exists(img_path_list[0] + img_file[1:] + '.mrxs'):
            slide = openslide.open_slide(img_path_list[0] + img_file[1:] + '.mrxs')
        elif os.path.exists(img_path_list[1] + img_file + '.mrxs'):
            slide = openslide.open_slide(img_path_list[1] + img_file + '.mrxs')
        else:
            for i in range(1,5):
                if os.path.exists(img_path_list[2] + '{}/'.format(i) + img_file + '.mrxs'):
                    print(img_path_list[2] + '{}/'.format(i) + img_file + '.mrxs')
                    slide = openslide.open_slide(img_path_list[2] + '{}/'.format(i) + img_file + '.mrxs')
                    break
        X, Y = slide.level_dimensions[0]
        whole_feature_polymer = np.zeros((fea_channal, (Y//patch_size+1)*fea_size, (X//patch_size+1)*fea_size), dtype=np.float)

        for x_idx, x in enumerate(range(0, X, around_size)):
            for y_idx, y in enumerate(range(0, Y, around_size)):
                img_patch = np.array(slide.read_region((x, y) ,0, (around_size, around_size)).convert('RGB'))
                if img_patch.shape != (around_size, around_size, 3):
                    img_patch = get_around_patch(slide, x, y, around_size=around_size, data_type=img_patch.dtype)
                img_patch = img_patch/255.
                img_patch = torch.tensor(img_patch.transpose((2, 0, 1))).unsqueeze(0).float().cuda()
                with torch.no_grad():
                    fea = model.encoder(img_patch).reshape(fea_channal, fea_size, fea_size) # size: (64, 7, 7)
                    whole_feature_polymer[:, y_idx*fea_size:(y_idx+1)*fea_size, x_idx*fea_size:(x_idx+1)*fea_size] = fea.detach().cpu().numpy()
        print('{0}/{1}: polymer done. '.format(total_idx, len(img_file_list)), end='')

        np.save('{0}FP_{1}.npy'.format(save_file, img_file), whole_feature_polymer)
        t1 = time.time()
        print('{0:.4f}'.format(t1-t0), end='')


def sim_map(patch_size=448, around_size=448, n_around=32, checkpoint_path='checkpoints/liver_ae.pth', patch_path='/nfs/yuxiaotian/HCC/HCC_binary_CV/', patch_bank_paths=['/nfs/yuxiaotian/HCC/HCC_select_patch/', '/nfs/yuxiaotian/grading/liver_whole_patch/'], save_file='/nfs/yuxiaotian/HCC/HCC_AF/'):    
    import time
    from ae import Autoencoder
    from dataset import LiverNeighborDataset
    from torch.autograd import Variable

    #torch.cuda.set_device(0)
    t0 = time.time()
    model = Autoencoder().cuda()
    model.load_state_dict(torch.load(checkpoint_path))
    model.eval()
    fea_channal = 64
    fea_size = 7
    
    transform = transforms.Compose([transforms.Resize(448),transforms.ToTensor()])
    datasets = LiverNeighborDataset(patch_path, n_class=2,
                            transform=transform
                            )
    dataloader = torch.utils.data.DataLoader(dataset=datasets, batch_size=64, shuffle=True)
    
    reverse_dict = {}
    for i in range(5):
        list_1 = os.listdir(patch_bank_paths[0]+'{}/'.format(i))
        list_2 = os.listdir(patch_bank_paths[1]+'{}/'.format(i))
        for file in list_1+list_2:
            path = patch_bank_paths[0]+'{}/'.format(i) if file in list_1 else patch_bank_paths[1]+'{}/'.format(i)
            name = file.split('_')[0]
            if name not in reverse_dict:
                reverse_dict[name] = [[], []]
            idx = 1 if i==0 else 0
            reverse_dict[name][idx].append(path+file)
    print('start')
    with torch.no_grad():
        for batch_idx, data in enumerate(dataloader):
            image, label_list, _, file_list = data
            image = Variable(image.float()).cuda()
            label_list = Variable(label_list).cuda()
            
            #fea_list = model.encoder(image).reshape(-1, fea_channal, fea_size, fea_size).detach().cpu().numpy() # size: (N, 64, 7, 7)
            for idx, file in enumerate(file_list):
                label = label_list[idx]
                #fea = fea_list[idx]
                file_name = file_list[idx]
                if os.path.exists(save_file+'r_'+file_name.split('/')[-1]):
                    None
                if not os.path.exists(save_file+file.split('/')[-1][:-4]+'.npy'):
                    print(save_file+file.split('/')[-1][:-4]+'.npy')
                    continue
                sim_mat = np.zeros((n_around, n_around), dtype=np.float)
                
                #reverse
                if len(reverse_dict[file.split('_')[0]][label]) > 0:
                    file_r = np.random.choice(reverse_dict[file.split('_')[0]][label], 1)[0]
                    fea = model.encoder(Variable(transform(Image.open(file_r).convert('RGB')).float()).reshape(1,3,448,448).cuda()).reshape(-1, fea_channal, fea_size, fea_size).detach().cpu().numpy()
                    af = np.load(save_file+file.split('/')[-1][:-4]+'.npy') # size: (64, 224, 224)
                    for i in range(n_around):
                        for j in range(n_around):
                            af_ij = af[:, i*fea_size:(i+1)*fea_size, j*fea_size:(j+1)*fea_size]
                            sim_mat[i,j] = (af_ij.flatten() * fea.flatten()).sum() / np.linalg.norm(af_ij) / np.linalg.norm(fea)
                    cv2.imwrite(save_file+'r_'+file_name.split('/')[-1], sim_mat*255)
            print('\r{0}/{1}'.format(batch_idx, datasets.__len__()//64))
    
    
def class0_transfer(target, main_path='/nfs/yuxiaotian/HCC/HCC_screened_grading_50x/'):
    if not target in ['cancer', 'normal']:
        print('ERROR TARGET')
        return
    for dataset in ['train', 'valid', 'test']:
        exist_set = os.listdir(main_path+dataset+'/0/')
        for i,file in enumerate(exist_set):
            os.remove(main_path+dataset+'/0/'+file)
            print('\rdeleting for {2} set: {0}/{1}'.format(i, len(exist_set), dataset), end='')
        print()
        target_set = os.listdir(main_path+dataset+'_0/'+target)
        for i,file in enumerate(target_set):
            shutil.copy(main_path+dataset+'_0/'+target+'/'+file, main_path+dataset+'/0/'+file)
            print('\rcopying for {2} set: {0}/{1}'.format(i, len(exist_set), dataset), end='')
        print()
    

def slide_level_acc(gt_file, pred_file, cancer_target=255, normal_target=15):
    pred = np.array(Image.open(pred_file))
    gt = np.array(Image.open(gt_file).resize(pred.shape[::-1]))
    gt[(pred>0)&(gt<128)] = normal_target
    gt[(pred>0)&(gt>=128)] = cancer_target
    
    tp = ((gt==cancer_target) & (pred==gt)).sum()
    tn = ((gt==normal_target) & (pred==gt)).sum()
    fp = ((gt==normal_target) & (pred!=gt)).sum()
    fn = ((gt==cancer_target) & (pred!=gt)).sum()
    
    acc = ((gt>0) & (pred==gt)).sum() / (gt>0).sum()
    sen = tp / (tp+fn)
    spe = tn / (tn+fp)
    
    return acc, sen, spe
    
if __name__ == '__main__':
    with open('/nfs3-p2/yuxiaotian/PANDA/train.csv') as f:
        f = f.readlines()
        name_list = []
        for line in f[1:]:
            name, datatype = line.split(',')[:2]
            if datatype == 'radboud':
                name_list.append(name)
    #name_list = os.listdir('/nfs3-p2/yuxiaotian/PANDA/train_images/')
    exist_name_list = set([name_list.index(file.split('_')[0]) for file in os.listdir('/nfs3-p2/yuxiaotian/PANDA/train_region/')])
    for i in range(max(exist_name_list)+1, len(name_list)):
        name = name_list[i]
        split_PANDA_patches(name)
