import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from skimage import measure
from torch import nn
from torchmetrics.functional import structural_similarity_index_measure as SSIM_accuracy
from models.modelUnet import UNet
from models.modelDiscriminator import linearDiscriminator
from models.modelTwoTailClassifier import TwoTailClassifier_smallerKernel
from training.psnr import PSRN_accuracy
from config import CLASSIFIER_WEIGHTS_PATH

BATCH_SIZE = 256 if torch.cuda.is_available() else 64

class Decomposition(pl.LightningModule):
    def __init__(
            self,
            channels,
            width,
            height,
            g_loss_weight=1,
            classifier_loss_weight=1,
            rec_loss_l1_weight=1,
            mask_background_loss_l1_weight=1,
            mask_background_loss_mse_weight=1,
            mask_th=0.2,
            sigmoid_slope=25,
            sigmoid_const=5,
            classifier_weights_path=CLASSIFIER_WEIGHTS_PATH,
            discriminator_text_remove_weights_path=None,
            discriminator_text_detect_weights_path=None,
            generator_text_remove_weights_path=None,
            generator_text_detect_weights_path=None,
            latent_dim: int = 100,
            lr_gen: float = 0.0002,
            lr_disc: float = 0.0002,
            b1: float = 0.5,
            b2: float = 0.999,
            batch_size: int = BATCH_SIZE,
            **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        data_shape = (channels, width, height)

        # the weights for each loss
        self.g_loss_weight = g_loss_weight
        self.classifier_loss_weight = classifier_loss_weight
        self.rec_loss_l1_weight = rec_loss_l1_weight
        self.mask_background_loss_l1_weight = mask_background_loss_l1_weight
        self.mask_background_loss_mse_weight = mask_background_loss_mse_weight
        self.mask_th = mask_th
        self.sigmoid_slope = sigmoid_slope
        self.sigmoid_const = sigmoid_const

        self.loss_mse = nn.MSELoss()
        self.loss_l1 = nn.L1Loss()

        # --------------- The network ----------------:
        # the input of UNet is the num of channels in the input and at the output
        self.generatorTextRemove = UNet(2 * channels, channels)
        self.generatorTextDetect = UNet(2 * channels, channels)
        if generator_text_remove_weights_path is not None:
            self.generatorTextRemove.load_from_checkpoint(generator_text_remove_weights_path)
        if generator_text_detect_weights_path is not None:
            self.generatorTextDetect.load_from_checkpoint(generator_text_detect_weights_path)

        # Discriminator
        self.discriminatorTextRemove = linearDiscriminator(img_shape=data_shape)
        self.discriminatorTextDetect = linearDiscriminator(img_shape=data_shape)
        if discriminator_text_remove_weights_path is not None:
            self.discriminatorTextRemove.load_from_checkpoint(discriminator_text_remove_weights_path)
        if discriminator_text_detect_weights_path is not None:
            self.discriminatorTextDetect.load_from_checkpoint(discriminator_text_detect_weights_path)

        # Classifier
        self.classifier = TwoTailClassifier_smallerKernel()
        self.classifier.load_from_checkpoint(classifier_weights_path)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def configure_optimizers(self):
        lr_gen = self.hparams.lr_gen
        lr_disc = self.hparams.lr_disc
        b1 = self.hparams.b1
        b2 = self.hparams.b2

        opt_g_text_remove = torch.optim.Adam(self.generatorTextRemove.parameters(), lr=lr_gen, betas=(b1, b2))
        opt_g_text_detect = torch.optim.Adam(self.generatorTextDetect.parameters(), lr=lr_gen, betas=(b1, b2))
        opt_d_text_detect = torch.optim.Adam(self.discriminatorTextDetect.parameters(), lr=lr_disc, betas=(b1, b2))
        opt_d_text_remove = torch.optim.Adam(self.discriminatorTextRemove.parameters(), lr=lr_disc, betas=(b1, b2))
        # schedulers
        scheduler_g_text_remove = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_g_text_remove)
        scheduler_g_text_detect = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_g_text_detect)
        scheduler_d_text_detect = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_d_text_detect)
        scheduler_d_text_remove = torch.optim.lr_scheduler.ReduceLROnPlateau(opt_d_text_remove)

        return [opt_g_text_remove, opt_g_text_detect, opt_d_text_detect, opt_d_text_remove], \
            [{"scheduler": scheduler_g_text_remove, "monitor": "G_loss_text_remove", "name": "g_scheduler_loss_remove"},
             {"scheduler": scheduler_g_text_detect, "monitor": "G_loss_text_detect", "name": "g_scheduler_loss_detect"},
             {"scheduler": scheduler_d_text_detect, "monitor": "d_detect_loss", "name": "d_scheduler_loss_detect"},
             {"scheduler": scheduler_d_text_remove, "monitor": "d_remove_loss", "name": "d_scheduler_loss_remove"}
             ]

    def training_step(self, batch, batch_idx, optimizer_idx):

        # Load the images from the batch
        image_with_text, image_with_text_clean, image_gradCAM_cat, image_without_text, image_only_text \
            = batch['image_with_text'], batch['image_with_text_clean'], batch[
            'image_gradCAM_cat'], batch['image_without_text'], batch['image_only_text']

        # Text removal generator training - optimizer_idx == 0
        if optimizer_idx == 0:
            # Generate the images in both generators
            fake_imgs_text_remove, fake_imgs_text_detect = self.img_generation(image_with_text, image_gradCAM_cat)

            # -------------- mask background loss L1 ---------------
            mask_background_loss_l1 = self.mask_background_loss_calc_l1(fake_imgs_text_remove, image_gradCAM_cat)
            self.log("mask_background_loss_l1", mask_background_loss_l1)

            # -------------- mask background loss MSE ---------------
            mask_background_loss_mse = self.mask_background_loss_calc_mse(fake_imgs_text_remove, image_gradCAM_cat)
            self.log("mask_background_loss_mse", mask_background_loss_mse)

            # ----------- adversarial loss -----------
            g_adversarial_loss_text_remove = self.g_text_remove_adversarial_loss_calc(fake_imgs_text_remove,
                                                                                      image_with_text)
            self.log("g_adversarial_loss_text_remove", g_adversarial_loss_text_remove)

            # ----------- reconstructed loss -----------
            rec_l1_loss = self.rec_loss_l1_calc(fake_imgs_text_remove, fake_imgs_text_detect, image_with_text)
            self.log("rec_l1_loss", rec_l1_loss)

            G_loss_text_remove = (g_adversarial_loss_text_remove +
                                  self.mask_background_loss_mse_weight * mask_background_loss_mse +
                                  self.mask_background_loss_l1_weight * mask_background_loss_l1 +
                                  self.rec_loss_l1_weight * rec_l1_loss)

            self.log("G_loss_text_remove", G_loss_text_remove)

            # train accuracy
            train_acc_psnr, train_std_psnr = PSRN_accuracy(fake_imgs_text_remove, image_with_text_clean)
            train_acc_ssim = SSIM_accuracy(fake_imgs_text_remove, image_with_text_clean)
            self.log("train_acc_psnr", train_acc_psnr)
            self.log("train_std_psnr", train_std_psnr)
            self.log("train_acc_ssim", train_acc_ssim.mean())
            self.log("train_std_ssim", train_acc_ssim.std())

            # ---------------------- Classifier loss -----------------------------
            classifier_loss = self.classifier_loss_calc(fake_imgs_text_remove, image_with_text)
            self.log("classifier_loss", classifier_loss)

            total_loss = (G_loss_text_remove * self.g_loss_weight + classifier_loss * self.classifier_loss_weight) / (
                                 (self.g_loss_weight + self.classifier_loss_weight) * 1.0)
            self.log("total_loss_g_class", total_loss)

            log_dict = {"total_loss": total_loss, 'classifier_loss': classifier_loss}
            return {"loss": total_loss, "G_loss_text_remove": G_loss_text_remove, "progress_bar": log_dict,
                    "log": log_dict}

        # train generators - optimizer_idx == 1 - text detect
        elif optimizer_idx == 1:
            fake_imgs_text_remove, fake_imgs_text_detect = self.img_generation(image_with_text, image_gradCAM_cat)

            # adversarial loss calculation
            y_hat_text_detect = self.discriminatorTextDetect(fake_imgs_text_detect)

            y = torch.ones(image_with_text.size(0), 1)
            y = y.type_as(image_with_text)

            g_adversarial_loss_text_detect = self.adversarial_loss(y_hat_text_detect, y)
            self.log("g_adversarial_loss_text_detect", g_adversarial_loss_text_detect)

            # ----------- reconstructed loss -----------
            rec_l1_loss = self.rec_loss_l1_calc(fake_imgs_text_remove, fake_imgs_text_detect, image_with_text)
            self.log("rec_l1_loss", rec_l1_loss)

            G_loss_text_detect = g_adversarial_loss_text_detect + self.rec_loss_l1_weight * rec_l1_loss
            self.log("G_loss_text_detect", G_loss_text_detect)
            #
            log_dict = {"G_loss_text_detect": G_loss_text_detect}
            return {"loss": G_loss_text_detect, 'G_loss_text_detect': G_loss_text_detect, "progress_bar": log_dict,
                    "log": log_dict}

        # train discriminator - text remove
        elif optimizer_idx == 3:

            # how well can the discriminator label it as real
            y_hat_real_text_remove = self.discriminatorTextRemove(image_without_text)

            y_real = torch.ones(image_with_text.size(0), 1)
            y_real = y_real.type_as(image_with_text)

            real_loss_text_remove = self.adversarial_loss(y_hat_real_text_remove, y_real)
            # real_loss_text_remove.requires_grad = True

            # how well can the discriminator label it as fake
            fake_generated_text_remove = self.generatorTextRemove(image_gradCAM_cat).detach()

            y_hat_fake_text_remove = self.discriminatorTextRemove(fake_generated_text_remove)

            y_fake = torch.zeros(image_with_text.size(0), 1)
            y_fake = y_fake.type_as(image_with_text)

            fake_loss_text_remove = self.adversarial_loss(y_hat_fake_text_remove, y_fake)
            # fake_loss_text_remove.requires_grad = True

            # discriminator loss is the average of these
            d_remove_loss = (real_loss_text_remove + fake_loss_text_remove) / 2.0

            self.log("d_remove_loss", d_remove_loss)
            self.log("real_loss_text_remove", real_loss_text_remove)
            self.log("fake_loss_text_remove", fake_loss_text_remove)

            log_dict = {"d_remove_loss": d_remove_loss}
            return {"loss": d_remove_loss, "d_remove_loss": d_remove_loss, "progress_bar": log_dict, "log": log_dict}

        # train discriminator - text detect
        elif optimizer_idx == 2:

            # how well can the discriminator label it as real
            y_hat_real_text_detect = self.discriminatorTextDetect(image_only_text)

            y_real = torch.ones(image_with_text.size(0), 1)
            y_real = y_real.type_as(image_with_text)

            real_loss_text_detect = self.adversarial_loss(y_hat_real_text_detect, y_real)

            # how well can the discriminator label it as fake
            fake_generated_text_detect = self.generatorTextDetect(image_gradCAM_cat).detach()

            y_hat_fake_text_detect = self.discriminatorTextDetect(fake_generated_text_detect)

            y_fake = torch.zeros(image_with_text.size(0), 1)
            y_fake = y_fake.type_as(image_with_text)

            fake_loss_text_detect = self.adversarial_loss(y_hat_fake_text_detect, y_fake)

            # discriminator loss is the average of these
            d_detect_loss = (real_loss_text_detect + fake_loss_text_detect) / 2.0

            self.log("d_detect_loss", d_detect_loss)
            self.log("real_loss_text_detect", real_loss_text_detect)
            self.log("fake_loss_text_detect", fake_loss_text_detect)

            log_dict = {"d_detect_loss": d_detect_loss}
            return {"loss": d_detect_loss, "d_detect_loss": d_detect_loss, "progress_bar": log_dict, "log": log_dict}

    def validation_step(self, batch, batch_idx):
        # in each batch there are 5 images -
        # 1. with text
        # 2. the same one without text
        # 3. the gradCAM image of the image with text from 1. this variable has 2 channels - one is the image with text
        #    and the second channel is the gradCAM img.
        # 4. another image without text
        # 5. an image with black background and only
        # text (different frame than the first two ones)

        image_with_text, image_with_text_clean, image_gradCAM_cat, image_without_text, image_only_text \
            = batch['image_with_text'], batch['image_with_text_clean'], batch[
            'image_gradCAM_cat'], batch['image_without_text'], batch['image_only_text']

        fake_imgs_text_remove, fake_imgs_text_detect = self.img_generation(image_with_text, image_gradCAM_cat)

        # calc accuracy
        val_acc_psnr, val_std_psnr = PSRN_accuracy(fake_imgs_text_remove, image_with_text_clean)
        val_acc_ssim = SSIM_accuracy(fake_imgs_text_remove, image_with_text_clean)
        self.log("val_acc_psnr", val_acc_psnr)
        self.log("val_std_psnr", val_std_psnr)
        self.log("val_acc_ssim", val_acc_ssim.mean())
        self.log("val_std_ssim", val_acc_ssim.std())

        # --------------------------------- loss calc ---------------------------------
        # -------------- mask loss ---------------
        val_mask_background_loss_l1 = self.mask_background_loss_calc_l1(fake_imgs_text_remove, image_gradCAM_cat)
        self.log("val_mask_background_loss_l1", val_mask_background_loss_l1)

        # -------------- mask loss MSE ---------------
        val_mask_background_loss_mse = self.mask_background_loss_calc_mse(fake_imgs_text_remove, image_gradCAM_cat)
        self.log("val_mask_background_loss_mse", val_mask_background_loss_mse)

        # adversarial loss calculation
        val_g_adversarial_loss_text_remove = self.g_text_remove_adversarial_loss_calc(fake_imgs_text_remove,
                                                                                      image_with_text)
        self.log("val_g_adversarial_loss_text_remove", val_g_adversarial_loss_text_remove)

        # reconstructed loss
        val_rec_l1_loss = self.rec_loss_l1_calc(fake_imgs_text_remove, fake_imgs_text_detect, image_with_text)
        self.log("val_rec_l1_loss", val_rec_l1_loss)

        val_G_loss_text_remove = (val_g_adversarial_loss_text_remove +
                                  self.mask_background_loss_l1_weight * val_mask_background_loss_l1 +
                                  self.mask_background_loss_mse_weight * val_mask_background_loss_mse +
                                  self.rec_loss_l1_weight * val_rec_l1_loss)

        self.log("val_G_loss_text_remove", val_G_loss_text_remove)

        # classifier loss

        val_classifier_loss = self.classifier_loss_calc(fake_imgs_text_remove, image_with_text)
        self.log("val_classifier_loss", val_classifier_loss)

        val_total_loss = (val_G_loss_text_remove * self.g_loss_weight + val_classifier_loss * self.classifier_loss_weight) / (
                                 (self.g_loss_weight + self.classifier_loss_weight) * 1.0)
        self.log("val_total_loss_g_class", val_total_loss)

        y_hat_text_remove = self.discriminatorTextRemove(fake_imgs_text_remove)

        y = torch.ones(image_with_text.size(0), 1)
        y = y.type_as(image_with_text)

        # how well can the discriminator can label it as real
        y_hat_real_text_remove = self.discriminatorTextRemove(image_without_text)
        val_real_loss_text_remove = self.adversarial_loss(y_hat_real_text_remove, y)

        # how well can the discriminator can label it as fake
        y_fake = torch.zeros(image_with_text.size(0), 1)
        y_fake = y_fake.type_as(image_with_text)
        val_fake_loss_text_remove = self.adversarial_loss(y_hat_text_remove, y_fake)

        # discriminator loss is the average of these
        val_d_remove_loss = (val_real_loss_text_remove + val_fake_loss_text_remove) / 2.0

        self.log("val_d_remove_loss", val_d_remove_loss)
        self.log("val_real_loss_text_remove", val_real_loss_text_remove)
        self.log("val_fake_loss_text_remove", val_fake_loss_text_remove)

    def test_step(self, batch, batch_idx):
        # in each batch there are 5 images -
        # 1. with text
        # 2. the same one without text
        # 3. the gradCAM image of the image with text from 1. this variable has 2 channels - one is the image with text
        #    and the second channel is the gradCAM img.
        # 4. another image without text
        # 5. an image with black background and only
        # text (different frame than the first two ones)

        image_with_text, image_with_text_clean, image_gradCAM_cat, image_without_text, image_only_text \
            = batch['image_with_text'], batch['image_with_text_clean'], batch[
            'image_gradCAM_cat'], batch['image_without_text'], batch['image_only_text']

        fake_imgs_text_remove, fake_imgs_text_detect = self.img_generation(image_with_text, image_gradCAM_cat)
        img_reconstructed = self.generated_img_reconstructed(fake_imgs_text_remove, image_gradCAM_cat)

        # calc accuracy for the all batch - the rec images
        test_acc_psnr_rec, test_std_psnr_rec = PSRN_accuracy(img_reconstructed, image_with_text_clean)
        test_acc_ssim_rec = SSIM_accuracy(img_reconstructed, image_with_text_clean, reduction='none')
        self.log("test_acc_psnr_rec", test_acc_psnr_rec)
        self.log("test_acc_ssim_rec", test_acc_ssim_rec.mean())
        self.log("test_std_psnr_rec", test_std_psnr_rec)
        self.log("test_std_ssim_rec", test_acc_ssim_rec.std())

        ground_psnr, ground_psnr_std = PSRN_accuracy(image_with_text, image_with_text_clean)
        ground_ssim = SSIM_accuracy(image_with_text, image_with_text_clean, reduction='none')
        self.log("psnr ground", ground_psnr)
        self.log("ssim ground", ground_ssim.mean())
        self.log("ssim ground std", ground_ssim.std())
        self.log("psnr ground std", ground_psnr_std)

    def one_hot_ground_labels_generate(self, images):
        # Set the labels
        labels = torch.zeros(images.size(0), dtype=torch.int64)
        # Convert the labels vector to one hot vector
        labels_ground = F.one_hot(labels, num_classes=2)
        labels_ground = labels_ground.type_as(images)
        return labels_ground

    def classifier_loss_calc(self, fake_imgs_text_remove, image_with_text):
        label_hat = self.classifier(fake_imgs_text_remove)
        label = self.one_hot_ground_labels_generate(fake_imgs_text_remove)
        return F.binary_cross_entropy_with_logits(label_hat, label)

    def img_generation(self, image_with_text, image_gradCAM_cat):
        fake_imgs_text_detect = self.generatorTextDetect(image_gradCAM_cat)
        fake_imgs_text_remove = self.generatorTextRemove(image_with_text)
        return fake_imgs_text_remove, fake_imgs_text_detect

    def g_text_remove_adversarial_loss_calc(self, fake_imgs_text_remove, image_with_text):
        y_hat_text_remove = self.discriminatorTextRemove(fake_imgs_text_remove)

        y = torch.ones(image_with_text.size(0), 1)
        y = y.type_as(image_with_text)

        g_adversarial_loss_text_remove = self.adversarial_loss(y_hat_text_remove, y)

        return g_adversarial_loss_text_remove

    def rec_loss_l1_calc(self, fake_imgs_text_remove, fake_imgs_text_detect, image_with_text):
        mask = torch.zeros_like(fake_imgs_text_detect)
        mask[fake_imgs_text_detect > self.mask_th] = 1
        rec_img = fake_imgs_text_remove * (1 - mask) + fake_imgs_text_detect * mask
        rec_loss = self.loss_l1(rec_img, image_with_text)

        return rec_loss

    def apply_custom_sigmoid_to_heatmap(self, batch_of_XAImaps, slope=1.0, const=0):
        # Apply the custom sigmoid function to each element in the tensor
        custom_sigmoid_heatmaps = (1 / (1 + torch.exp(-slope * batch_of_XAImaps + const)))

        return custom_sigmoid_heatmaps

    def generated_img_reconstructed(self, fake_imgs_text_remove, image_gradCAM_cat):
        # Split the tensor along the second dimension
        original_img, xai_map = torch.split(image_gradCAM_cat, split_size_or_sections=1, dim=1)
        custom_xai_map = self.apply_custom_sigmoid_to_heatmap(xai_map, self.sigmoid_slope, self.sigmoid_const)
        rec_img = fake_imgs_text_remove * custom_xai_map + original_img * (1 - custom_xai_map)
        return rec_img

    def mask_background_loss_calc_l1(self, fake_imgs_text_remove, image_gradCAM_cat):
        # Split the tensor along the second dimension
        input_img, xai_map = torch.split(image_gradCAM_cat, split_size_or_sections=1, dim=1)
        custom_xai_map = self.apply_custom_sigmoid_to_heatmap(xai_map, self.sigmoid_slope, self.sigmoid_const)
        mask_generator_out = (1 - custom_xai_map) * fake_imgs_text_remove
        mask_input = input_img * (1 - custom_xai_map)
        return self.loss_l1(mask_generator_out, mask_input)

    def mask_background_loss_calc_mse(self, fake_imgs_text_remove, image_gradCAM_cat):
        # Split the tensor along the second dimension
        input_img, xai_map = torch.split(image_gradCAM_cat, split_size_or_sections=1, dim=1)
        custom_xai_map = self.apply_custom_sigmoid_to_heatmap(xai_map, self.sigmoid_slope, self.sigmoid_const)
        mask_generator_out = (1 - custom_xai_map) * fake_imgs_text_remove
        mask_input = input_img * (1 - custom_xai_map)
        return self.loss_mse(mask_generator_out, mask_input)


    def contour_above_threshold(self, xaiMAP, threshold=0.5):
        # Create a binary mask for values above the threshold
        mask = (xaiMAP > threshold).astype(int)
        # Find contours using skimage's find_contours
        contours = measure.find_contours(mask, 0.5)
        # Choose the longest contour (assuming it represents the outer boundary)
        longest_contour = max(contours, key=len) if contours else None
        return longest_contour