import sys
import os
import torch
from torchvision import transforms
import torch.backends.cudnn as cudnn
import numpy as np
import cv2
import gymnasium 
import pdb
import argparse
from skimage import img_as_ubyte
import matplotlib.pyplot as plt
from stable_baselines3 import PPO
from sb3_contrib import RecurrentPPO
from stable_baselines3.common.env_checker import check_env
from controller import Controller
from augement_policy import Policy
from config import get_args
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry, SamPredictor
from FT_utils import img_loader, forward_point_cv2, eval_gt_pred, update_record
from PromptToAdapt import Prompt2Adapt
from PromptToAdaptAll import Prompt2AdaptAll
from datetime import datetime
from logger import Logger
from Pix2Pix.models.pix2pix_model import Pix2PixModel
 




def main(args):
    
    print(args.val_txt)
    
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)

    # SAM
    len_save = args.save_policy_len
    #best_p_val_dict = [0 for i in range(len_save)]
    
    transform = transforms.Compose([transforms.Resize((1024,1024), interpolation = transforms.InterpolationMode.BICUBIC, antialias = True)])
    
    mode = "val"
    
    img_dir = args.img_dir
    label_dir = args.label_dir
    train_txt = args.train_txt
    val_txt = args.val_txt
    
    if mode == 'train':
        ann_file = open(train_txt, "r")
    elif mode == 'val':
        ann_file = open(val_txt, "r")
    else:
        ann_file = open(val_txt, "r")
    content = ann_file.read()
    imgs_list = content.splitlines()
    
    imgs_list = imgs_list[:500]
    
    print(str(len(imgs_list)) + "images in All")

    #pdb.set_trace()
    img_stack = np.zeros((len(imgs_list),1024,1024,3))
    gt_mask_stack = np.zeros((len(imgs_list),1024,1024))

    for i in range(len(imgs_list)):
        
        #pdb.set_trace()
        
        img_name = imgs_list[i].split(".")[0]
        img_path = os.path.join(img_dir, img_name + ".jpg")
        img_path_png = os.path.join(img_dir, img_name + ".png")
        img = cv2.imread(img_path)
        if img is None:
            img = cv2.imread(img_path_png)
                
        img = torch.from_numpy(img)
        img = img.permute(2,0,1)
        img_resize = transform(img)
        img_stack[i] = img_resize.numpy().transpose(1,2,0)
        img_stack = img_stack.astype(np.uint8)
        
    BestPolicy = [{0: {'op': 'Erode', 'magnitude': 6}}, {0: {'op': 'contrast_up', 'magnitude': 3}}, {0: {'op': 'gaussianBlur', 'magnitude': 6}}]

    
    model_name = "MakeDataset_" + args.exp_name + "_" + str(BestPolicy)
    log_file_name = f"{model_name}.log"
    sys.stdout = Logger(str(log_file_name))
    sys.stderr = Logger(str(log_file_name))

    device = torch.device("cuda")
    
    DataPairList = []
    for i in range(img_stack.shape[0]):
        DataPair = {}
        imgA_path = "Pix2PixDataPair/" + args.exp_name + str(i) + ".jpg"
        imgA = img_stack[i]
        #imgA = cv2.resize(imgA, (256, 256))
        imgA = cv2.cvtColor(imgA, cv2.COLOR_BGR2RGB)
        imgA = np.float32(imgA) / 255.
        imgA = torch.from_numpy(imgA).permute(2, 0, 1)
        inputA = imgA.unsqueeze(0)
        inputA = inputA.to(device)
        DataPair["A"] = inputA
        DataPair["A_paths"] = imgA_path
        #cv2.imwrite(imgA_path, img_stack[i])
        
        obs = Policy(args, img_stack[i], BestPolicy)
        imgB_path = "Pix2PixDataPair/" + args.exp_name + str(i) + ".png"
        imgB = obs
        #imgB = cv2.resize(imgB, (256, 256))
        imgB = cv2.cvtColor(imgB, cv2.COLOR_BGR2RGB)
        imgB = np.float32(imgB) / 255.
        imgB = torch.from_numpy(imgB).permute(2, 0, 1)
        inputB = imgB.unsqueeze(0)
        inputB = inputB.to(device)
        DataPair["B"] = inputB
        DataPair["B_paths"] = imgB_path
        #cv2.imwrite(imgB_path, obs)
        
        DataPairList.append(DataPair)
        
        print("DataPair_" + str(i))
        
    opt = argparse.Namespace(
        batch_size=1,
        beta1=0.5,
        checkpoints_dir='./Pix2Pix_checkpoints',
        continue_train=False,
        crop_size=256,
        dataroot='./base/',
        dataset_mode='aligned',
        direction='AtoB',
        display_env='main',
        display_freq=400,
        display_id=1,
        display_ncols=4,
        display_port=8097,
        display_server='http://localhost',
        display_winsize=256,
        epoch='latest',
        epoch_count=1,
        gan_mode='vanilla',
        gpu_ids=[0],
        init_gain=0.02,
        init_type='normal',
        input_nc=3,
        isTrain=True,
        lambda_L1=100.0,
        load_iter=0,
        load_size=256,
        lr=0.0002,
        lr_decay_iters=50,
        lr_policy='linear',
        max_dataset_size=float('inf'),
        model='pix2pix',
        n_epochs=100,
        n_epochs_decay=100,
        n_layers_D=3,
        name='Pix2Pix_' + args.exp_name,
        ndf=64,
        netD='basic',
        netG='unet_256', ### unet_256
        ngf=64,
        no_dropout=False,
        no_flip=False,
        no_html=False,
        norm='batch',
        num_threads=4,
        output_nc=3,
        phase='train',
        pool_size=0,
        preprocess = None, #preprocess='resize_and_crop',
        print_freq=100,
        save_by_iter=False,
        save_epoch_freq=50,
        save_latest_freq=5000,
        serial_batches=False,
        suffix='',
        update_html_freq=1000,
        use_wandb=False,
        verbose=False,
        wandb_project_name='CycleGAN-and-pix2pix'
    )
    
    P2PModel = Pix2PixModel(opt)
    
    P2PModel.setup(opt)
    
    #P2PModel = P2PModel.to(device)
    losses = {
        'G_GAN': [],
        'G_L1': [],
        'D_real': [],
        'D_fake': []
    }
    
    for i in range(args.pix2pix_epoches):
    
        for DP in DataPairList:

            P2PModel.set_input(DP)

            P2PModel.optimize_parameters()
            
            loss = P2PModel.get_current_losses()
            
            losses['G_GAN'].append(loss['G_GAN'])
            losses['G_L1'].append(loss['G_L1'])
            losses['D_real'].append(loss['D_real'])
            losses['D_fake'].append(loss['D_fake'])
            
            #pdb.set_trace()
            print(f"Epoch {i+1}/{args.pix2pix_epoches}, G_GAN: {loss['G_GAN']}, G_L1: {loss['G_L1']}, D_real: {loss['D_real']}, D_fake: {loss['D_fake']}")
        
        print("Finish Pix2Pix training epoch " + str(i))
        
    #pdb.set_trace()
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))  # 2行2列的子图
    axes = axes.flatten()

    for i, (loss_name, loss_values) in enumerate(losses.items()):
        axes[i].plot(range(1, len(losses['G_GAN']) + 1), loss_values, label=loss_name, color='b')
        axes[i].set_title(loss_name)
        axes[i].set_xlabel('Epochs')
        axes[i].set_ylabel('Loss')
        axes[i].legend()
        axes[i].grid(True)

    # 自动调整子图的布局
    plt.tight_layout()

    # 保存为图片
    plt.savefig('training_loss_curve_all.png', dpi=300)

    # 关闭图形
    plt.close()
    
    if not os.path.exists('./Pix2Pix_checkpoints/' + 'Pix2Pix_' + args.exp_name):
        os.makedirs('./Pix2Pix_checkpoints/' + 'Pix2Pix_' + args.exp_name)
        
    P2PModel.save_networks('latest')
    
    print("finish checkpoint for policy " + str(BestPolicy))
    
    P2PModel.eval()
    
    
    save_distill_path = './Pix2Pix_results/' + 'Pix2Pix_' + args.exp_name
    
    if not os.path.exists(save_distill_path):
        os.makedirs(save_distill_path)
    
    num = 0
    
    for DP in DataPairList:
        
        P2PModel.set_input(DP)
        P2PModel.test()
        
        visuals = P2PModel.get_current_visuals()
        
        realA = visuals["real_A"]
        realA = torch.clamp(realA, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
        realA = cv2.cvtColor(realA, cv2.COLOR_RGB2BGR)
        realA = img_as_ubyte(realA)
        cv2.imwrite(save_distill_path + "/img_real_A_" + str(num) + ".jpg", realA)
            
        fakeB = visuals["fake_B"]
        fakeB = torch.clamp(fakeB, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
        fakeB = cv2.cvtColor(fakeB, cv2.COLOR_RGB2BGR)
        fakeB = img_as_ubyte(fakeB)
        cv2.imwrite(save_distill_path + "/img_fake_B_" + str(num) + ".jpg", fakeB)
            
        realB = visuals["real_B"]
        realB = torch.clamp(realB, 0, 1).cpu().detach().permute(0, 2, 3, 1).squeeze(0).numpy()
        realB = cv2.cvtColor(realB, cv2.COLOR_RGB2BGR)
        realB = img_as_ubyte(realB)
        cv2.imwrite(save_distill_path + "/img_real_B_" + str(num) + ".jpg", realB)
        num = num + 1
        
        
    

        
    
    

if __name__ == '__main__':
    args = get_args()
    main(args)