import torch
import pathlib
import cv2
import numpy as np
import matplotlib.pyplot as plt
from vit_pytorch import ViT
from BinarizationModel.binae import BinModel
from einops import rearrange
from skimage import img_as_ubyte

THRESHOLD = 0.5 ## binarization threshold after the model output

SPLITSIZE =  256  ## your image will be divided into patches of 256x256 pixels
setting = "large"  ## choose the desired model size [small, base or large], depending on the model you want to use
patch_size = 16 ## choose your desired patch size [8 or 16], depending on the model you want to use
image_size =  (SPLITSIZE,SPLITSIZE)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if setting == 'base':
    encoder_layers = 6
    encoder_heads = 8
    encoder_dim = 768

if setting == 'small':
    encoder_layers = 3
    encoder_heads = 4
    encoder_dim = 512

if setting == 'large':
    encoder_layers = 12
    encoder_heads = 16
    encoder_dim = 1024
    
    
v = ViT(
    image_size = image_size,
    patch_size = patch_size,
    num_classes = 1000,
    dim = encoder_dim,
    depth = encoder_layers,
    heads = encoder_heads,
    mlp_dim = 2048
)
Bmodel = BinModel(
    encoder = v,
    decoder_dim = encoder_dim,      
    decoder_depth = encoder_layers,
    decoder_heads = encoder_heads       
)

Bmodel = Bmodel.to(device)

model_path = "./PretrainModel/model_16_2018_large.pt"
Bmodel.load_state_dict(torch.load(model_path, map_location=device))

print("Finishing building Binarization Model")


def split(im,h,w):
    patches=[]
    nsize1=SPLITSIZE
    nsize2=SPLITSIZE
    for ii in range(0,h,nsize1): #2048
        for iii in range(0,w,nsize2): #1536
            patches.append(im[ii:ii+nsize1,iii:iii+nsize2,:])
    
    return patches 

def merge_image(splitted_images, h,w):
    image=np.zeros(((h,w,3)))
    nsize1=SPLITSIZE
    nsize2=SPLITSIZE
    ind =0
    for ii in range(0,h,nsize1):
        for iii in range(0,w,nsize2):
            image[ii:ii+nsize1,iii:iii+nsize2,:]=splitted_images[ind]
            ind += 1
    return image  

def BinarizationFunc(img, magnitude, model = Bmodel, device = device):

    deg_image = img / 255

    ## Split the image intop patches, an image is padded first to make it dividable by the split size
    #h =  ((deg_image.shape[0] // 256) +1)*256 
    #w =  ((deg_image.shape[1] // 256 ) +1)*256
    h = deg_image.shape[0]
    w = deg_image.shape[1]
    deg_image_padded=np.ones((h,w,3))
    deg_image_padded[:deg_image.shape[0],:deg_image.shape[1],:]= deg_image
    patches = split(deg_image_padded, deg_image.shape[0], deg_image.shape[1])
    ## preprocess the patches (images)
    mean = [0.485, 0.456, 0.406]
    std = [0.229, 0.224, 0.225]

    out_patches=[]
    for p in patches:
        out_patch = np.zeros([3, *p.shape[:-1]])
        for i in range(3):
            out_patch[i] = (p[:,:,i] - mean[i]) / std[i]
        out_patches.append(out_patch)

    result = []
    for patch_idx, p in enumerate(out_patches):
        #print(f"({patch_idx} / {len(out_patches) - 1}) processing patch...")
        p = np.array(p, dtype='float32')
        train_in = torch.from_numpy(p)

        with torch.no_grad():
            train_in = train_in.view(1,3,SPLITSIZE,SPLITSIZE).to(device)
            _ = torch.rand((train_in.shape)).to(device)
            loss,_, pred_pixel_values = model(train_in,_)
            rec_patches = pred_pixel_values
            rec_image = torch.squeeze(rearrange(rec_patches, 'b (h w) (p1 p2 c) -> b c (h p1) (w p2)', p1 = patch_size, p2 = patch_size,  h=image_size[0]//patch_size))
            impred = rec_image.cpu().numpy()
            impred = np.transpose(impred, (1, 2, 0))
            for ch in range(3):
                impred[:,:,ch] = (impred[:,:,ch] *std[ch]) + mean[ch]
            impred[np.where(impred>1)] = 1
            impred[np.where(impred<0)] = 0
        result.append(impred)


    clean_image = merge_image(result, deg_image_padded.shape[0], deg_image_padded.shape[1])
    clean_image = clean_image[:deg_image.shape[0], :deg_image.shape[1],:]
    clean_image = (clean_image>THRESHOLD)*255
    
    pred = clean_image.astype(np.uint8)
    #pred = pred.astype(np.uint8)
    
    return pred
    
    

if __name__ == "__main__":
    
    raw_img = cv2.imread('10.png')
    
    raw_img = cv2.resize(raw_img, (1024, 1024))
    
    Bin = BinarizationFunc(raw_img, magnitude = 0)
    
    cv2.imwrite("10Bin.jpg", Bin)
