import os
import torch
from torch.utils.data import Dataset
from PIL import Image
import random
import numpy as np
import cv2

Image.MAX_IMAGE_PIXELS = 3600000000

PATCH_EDGE = 5
MAX_PATCH_NUM = PATCH_EDGE ** 2
PATCH_SIZE = 256
SPLIT_EDGE = 4
MASK_EDGE = PATCH_EDGE*SPLIT_EDGE


def graph_coll_fn(batch):
    res = []
    for i, elems in enumerate(zip(*batch)):
        if i == 0 or i == 1: #image, mask
            res.append(torch.stack(elems, dim=0))
        elif i == 2 or i == 4: #link, file_name
            res.append(elems)
        elif i == 3: #label
            res.append(torch.tensor(elems, dtype=torch.long))
        else:
            raise ValueError(f'graph_coll_fn: i=={i}')
    return res


def no_link_coll_fn(batch):
    res = []
    for i, elems in enumerate(zip(*batch)):
        if i == 0 or i == 1: #image, mask
            res.append(torch.stack(elems, dim=0))
        elif i == 2: #label
            res.append(torch.tensor(elems, dtype=torch.long))
        elif i == 3: #file_name
            res.append(elems)
        else:
            raise ValueError(f'graph_coll_fn: i=={i}')
    return res


def get_weight_list(dataset):
    pos_list = []
    neg_list = []
    for data in dataset.datas:
        img, label = data[0], data[-1]
        if label == 0:
            neg_list.append(img)
        elif label == 1:
            pos_list.append(img)
        else:
            raise ValueError(f'{img}, {label}')
    
    pos_num, neg_num = len(pos_list), len(neg_list)
    weight_list = []
    for data in dataset.datas:
        img, label = data[0], data[-1]
        if label == 0:
            weight_list.append(pos_num)
        elif label == 1:
            weight_list.append(neg_num)
        else:
            raise ValueError(f'{img}, {label}')
    
    return weight_list


class PatchGraphDataset(Dataset):
    def __init__(self, root, cate_list, random, transform, link='link', fold=None, test=False):
        self.random = random
        self.transform = transform
        self.test = test
        self.datas = []
        self.patch_nums = []
        self.patchs = []
        self.patch_index_list = list(range(MAX_PATCH_NUM))
        cate_list = [str(c) for c in cate_list]
        fold_list = [''] if fold is None else [str(f) for f in fold]
        for i, sub_dir in enumerate(cate_list):
            sub_backup = sub_dir
            for f in fold_list:
                sub_dir = os.path.join(sub_backup, f)
                names = sorted(os.listdir(os.path.join(root, sub_dir, 'image')))
                for name in names:
                    if not os.path.exists(os.path.join(root, sub_dir, link, name)):
                        raise ValueError(f'link not exist: {name}')
                    patch_list = sorted(map(lambda x: os.path.splitext(x)[0], os.listdir(os.path.join(root, sub_dir, 'image', name))))
                    patch_num = len(patch_list)
                    self.datas.append( (os.path.join(root, sub_dir, 'image', name), os.path.join(root, sub_dir, link, name), i) )
                    self.patch_nums.append(patch_num)
                    self.patchs.append(patch_list)

        self.lut = torch.zeros(PATCH_EDGE, PATCH_EDGE, SPLIT_EDGE**2, dtype=torch.long)
        pi = torch.arange(SPLIT_EDGE).repeat_interleave(SPLIT_EDGE)
        pj = torch.arange(SPLIT_EDGE).repeat(SPLIT_EDGE)
        for i in range(PATCH_EDGE):
            ti = i * SPLIT_EDGE
            for j in range(PATCH_EDGE):
                tj = j * SPLIT_EDGE
                tpi = pi + ti
                tpj = pj + tj
                self.lut[i, j] = tpi * PATCH_EDGE * SPLIT_EDGE + tpj

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, index):
        label = self.datas[index][2]
        if self.random:
            patchs = random.sample(self.patchs[index], min(self.patch_nums[index], MAX_PATCH_NUM))
            patch_index = random.sample(self.patch_index_list, min(self.patch_nums[index], MAX_PATCH_NUM))
        else:
            patchs = self.patchs[index][:min(self.patch_nums[index], MAX_PATCH_NUM)]
            patch_index = list(range(min(self.patch_nums[index], MAX_PATCH_NUM)))
        img = np.ones((PATCH_EDGE*PATCH_SIZE, PATCH_EDGE*PATCH_SIZE, 3), dtype=np.uint8) * 241
        mask = np.zeros((MASK_EDGE, MASK_EDGE), dtype=np.uint8)
        link = []
        for patch, idx in zip(patchs, patch_index):
            p_img = np.array(Image.open(os.path.join(self.datas[index][0], patch+'.png')))
            p_mask = np.array(Image.open(os.path.join(self.datas[index][1], patch+'.png')))
            p_link = torch.load(os.path.join(self.datas[index][1], patch+'.pth'))
            j, i = idx % PATCH_EDGE, idx // PATCH_EDGE
            x, y = j * PATCH_SIZE, i * PATCH_SIZE
            img[y:y+PATCH_SIZE, x:x+PATCH_SIZE] = p_img
            x, y = j * SPLIT_EDGE, i * SPLIT_EDGE
            mask[y:y+SPLIT_EDGE, x:x+SPLIT_EDGE] = p_mask
            p_link_res = p_link.clone()
            for value in torch.unique(p_link):
                p_link_res[p_link==value] = self.lut[i, j, value]
            link.append(p_link_res)
        link = torch.cat(link, dim=1)
        if mask.max() == 0 and random.random() < 0.5:
            img = np.ones_like(img) * random.randint(0, 200)
        img, mask, link = self.transform([Image.fromarray(img), Image.fromarray(mask), link])

        if self.test:
            return img, mask, link, label, os.path.basename(self.datas[index][0])
        else:
            return img, mask, link, label


