#!/usr/bin/env python

import argparse
import collections
import os
import time
import tqdm, cv2
import numpy as np
from PIL import Image
import torchvision.transforms as transforms
from b2b.utils import util

from b2b.consts import MERGE_NONE
from b2b.eval.funcs import (
    load_eval_model_dset_from_cmdargs, tensor_to_image, tensor2img, slice_data_loader,
    get_eval_savedir, make_image_subdirs
)
from b2b.utils.parsers import (
    add_standard_eval_parsers, add_plot_extension_parser
)
import torch
import torch.nn.functional as F
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from math import log10

def metric(img1, img2):
    img1_np = img1.cpu().detach().numpy().transpose(1, 2, 0)
    img2_np = img2.cpu().detach().numpy().transpose(1, 2, 0)

    # Calculate PSNR using PyTorch
    # mse = F.mse_loss(img1, img2).item()
    # psnr = 10 * log10(1 / mse)
    psnr_val = psnr(img1_np, img2_np)

    # Calculate SSIM using scikit-image
    ssim_val = ssim(img1_np, img2_np, multichannel=True)
    return psnr_val, ssim_val

def parse_cmdargs():
    parser = argparse.ArgumentParser(
        description = 'Save model predictions as images'
    )
    add_standard_eval_parsers(parser)
    add_plot_extension_parser(parser)

    return parser.parse_args()

def save_images(model, savedir, sample_counter, ext, batch):
    for (name, torch_image) in model.images.items():
        if torch_image is None:
            continue

        for index in range(torch_image.shape[0]):
            sample_index = sample_counter[name]
            # breakpoint()
            # im_data = transforms.functional.crop(torch_image[index], 0, 0, int(batch['ori_size'][1]), int(batch['ori_size'][0]))
            image = tensor2img(torch_image[index])
            # image = tensor_to_image(torch_image[index])
            # image = np.round(255 * image).astype(np.uint8)
            # image = Image.fromarray(image)

            # path  = os.path.join(savedir, name, f'sample_{sample_index}')
            # breakpoint()
            path  = os.path.join(savedir, name, f"{batch['name'][0].split('/')[-1].split('.')[0]}")
            for e in ext:
                # image.save(path + '.' + e)
                cv2.imwrite(path + '.' + e, image)

            sample_counter[name] += 1

def dump_single_domain_images(
    model, data_it, domain, n_eval, batch_size, savedir, sample_counter, ext
):
    # pylint: disable=too-many-arguments
    data_it, steps = slice_data_loader(data_it, batch_size, n_eval)
    psnr_fb, ssim_fb = 0.0, 0.0
    psnr_ra, ssim_ra = 0.0, 0.0
    cnt = 0
    for batch in tqdm.tqdm(data_it, desc = f'Dumping {domain}', total = steps):
        # breakpoint()
        if "00009" not in batch['name'][0]: continue
        # print()
        model.set_input(batch['img'], domain = domain)
        start_time = time.time()
        model.forward_nograd()

        end_time = time.time()

        time_taken = end_time - start_time
        print(f'Time taken for inference: {time_taken} seconds')

        save_images(model, savedir, sample_counter, ext, batch)
        # psnr, ssim = metric(model.images['fake_b'][0], model.images['deb_fake_b'][0])
        # psnr_fb += psnr
        
        # ssim_fb += ssim

        # psnr, ssim = metric(model.images['real_a'][0], model.images['deb_real_a'][0])
        # psnr_ra += psnr
        # ssim_ra += ssim
        # cnt += 1

    # print("==============================================")
    # print(f"FAKE_B (psnr = {psnr_fb / cnt}, ssim = {ssim_fb / cnt})")
    # print(f"REAL_A (psnr = {psnr_ra / cnt}, ssim = {ssim_ra / cnt})")
    # print("==============================================")
    # breakpoint()

def dump_images(model, data_list, n_eval, batch_size, savedir, ext):
    # pylint: disable=too-many-arguments
    make_image_subdirs(model, savedir)

    sample_counter = collections.defaultdict(int)
    if isinstance(ext, str):
        ext = [ ext, ]
    
    for domain, data_it in enumerate(data_list):
        # breakpoint()
        # if domain == 0: continue
        dump_single_domain_images(
            model, data_it, domain, n_eval, batch_size, savedir,
            sample_counter, ext
        )
        

def main():
    cmdargs = parse_cmdargs()
    # breakpoint()
    args, model, data_list, evaldir = load_eval_model_dset_from_cmdargs(
        cmdargs, merge_type = MERGE_NONE
    )

    if not isinstance(data_list, (list, tuple)):
        data_list = [ data_list, ]

    savedir = get_eval_savedir(
        evaldir,  cmdargs.prefix , cmdargs.model_state, cmdargs.split
    )

    dump_images(
        model, data_list, cmdargs.n_eval, args.batch_size, savedir,
        cmdargs.ext
    )

if __name__ == '__main__':
    main()

