import cv2
import pdb
import torch
import numpy as np
from augement_policy import WaveletTransformation, DepthAnything, NormalAnything, AlbedoFunc, RoughFunc, EdgeDetection, RetinexBrightUp, ImageDemoireing, MIMOdeblur, BinarizationFunc, medianBlur
from FT_utils import img_loader, forward_point_cv2, eval_gt_pred, update_record
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor

# SAM
device = torch.device("cuda")

print("Loading model...")
sam = sam_model_registry["vit_h"](checkpoint="./segment_anything/sam_vit_h_4b8939.pth")
_ = sam.to(device)
predictor = SamPredictor(sam) # predictor is set up

img_name = "Kvasir.jpg"

gt_name = "Kvasir.jpg"

Img = cv2.imread("TestAugmentImg/Img/" + img_name)

gt_mask = cv2.imread("TestAugmentImg/GT/" + gt_name, cv2.IMREAD_GRAYSCALE)

gt_mask = 255 - gt_mask

Img = cv2.resize(Img, (1024,1024))

gt_mask = cv2.resize(gt_mask, (1024,1024))

'''
Wavelet = WaveletTransformation(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/Wavelet" + img_name, Wavelet)

Depth = DepthAnything(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/Depth" + img_name, Depth)

Normal = NormalAnything(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/Normal" + img_name, Normal)

Albedo = AlbedoFunc(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/Albedo" + img_name, Albedo)

Rough = RoughFunc(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/Rough" + img_name, Rough)

Edge = EdgeDetection(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/Edge" + img_name, Edge)

RBU = RetinexBrightUp(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/RBU" + img_name, RBU)

Demoire = ImageDemoireing(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/Demoire" + img_name, Demoire)

Deblur = MIMOdeblur(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/Deblur" + img_name, Deblur)

Bin = BinarizationFunc(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/Bin" + img_name, Bin)



input_point = forward_point_cv2(gt_mask)
input_label = np.array([1])

predictor.set_image(Img)
pred_mask_Img, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=None,
    multimask_output=False,
)

predictor.set_image(Normal)
pred_mask_Normal, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=None,
    multimask_output=False,
)

predictor.set_image(Bin)
pred_mask_Bin, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=None,
    multimask_output=False,
)

predictor.set_image(Depth)
pred_mask_Depth, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=None,
    multimask_output=False,
)

pred = pred_mask_Img.astype(np.uint8)
pred = pred.squeeze()
pred = pred * 255
cv2.imwrite("ImgMask" + img_name, pred)

pred2 = pred_mask_Depth.astype(np.uint8)
pred2 = pred2.squeeze()
pred2 = pred2 * 255
cv2.imwrite("DepthMask" + img_name, pred2)

pdb.set_trace()
'''

Step1 = NormalAnything(Img, magnitude = 0)
cv2.imwrite("TestAugmentImg/Normal" + img_name, Step1)

Step2 = RetinexBrightUp(Step1, magnitude = 0)
cv2.imwrite("TestAugmentImg/RBU" + img_name, Step2)

Step3 = medianBlur(Step2, magnitude = 4)
cv2.imwrite("TestAugmentImg/DB" + img_name, Step3)