class PatchGraphClsDataset(Dataset):
    def __init__(self, root, cate_list, random, transform, fold=None):
        self.random = random
        self.transform = transform
        self.datas = []
        self.patch_nums = []
        self.patchs = []
        self.patch_index_list = list(range(MAX_PATCH_NUM))
        cate_list = [str(c) for c in cate_list]
        fold_list = [''] if fold is None else [str(f) for f in fold]
        for i, sub_dir in enumerate(cate_list):
            sub_backup = sub_dir
            for f in fold_list:
                sub_dir = os.path.join(sub_backup, f)
                names = sorted(os.listdir(os.path.join(root, sub_dir, 'image')))
                for name in names:
                    patch_list = sorted(map(lambda x: os.path.splitext(x)[0], os.listdir(os.path.join(root, sub_dir, 'image', name))))
                    patch_num = len(patch_list)
                    self.datas.append( (os.path.join(root, sub_dir, 'image', name), i) )
                    self.patch_nums.append(patch_num)
                    self.patchs.append(patch_list)

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, index):
        label = self.datas[index][1]
        if self.random:
            patchs = random.sample(self.patchs[index], min(self.patch_nums[index], MAX_PATCH_NUM))
            patch_index = random.sample(self.patch_index_list, min(self.patch_nums[index], MAX_PATCH_NUM))
        else:
            patchs = self.patchs[index][:min(self.patch_nums[index], MAX_PATCH_NUM)]
            patch_index = list(range(min(self.patch_nums[index], MAX_PATCH_NUM)))
        img = np.ones((PATCH_EDGE*PATCH_SIZE, PATCH_EDGE*PATCH_SIZE, 3), dtype=np.uint8) * 241
        for patch, idx in zip(patchs, patch_index):
            p_img = np.array(Image.open(os.path.join(self.datas[index][0], patch+'.png')))
            j, i = idx % PATCH_EDGE, idx // PATCH_EDGE
            x, y = j * PATCH_SIZE, i * PATCH_SIZE
            img[y:y+PATCH_SIZE, x:x+PATCH_SIZE] = p_img
        if label == 0 and random.random() < 1e-3:
            img = np.ones_like(img) * random.randint(0, 255)
        img = self.transform(Image.fromarray(img))

        return img, label


