import numpy as np
import cv2
from augement_policy import Policy, get_sub_policies
from FT_utils import img_loader, forward_point_cv2, eval_gt_pred, update_record_env
import torch
import gymnasium as gym
from gymnasium import spaces
import random
import pdb

def segment_prompted_image(img, gt_mask, policy_dict, predictor, args, device):
    adapted_image = Policy(args, img, policy_dict)
    image_for_predictor = cv2.cvtColor(adapted_image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image_for_predictor)
    input_point = forward_point_cv2(gt_mask)
    input_label = np.array([1])

    pred_mask, _, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        box=None,
        multimask_output=False,
    )

    pred_masks_tensor = torch.from_numpy(pred_mask).to(device).float()
    gt_masks_tensor = torch.from_numpy(gt_mask).float() / 255.0
    gt_masks_tensor = gt_masks_tensor.unsqueeze(0).to(device)

    IoU_cost = eval_gt_pred(gt_masks_tensor, pred_masks_tensor)

    reward = IoU_cost.cpu().numpy().item()

    return adapted_image, reward
    

class Prompt2AdaptAll(gym.Env):
    
    def __init__(self, ori_images_stack, gt_mask_stack, predictor, device, mode, args):
        self.args = args
        self.observation_space = spaces.Box(0, 255, shape=(1024, 1024, 3), dtype=np.uint8)  #TODO: What is your observation space
        self.action_space = spaces.MultiDiscrete([32,10,32,10,32,10])
        self._image_stack = ori_images_stack #should be resized numpy image stack BGR 10 * 1024 * 1024 * 3  0-255
        self._gt_mask_stack = gt_mask_stack #should be resized numpy gt mask stack BGR 10 * 1024 * 1024  0-255
        self._random_int = np.random.randint(0, self._image_stack.shape[0])
        self._adapted_image = self._image_stack[self._random_int]
        self._predictor = predictor
        self._device = device
        self._mode = mode
        self._policy_dict = []
        self._policy_dict_info = {}
        self._policy_dict_wrapper = ["step1", "step2", "step3"]
        self._best_policy_list = []
        self._best_mIoU_list = []
        
        #self.transform = transforms.Compose([transforms.Resize((1024,1024), interpolation = transforms.InterpolationMode.BICUBIC)])
        
    def _get_obs(self):
        #self._adapted_image = self._image_stack[self._random_int]
        return self._adapted_image
    
    def _get_info(self):
        return self._policy_dict_info
    
    def reset(self, seed=None, options = None):
        # reset the seed self.np_random
        super().reset(seed=seed)
        
        self._random_int = np.random.randint(0, self._image_stack.shape[0])
        self._adapted_image = self._image_stack[self._random_int] # TODO reset adapted image to origin image
        self._policy_dict = []
        self._policy_dict_info = {}
        
        observation = self._get_obs()
        info = self._get_info()
        terminated = False
        
        return observation, info
    
    def step(self, action):
        
        #self._random_int = np.random.randint(0, self._image_stack.shape[0])
        policy_id_list = []
        magnitude_id_list = []
        #op_list = []
        #magnitude_list = []
        op_magnitude_list = []
        
        for i in range(self.args.subpolicy_num):
            policy_id_list.append(action[2 * i])
            magnitude_id_list.append(action[2 * i + 1])
        
        self._policy_dict = get_sub_policies(policy_id_list, magnitude_id_list, self.args)
        
        for i in range(self.args.subpolicy_num):
            op = self._policy_dict[i][0]['op']
            #op_list.append(op)
            magnitude = self._policy_dict[i][0]['magnitude']
            #magnitude_list.append(magnitude)
            op_magnitude_list.append(str(op) + " : " + str(magnitude))

            #if((len(op_list) != 3) | (len(magnitude_list) != 3)):
                #pdb.set_trace()

        self._policy_dict_info = dict(zip(self._policy_dict_wrapper, op_magnitude_list))
        
        if(self._mode == "Three"):
            
            mIoU_list = []
            
            self._adapted_image = self._image_stack[self._random_int]
            
            numbers = list(range(self._image_stack.shape[0]))
            
            numbers.remove(self._random_int)
            
            other_numbers = random.sample(numbers, 2)
            
            idx1 = other_numbers[0]
            
            idx2 = other_numbers[1]
            
            self._adapted_image, reward0 = segment_prompted_image(
                img = self._adapted_image, 
                gt_mask = self._gt_mask_stack[self._random_int], 
                policy_dict = self._policy_dict, 
                predictor = self._predictor, 
                args = self.args, 
                device = self._device)
            
            _, reward1 = segment_prompted_image(
                img = self._image_stack[idx1], 
                gt_mask = self._gt_mask_stack[idx1], 
                policy_dict = self._policy_dict, 
                predictor = self._predictor, 
                args = self.args, 
                device = self._device)
            
            _, reward2 = segment_prompted_image(
                img = self._image_stack[idx2], 
                gt_mask = self._gt_mask_stack[idx2],
                policy_dict = self._policy_dict, 
                predictor = self._predictor, 
                args = self.args, 
                device = self._device)

            terminated = True

            reward = np.mean([reward0, reward1, reward2])
            
        elif(self._mode == "Single"):
            self._adapted_image = self._image_stack[self._random_int]
            #self._policy_dict = [{0: {'op': 'contrast_down', 'magnitude': 8}}, {0: {'op': 'saturation_down', 'magnitude': 8}}, {0: {'op': 'gaussianBlur', 'magnitude': 2}}]
            #just for test
            self._adapted_image = Policy(self.args, self._adapted_image, self._policy_dict)
            image_for_predictor = cv2.cvtColor(self._adapted_image, cv2.COLOR_BGR2RGB)
            self._predictor.set_image(image_for_predictor)
            gt_mask = self._gt_mask_stack[self._random_int]
            input_point = forward_point_cv2(gt_mask)
            input_label = np.array([1])

            pred_mask, _, _ = self._predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            box=None,
            multimask_output=False,
            )

            pred_masks_tensor = torch.from_numpy(pred_mask).to(self._device).float()
            gt_masks_tensor = torch.from_numpy(gt_mask).float() / 255.0
            gt_masks_tensor = gt_masks_tensor.unsqueeze(0).to(self._device)

            IoU_cost = eval_gt_pred(gt_masks_tensor, pred_masks_tensor)

            terminated = True

            reward = IoU_cost

            reward = reward.cpu().numpy().item()
            
        elif(self._mode == "SingleResidual"):
            self._adapted_image = self._image_stack[self._random_int]
            #self._policy_dict = [{0: {'op': 'contrast_down', 'magnitude': 8}}, {0: {'op': 'saturation_down', 'magnitude': 8}}, {0: {'op': 'gaussianBlur', 'magnitude': 2}}]
            #just for test
            
            origin_image = self._adapted_image
            
            self._adapted_image = Policy(self.args, self._adapted_image, self._policy_dict)
            image_for_predictor = cv2.cvtColor(self._adapted_image, cv2.COLOR_BGR2RGB)
            self._predictor.set_image(image_for_predictor)
            gt_mask = self._gt_mask_stack[self._random_int]
            input_point = forward_point_cv2(gt_mask)
            input_label = np.array([1])

            pred_mask, _, _ = self._predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            box=None,
            multimask_output=False,
            )

            pred_masks_tensor = torch.from_numpy(pred_mask).to(self._device).float()
            gt_masks_tensor = torch.from_numpy(gt_mask).float() / 255.0
            gt_masks_tensor = gt_masks_tensor.unsqueeze(0).to(self._device)

            IoU_cost = eval_gt_pred(gt_masks_tensor, pred_masks_tensor)
            
            # see original image performance
            image_for_predictor_o = cv2.cvtColor(origin_image, cv2.COLOR_BGR2RGB)
            self._predictor.set_image(image_for_predictor_o)
            
            pred_mask_o, _, _ = self._predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            box=None,
            multimask_output=False,
            )

            pred_masks_tensor_o = torch.from_numpy(pred_mask_o).to(self._device).float()
            
            IoU_cost_o = eval_gt_pred(gt_masks_tensor, pred_masks_tensor_o)

            terminated = True

            IoU_cost = IoU_cost.cpu().numpy().item()
            
            IoU_cost_o = IoU_cost_o.cpu().numpy().item()
            
            reward = IoU_cost - IoU_cost_o
            
        elif(self._mode == "All"):
            
            mIoU_list = []
            
            self._adapted_image = self._image_stack[self._random_int]
            self._adapted_image = Policy(self.args, self._adapted_image, self._policy_dict)

            for i in range(self._image_stack.shape[0]):
                adapted_image = self._image_stack[i]

                adapted_image = Policy(self.args, adapted_image, self._policy_dict)
                image_for_predictor = cv2.cvtColor(adapted_image, cv2.COLOR_BGR2RGB)
                self._predictor.set_image(image_for_predictor)
                gt_mask = self._gt_mask_stack[i]
                input_point = forward_point_cv2(gt_mask)
                input_label = np.array([1])

                pred_mask, _, _ = self._predictor.predict(
                point_coords=input_point,
                point_labels=input_label,
                box=None,
                multimask_output=False,
                )

                pred_masks_tensor = torch.from_numpy(pred_mask).to(self._device).float()
                gt_masks_tensor = torch.from_numpy(gt_mask).float() / 255.0
                gt_masks_tensor = gt_masks_tensor.unsqueeze(0).to(self._device)

                IoU_cost = eval_gt_pred(gt_masks_tensor, pred_masks_tensor)
                mIoU_list.append(IoU_cost.cpu().numpy().item())


            terminated = True

            reward = np.mean(mIoU_list)
            
        #pdb.set_trace()
        else:
            raise NameError("PleaseSetCorrectModeName")
            
        print("Adopt ", self._policy_dict, " in this step and get ", reward, " mIoU.")
           
        if(((len(self._best_mIoU_list) == self.args.save_policy_len) and (reward < self._best_mIoU_list[-1])) or ( reward in self._best_mIoU_list) or ( self._policy_dict in self._best_policy_list)):
            pass
        else:
            self._best_mIoU_list.append(reward)
            self._best_policy_list.append(self._policy_dict)
            print("Take ", self._policy_dict, " into best policy list.")
            self._best_mIoU_list, self._best_policy_list = update_record_env(self._best_mIoU_list, self._best_policy_list, self.args.save_policy_len)
        
        observation = self._get_obs() #should be 1 adaptedd images
        info = self._get_info() #should be 3 policies with magnitude

        return observation, reward, terminated, False, info