import sys
import os
import torch
from PIL import Image
from torchvision import transforms
import torch.backends.cudnn as cudnn
import numpy as np
import cv2
import gymnasium 
import pdb
import time
import argparse
from skimage import img_as_ubyte
from stable_baselines3 import PPO
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.env_checker import check_env
from controller import Controller
from augement_policy import Policy
from config import get_args
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
from FT_utils import img_loader, forward_point_cv2, eval_gt_pred, update_record
from PromptToAdapt import Prompt2Adapt
from PromptToAdaptAll import Prompt2AdaptAll
from datetime import datetime
from logger import Logger

 


def is_image_file(filename):
    try:
        with Image.open(filename) as img:
            return True
    except IOError:
        return False



def main(args):
    
    BestPolicy = [{0: {'op': 'sharpen_gaussian', 'magnitude': 7}}, {0: {'op': 'contrast_up', 'magnitude': 9}}, {0: {'op': 'gaussianBlur', 'magnitude': 8}}]
    model_name = "Paper_" + args.exp_name + "_" + str(BestPolicy)
    log_file_name = f"{model_name}.log"
    sys.stdout = Logger(str(log_file_name))
    sys.stderr = Logger(str(log_file_name))
    
    print(args.val_txt)
    
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    # SAM
    sam = sam_model_registry[args.model_type](checkpoint=args.checkpoint)
    _ = sam.to(device)
    predictor = SamPredictor(sam) # predictor is set up
    
    len_save = args.save_policy_len
    #best_p_val_dict = [0 for i in range(len_save)]
    
    transform = transforms.Compose([transforms.Resize((1024,1024), interpolation = transforms.InterpolationMode.BICUBIC, antialias = True)])
    
    mode = "val"
    
    img_dir = args.img_dir
    label_dir = args.label_dir
    train_txt = args.train_txt
    val_txt = args.val_txt

    all_images = [f for f in os.listdir(img_dir) if is_image_file(os.path.join(img_dir, f))]
    
    with open(val_txt, 'r') as file:
        images_in_txt = file.read().splitlines()
        
    #pdb.set_trace()
    if mode == 'train':
        ann_file = open(train_txt, "r")
    elif mode == 'val':
        ann_file = open(val_txt, "r")
    else:
        ann_file = open(val_txt, "r")
    content = ann_file.read()
    imgs_list = content.splitlines()

        
    #imgs_list = all_images[:100]
    
    imgs_list = imgs_list[:100]

    #pdb.set_trace()
    img_stack = np.zeros((len(imgs_list),1024,1024,3))
    gt_mask_stack = np.zeros((len(imgs_list),1024,1024))

    for i in range(len(imgs_list)):
        
        #pdb.set_trace()
        
        img_name = imgs_list[i].split(".")[0]
        img_path = os.path.join(img_dir, img_name + ".png")
        img_path_png = os.path.join(img_dir, img_name + ".png")
        label_path = os.path.join(label_dir, img_name + "_GT.png")
        label_path_png = os.path.join(label_dir, img_name + "_GT.png")
        
        img = cv2.imread(img_path)
        if img is None:
            img = cv2.imread(img_path_png)
        img = torch.from_numpy(img)
        img = img.permute(2,0,1)
        img_resize = transform(img)
        img_stack[i] = img_resize.numpy().transpose(1,2,0)
        img_stack = img_stack.astype(np.uint8)
        
        gt_mask = cv2.imread(label_path_png, cv2.IMREAD_GRAYSCALE)
        if gt_mask is None:
            gt_mask = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE)
        gt_mask = torch.from_numpy(gt_mask)
        gt_mask_resize = transform(gt_mask.unsqueeze(0))
        gt_mask_stack[i] = gt_mask_resize.squeeze().numpy()
        gt_mask_stack = gt_mask_stack.astype(np.uint8)
        
    
    #pdb.set_trace()
    
    

    device = torch.device("cuda")
    
    
    DataPairList = []
    for i in range(img_stack.shape[0]):
        DataPair = {}
        imgA_path = "Pix2PixDataPair/" + args.exp_name + str(i) + ".jpg"
        imgA = img_stack[i]
        #imgA = cv2.resize(imgA, (256, 256))
        imgA = cv2.cvtColor(imgA, cv2.COLOR_BGR2RGB)
        imgA = np.float32(imgA) / 255.
        imgA = torch.from_numpy(imgA).permute(2, 0, 1)
        inputA = imgA.unsqueeze(0)
        inputA = inputA.to(device)
        DataPair["A"] = inputA
        DataPair["A_paths"] = imgA_path
        #cv2.imwrite(imgA_path, img_stack[i])
        
        obs = Policy(args, img_stack[i], BestPolicy)
        print("Finish generating real B " + str(i))
        imgB_path = "Pix2PixDataPair/" + args.exp_name + str(i) + ".png"
        imgB = obs
        #imgB = cv2.resize(imgB, (256, 256))
        imgB = cv2.cvtColor(imgB, cv2.COLOR_BGR2RGB)
        imgB = np.float32(imgB) / 255.
        imgB = torch.from_numpy(imgB).permute(2, 0, 1)
        inputB = imgB.unsqueeze(0)
        inputB = inputB.to(device)
        DataPair["B"] = inputB
        DataPair["B_paths"] = imgB_path
        #cv2.imwrite(imgB_path, obs)
        
        DataPairList.append(DataPair)
        
    
    
    
    save_distill_path = './Paper_Results/' + 'Paper_' + args.exp_name
    
    if not os.path.exists(save_distill_path):
        os.makedirs(save_distill_path)
    
    num = 0
    
    mIoU_realA = []
    
    mIoU_realB = []
    
    mIoU_residualAB = []
    
    #pdb.set_trace()
    
    start_time_Distill = time.time()
    
    for DP in DataPairList:
        
        gt_mask = gt_mask_stack[num]
        cv2.imwrite(save_distill_path + "/" + args.exp_name + "_img_Mask_" + str(num) + ".jpg", gt_mask)
        input_point = forward_point_cv2(gt_mask)
        input_label = np.array([1])
        
        gt_masks_tensor = torch.from_numpy(gt_mask).float() / 255.0
        gt_masks_tensor = gt_masks_tensor.unsqueeze(0).to(device)

        
        realA = DP["A"]
        realA = torch.clamp(realA, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
        realA = cv2.cvtColor(realA, cv2.COLOR_RGB2BGR)
        realA = img_as_ubyte(realA)
        cv2.imwrite(save_distill_path + "/" + args.exp_name + "_img_real_A_" + str(num) + ".jpg", realA)
        predictor.set_image(realA)
        pred_mask_realA, _, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            box=None,
            multimask_output=False,
        )
        
        predA = pred_mask_realA.astype(np.uint8)
        predA = predA.squeeze()
        predA = predA * 255
        cv2.imwrite(save_distill_path + "/" + args.exp_name + "_img_real_A_pred_" + str(num) + ".jpg", predA)
        
        pred_masks_tensor_rA = torch.from_numpy(pred_mask_realA).to(device).float()
        

        IoU_cost_rA = eval_gt_pred(gt_masks_tensor, pred_masks_tensor_rA)
        
        IoU_cost_rA = IoU_cost_rA.cpu().numpy().item()
        
        mIoU_realA.append(IoU_cost_rA)
            
        realB = DP["B"]
        realB = torch.clamp(realB, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
        realB = cv2.cvtColor(realB, cv2.COLOR_RGB2BGR)
        realB = img_as_ubyte(realB)
        cv2.imwrite(save_distill_path + "/" + args.exp_name + "_img_real_B_" + str(num) + ".jpg", realB)
        predictor.set_image(realB)
        pred_mask_realB, _, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            box=None,
            multimask_output=False,
        )
        
        predB = pred_mask_realB.astype(np.uint8)
        predB = predB.squeeze()
        predB = predB * 255
        cv2.imwrite(save_distill_path + "/" + args.exp_name + "_img_real_B_pred_" + str(num) + ".jpg", predB)
        
        pred_masks_tensor_rB = torch.from_numpy(pred_mask_realB).to(device).float()


        IoU_cost_rB = eval_gt_pred(gt_masks_tensor, pred_masks_tensor_rB)
        
        IoU_cost_rB = IoU_cost_rB.cpu().numpy().item()
        
        mIoU_realB.append(IoU_cost_rB)
        
        
        mIoU_residualAB.append(IoU_cost_rB - IoU_cost_rA)
        
        
        num = num + 1
        
    print("############")
    
    print("original image performance: ", np.mean(mIoU_realA))
        
    print("No distill policy performance: ", np.mean(mIoU_realB))
    
    print("Residual performance: ", np.mean(mIoU_residualAB))
    
    print("Residual best: ", np.max(mIoU_residualAB), "its index", mIoU_residualAB.index(np.max(mIoU_residualAB)))
    
    print("############")
    
    pdb.set_trace()
        

if __name__ == '__main__':
    args = get_args()
    main(args)