class PatchGraphBGDataset(Dataset):
    def __init__(self, root, cate_list, random, transform, mask='mask'):
        self.random = random
        self.transform = transform
        self.datas = []
        self.patch_nums = []
        self.patchs = []
        self.patch_index_list = list(range(MAX_PATCH_NUM))
        for i, sub_dir in enumerate(cate_list):
            names = sorted(os.listdir(os.path.join(root, sub_dir, 'image')))
            for name in names:
                if not os.path.exists(os.path.join(root, sub_dir, mask, name)):
                    raise ValueError(f'mask not exist: {name}')
                patch_list = sorted(map(lambda x: os.path.splitext(x)[0], os.listdir(os.path.join(root, sub_dir, 'image', name))))
                patch_num = len(patch_list)
                self.datas.append( (os.path.join(root, sub_dir, 'image', name), os.path.join(root, sub_dir, mask, name), i) )
                self.patch_nums.append(patch_num)
                self.patchs.append(patch_list)
        s_l, t_l = [], []
        for i in range(MASK_EDGE):
            for j in range(MASK_EDGE):
                s = i * MASK_EDGE + j
                for di in [-1,0,1]:
                    ti = i + di
                    if ti < 0 or ti >= MASK_EDGE: continue
                    for dj in [-1,0,1]:
                        tj = j + dj
                        if tj < 0 or tj >= MASK_EDGE or di == 0 and dj == 0: continue
                        t = ti * MASK_EDGE + tj
                        s_l.append(s)
                        t_l.append(t)
        self.full_edge_index = torch.tensor([s_l, t_l], dtype=torch.long)

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, index):
        label = self.datas[index][2]
        if self.random:
            patchs = random.sample(self.patchs[index], min(self.patch_nums[index], MAX_PATCH_NUM))
            patch_index = random.sample(self.patch_index_list, min(self.patch_nums[index], MAX_PATCH_NUM))
        else:
            patchs = self.patchs[index][:min(self.patch_nums[index], MAX_PATCH_NUM)]
            patch_index = list(range(min(self.patch_nums[index], MAX_PATCH_NUM)))
        img = np.ones((PATCH_EDGE*PATCH_SIZE, PATCH_EDGE*PATCH_SIZE, 3), dtype=np.uint8) * 241
        mask_ori = np.zeros((PATCH_EDGE*PATCH_SIZE, PATCH_EDGE*PATCH_SIZE), dtype=np.uint8)
        for patch, idx in zip(patchs, patch_index):
            p_img = np.array(Image.open(os.path.join(self.datas[index][0], patch+'.png')))
            p_mask = np.array(Image.open(os.path.join(self.datas[index][1], patch+'.png')))
            j, i = idx % PATCH_EDGE, idx // PATCH_EDGE
            x, y = j * PATCH_SIZE, i * PATCH_SIZE
            img[y:y+PATCH_SIZE, x:x+PATCH_SIZE] = p_img
            mask_ori[y:y+PATCH_SIZE, x:x+PATCH_SIZE] = p_mask
        mask = np.zeros((MASK_EDGE, MASK_EDGE), dtype=np.uint8)
        contours, hierarchy = cv2.findContours(mask_ori, mode=cv2.RETR_TREE, method=cv2.CHAIN_APPROX_NONE)
        if len(contours) == 0:
            link = self.full_edge_index.clone()
        else:
            hierarchy = hierarchy[0]
            contours = [contour for contour, hi in zip(contours, hierarchy) if hi[3] < 0]
            assert 0 < len(contours) < 65536, f'len(contours):{len(contours)}, {self.datas[index][0]}'
            canvas = np.zeros_like(mask_ori, dtype=np.uint16)
            for i, contour in enumerate(contours):
                canvas = cv2.drawContours(canvas, [contour], contourIdx=-1, color=i+1, thickness=cv2.FILLED, hierarchy=None)
            if canvas.shape[0] % (MASK_EDGE) != 0 or canvas.shape[1] % (MASK_EDGE) != 0:
                h, w = int(np.ceil(canvas.shape[0]/(MASK_EDGE))*(MASK_EDGE)), int(np.ceil(canvas.shape[1]/(MASK_EDGE))*(MASK_EDGE))
                canvas = cv2.resize(canvas, (w, h), interpolation=cv2.INTER_NEAREST)
            h_size, w_size = canvas.shape[0] // MASK_EDGE, canvas.shape[1] // MASK_EDGE
            label_list = []
            for i, sh in enumerate(range(0, canvas.shape[0], h_size)):
                for j, sw in enumerate(range(0, canvas.shape[1], w_size)):
                    patch = canvas[sh:sh+h_size, sw:sw+w_size]
                    label_list.append(np.unique(patch))
                    if patch.max() > 0:
                        mask[i, j] = 255
            s_l, t_l = [], []
            for i in range(MASK_EDGE):
                for j in range(MASK_EDGE):
                    s = i * MASK_EDGE + j
                    for di in [-1,0,1]:
                        ti = i + di
                        if ti < 0 or ti >= MASK_EDGE: continue
                        for dj in [-1,0,1]:
                            tj = j + dj
                            if tj < 0 or tj >= MASK_EDGE or di == 0 and dj == 0: continue
                            t = ti * MASK_EDGE + tj
                            if len(label_list[s]) == len(label_list[t]) == 1 and 0 in label_list[s] and 0 in label_list[t]:
                                s_l.append(s)
                                t_l.append(t)
                            else:
                                for num in label_list[s]:
                                    if num != 0 and num in label_list[t]:
                                        s_l.append(s)
                                        t_l.append(t)
                                        break
            if len(s_l) == 0:
                link = torch.zeros(2, 0, dtype=torch.long)
            else:
                link = torch.tensor([s_l, t_l], dtype=torch.long)
        if mask.max() == 0 and random.random() < 0.5:
            img = np.ones_like(img) * random.randint(0, 200)
        img, mask, link = self.transform([Image.fromarray(img), Image.fromarray(mask), link])
        return img, mask, link, label


