import os
from os.path import join
import cv2
import numpy as np
import torch
import shutil
import torch
import torch.nn as nn
import torch.nn.functional as F
from numba import jit
import gc

class DiceLoss(nn.Module):
    def __init__(self):
            super(DiceLoss, self).__init__()

    def	forward(self, input, target):
            N = target.size(0)
            smooth = 1

            input_flat = input.view(N, -1)
            target_flat = target.view(N, -1)

            intersection = input_flat * target_flat

            loss = 2 * (intersection.sum(1) + smooth) / (input_flat.sum(1) + target_flat.sum(1) + smooth)
            loss = 1 - loss.sum() / N

            return loss


def get_connection(mask,thresh=100): 
    '''
    input: n*n
    bbox: from 0
    idx: from 1
    '''
    queue = [] 
    regions = [] 
    minmaxs = [] 
    mask1 = mask.copy() 
    flag = mask1.copy() 
    output = np.zeros(mask1.shape) 

    count = 0 

    for i in range(mask1.shape[0]):
        for j in range(mask1.shape[1]):
            if mask1[i,j]!=0 and flag[i,j]!=0:
                region = [] 
                minmax = np.zeros((mask1.shape[0],2)) 
                minmax[:,0] = mask1.shape[1]+1 
                minmax[:,1] = -1 
                queue.clear()
                region.append((i,j))
                queue.append((i,j))
                flag[i,j] = 0
                while len(queue)>0:
                    item = queue.pop(0)
                    if minmax[item[0], 0] > item[1]:
                        minmax[item[0], 0] = item[1]
                    if minmax[item[0], 1] < item[1]:
                        minmax[item[0], 1] = item[1]
                    
                    if item[0]-1>=0 and mask1[item[0]-1,item[1]]!=0 and flag[item[0]-1,item[1]]!=0:
                        queue.append((item[0]-1,item[1]))
                        region.append((item[0]-1,item[1]))
                        flag[item[0]-1,item[1]] = 0
                    if item[0]+1<=mask1.shape[0]-1 and mask1[item[0]+1,item[1]]!=0 and flag[item[0]+1,item[1]]!=0:
                        queue.append((item[0]+1,item[1]))
                        region.append((item[0]+1,item[1]))
                        flag[item[0]+1,item[1]] = 0
                    if item[1]-1>=0 and mask1[item[0],item[1]-1]!=0 and flag[item[0],item[1]-1]!=0:
                        queue.append((item[0],item[1]-1))
                        region.append((item[0],item[1]-1))
                        flag[item[0],item[1]-1] = 0
                    if item[1]+1<=mask1.shape[1]-1 and mask1[item[0],item[1]+1]!=0 and flag[item[0],item[1]+1]!=0:
                        queue.append((item[0],item[1]+1))
                        region.append((item[0],item[1]+1))
                        flag[item[0],item[1]+1] = 0
                if len(region)<thresh: 
                    continue
                
                a = list(region)
                count += 1
                b = minmax.copy()
                regions.append(a)
                minmaxs.append(b)
                for k in region:
                    output[k[0],k[1]] = count
    # bbox
    # [hmin,wmin,hmax,wmax]
    bboxs = [] 
    for minmax in minmaxs:
        hmin,wmin,hmax,wmax = 0,0,0,0
        inflag = 0
        for i in range(minmax.shape[0]):
            if minmax[i,0] > minmax[i,1]:
                if inflag == 1:
                    inflag = 0
                    break
                else:
                    continue
            if inflag == 0:
                inflag = 1
                hmin = i
                hmax = i
                wmin = minmax[i,0]
                wmax = minmax[i,1]
            else:
                hmax = i
                wmin = min(wmin,minmax[i,0])
                wmax = max(wmax,minmax[i,1])

        bboxs.append([int(hmin),int(wmin),int(hmax),int(wmax)])
    
    return output, bboxs, regions




def get_det(mask, r=6, rr=2): 
    mask = mask.astype(int).squeeze(-1) 
    height = mask.shape[0]  
    width  = mask.shape[1]
    det = np.zeros((height,width))
    g = np.zeros((8, height,width))
    for k in range(rr): # k∈[0, 1]
        for l in range(k+1,r+1): # l∈[k+1, ..., 4]
            g[0, r:height-r, r:width-r] += mask[r:height-r, r+k:width-r+k] - mask[r:height-r, r+l:width-r+l]
            g[1, r:height-r, r:width-r] += mask[r-k:height-r-k, r+k:width-r+k] - mask[r-l:height-r-l, r+l:width-r+l]
            g[2, r:height-r, r:width-r] += mask[r-k:height-r-k, r:width-r] - mask[r-l:height-r-l, r:width-r]
            g[3, r:height-r, r:width-r] += mask[r-k:height-r-k, r-k:width-r-k] - mask[r-l:height-r-l, r-l:width-r-l]
            g[4, r:height-r, r:width-r] += mask[r:height-r, r-k:width-r-k] - mask[r:height-r, r-l:width-r-l]
            g[5, r:height-r, r:width-r] += mask[r+k:height-r+k, r-k:width-r-k] - mask[r+l:height-r+l, r-l:width-r-l]
            g[6, r:height-r, r:width-r] += mask[r+k:height-r+k, r:width-r] - mask[r+l:height-r+l, r:width-r]
            g[7, r:height-r, r:width-r] += mask[r+k:height-r+k, r+k:width-r+k] - mask[r+l:height-r+l, r+l:width-r+l]
    gate = (g>0).all(axis=0)
    det[gate] = g.sum(axis=0)[gate]
    get_max, get_min = det.max(), det.min()
    det = (det-get_min)/(get_max-get_min) 
    return det


def get_new_temp(pred,gt,img):
    '''
    pred: n*n
    gt: n*n
    img: n*n
    '''
    height = pred.shape[0]
    width = pred.shape[1]
    new = np.zeros(())

def datanorm(img):
    # print(img)
    output = img * 1.0 / 255
    # print(output)
    return output

def gaussian_shaped_labels(sigma, sz):
    x, y = np.meshgrid(np.arange(1, sz[0]+1) - np.floor(float(sz[0]) / 2), np.arange(1, sz[1]+1) - np.floor(float(sz[1]) / 2))
    d = x ** 2 + y ** 2
    g = np.exp(-0.5 / (sigma ** 2) * d)

    g = np.roll(g, int(-np.floor(float(sz[0]) / 2.) + 1), axis=0)
    g = np.roll(g, int(-np.floor(float(sz[1]) / 2.) + 1), axis=1)
    return g.astype(np.float32)

def gaussian(sigma,sz):
    x, y = np.meshgrid(np.arange(0, sz[0]) - np.floor(float(sz[0]) / 2), np.arange(0, sz[1]) - np.floor(float(sz[1]) / 2))
    d = x ** 2 + y ** 2
    g = np.exp(-0.5 / (sigma ** 2) * d)
    return g
    