class PatchGraphTestDataset(Dataset):
    def __init__(self, root, transform, mask='corr_masks', recover=False):
        self.transform = transform
        self.recover = recover
        self.datas = []
        self.img_patch_edge = 256*5
        self.msk_patch_edge = self.img_patch_edge//64
        assert self.img_patch_edge%64==0, f'{self.img_patch_edge},{self.img_patch_edge}'
        img_path = os.path.join(root, 'images')
        msk_path = os.path.join(root, mask)
        for img in sorted(os.listdir(img_path)):
            self.datas.append([os.path.join(img_path, img), os.path.join(msk_path, img)])
    
    def __len__(self):
        return len(self.datas)
    
    def __getitem__(self, index):
        img_ori = np.array(Image.open(self.datas[index][0]))
        img = np.ones((int(np.ceil(img_ori.shape[0]/self.img_patch_edge))*self.img_patch_edge,int(np.ceil(img_ori.shape[1]/self.img_patch_edge))*self.img_patch_edge,3),dtype=img_ori.dtype)*241
        img[:img_ori.shape[0], :img_ori.shape[1]] = img_ori
        lbl_ori = np.array(Image.open(self.datas[index][1]))
        lbl = np.zeros((int(np.ceil(lbl_ori.shape[0]/self.msk_patch_edge))*self.msk_patch_edge,int(np.ceil(lbl_ori.shape[1]/self.msk_patch_edge))*self.msk_patch_edge),dtype=lbl_ori.dtype)
        lbl[:lbl_ori.shape[0], :lbl_ori.shape[1]] = lbl_ori
        ori_h, ori_w = img_ori.shape[0], img_ori.shape[1]
        del img_ori, lbl_ori
        img_patch = []
        lbl_patch = []
        for i,y in enumerate(range(0, img.shape[0], self.img_patch_edge)):
            for j,x in enumerate(range(0, img.shape[1], self.img_patch_edge)):
                img_patch.append(Image.fromarray(img[y:y+self.img_patch_edge,x:x+self.img_patch_edge]))
                lbl_patch.append(lbl[i*self.msk_patch_edge:(i+1)*self.msk_patch_edge,j*self.msk_patch_edge:(j+1)*self.msk_patch_edge])
        img_patch = self.transform(img_patch)
        lbl_patch = [(torch.from_numpy(l)/255).long() for l in lbl_patch]
        if self.recover:
            return img_patch, lbl_patch, (ori_h, ori_w, img.shape[0], img.shape[1], img.shape[0]//self.img_patch_edge, img.shape[1]//self.img_patch_edge), os.path.basename(self.datas[index][0])
        else:
            return img_patch, lbl_patch


class PatchGraphNoLinkDataset(Dataset):
    def __init__(self, root, cate_list, random, transform, mask='mask'):
        self.random = random
        self.transform = transform
        self.datas = []
        self.patch_nums = []
        self.patchs = []
        self.patch_index_list = list(range(MAX_PATCH_NUM))
        self.test = False 

        for i, sub_dir in enumerate(cate_list):
            names = sorted(os.listdir(os.path.join(root, sub_dir, 'image'))) 
            for name in names: 
                if not os.path.exists(os.path.join(root, sub_dir, mask, name)):
                    raise ValueError(f'mask not exist: {name}')
                patch_list = sorted(map(lambda x: os.path.splitext(x)[0], os.listdir(os.path.join(root, sub_dir, 'image', name))))
                patch_num = len(patch_list)
                self.datas.append( (os.path.join(root, sub_dir, 'image', name), os.path.join(root, sub_dir, mask, name), i) )
                self.patch_nums.append(patch_num)
                self.patchs.append(patch_list)

    def __len__(self):
        return len(self.datas)

    def __getitem__(self, index):
        label = self.datas[index][2]
        if self.random:
            patchs = random.sample(self.patchs[index], min(self.patch_nums[index], MAX_PATCH_NUM))
            patch_index = random.sample(self.patch_index_list, min(self.patch_nums[index], MAX_PATCH_NUM))
        else:
            patchs = self.patchs[index][:min(self.patch_nums[index], MAX_PATCH_NUM)]
            patch_index = list(range(min(self.patch_nums[index], MAX_PATCH_NUM)))
        img = np.ones((PATCH_EDGE*PATCH_SIZE, PATCH_EDGE*PATCH_SIZE, 3), dtype=np.uint8) * 241
        mask = np.zeros((MASK_EDGE, MASK_EDGE), dtype=np.uint8)
        for patch, idx in zip(patchs, patch_index):
            p_img = np.array(Image.open(os.path.join(self.datas[index][0], patch+'.png')))
            p_mask = np.array(Image.open(os.path.join(self.datas[index][1], patch+'.png')))
            j, i = idx % PATCH_EDGE, idx // PATCH_EDGE
            x, y = j * PATCH_SIZE, i * PATCH_SIZE
            img[y:y+PATCH_SIZE, x:x+PATCH_SIZE] = p_img
            x, y = j * SPLIT_EDGE, i * SPLIT_EDGE
            mask[y:y+SPLIT_EDGE, x:x+SPLIT_EDGE] = p_mask
        
        if mask.max() == 0 and random.random() < 0.5:
            img = np.ones_like(img) * random.randint(0, 255) 
        img, mask = self.transform([Image.fromarray(img), Image.fromarray(mask)])

        if self.test:
            return img, mask, label, os.path.basename(self.datas[index][0])
        else:
            return img, mask, label


class NoLinkGeneratePseudoDataset(Dataset): 
    def __init__(self, root, cate_list, transform):
        self.transform = transform
        self.datas = []
        self.patch_nums = []
        self.patchs = []
        self.patch_index_list = list(range(MAX_PATCH_NUM))

        for sub_dir in cate_list: 
            names = sorted(os.listdir(os.path.join(root, sub_dir, 'image'))) 
            for name in names: 
                patch_list = sorted(os.listdir(os.path.join(root, sub_dir, 'image', name)))
                patch_num = len(patch_list)
                self.datas.append([os.path.join(root, sub_dir, 'image', name), os.path.join(root, sub_dir), name])
                self.patch_nums.append(patch_num)
                self.patchs.append(patch_list)

    def __len__(self):
        return len(self.datas)
    
    def __getitem__(self, index):
        patchs_list = self.patchs[index]
        image_list = []
        patch_name_list = []

        for patchs in [patchs_list[i: i+MAX_PATCH_NUM] for i in range(0, len(patchs_list), MAX_PATCH_NUM)]:
            img = np.ones((PATCH_EDGE*PATCH_SIZE, PATCH_EDGE*PATCH_SIZE, 3), dtype=np.uint8) * 241
            for idx, patch in enumerate(patchs):
                p_img = np.array(Image.open(os.path.join(self.datas[index][0], patch)))
                j, i = idx % PATCH_EDGE, idx // PATCH_EDGE
                x, y = j * PATCH_SIZE, i * PATCH_SIZE
                img[y:y+PATCH_SIZE, x:x+PATCH_SIZE] = p_img
            image_list.append(img)
            patch_name_list.append(patchs)

        imgs = torch.stack(self.transform([Image.fromarray(img) for img in image_list]), dim=0)
        path, name = self.datas[index][1], self.datas[index][2]
        return imgs, patch_name_list, path, name
