import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import torch.nn as nn
# from lightning.pytorch.accelerators import find_usable_cuda_devices
import numpy as np
from models.modelTwoTailClassifier import TwoTailClassifier_smallerKernel
from torchmetrics import Accuracy


BATCH_SIZE = 256 if torch.cuda.is_available() else 64


class TwoTailClassifierPL(pl.LightningModule):
    def __init__(
            self,
            latent_dim: int = 100,
            lr: float = 0.0002,
            b1: float = 0.5,
            b2: float = 0.999,
            weight_decay: float = 1e-5,
            batch_size: int = BATCH_SIZE,
            **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.test_acc_step_out = []
        self.accuracy = Accuracy(task='multiclass', num_classes=2)

        # --------------- The network ----------------
        self.classifier = TwoTailClassifier_smallerKernel()

    def forward(self, x):
        x = self.classifier.pool(F.relu(self.classifier.conv1(x)))
        x = self.classifier.pool(F.relu(self.classifier.conv2(x)))
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.classifier.fc1(x))
        x = F.relu(self.classifier.fc2(x))
        x = self.classifier.fc3(x)
        return x

    def training_step(self, batch):
        # in each batch there are 5 images -
        # 1. with text
        # 2. the same one without text
        # 3. the gradCAM image of the image with text from 1. this variable has 2 channels - one is the image with text
        #    and the second channel is the gradCAM img.
        # 4. another image without text
        # 5. an image with black background and only text (different frame than the first two ones)

        image_with_text, image_with_text_clean, image_gradCAM_cat, image_without_text, image_only_text \
            = batch['image_with_text'], batch['image_with_text_clean'], batch[
            'image_gradCAM_cat'], batch['image_without_text'], batch['image_only_text']

        # Set the labels
        labels_ground, pred_labels, _ = self.ground_and_pred_labels_generate(image_with_text, image_without_text)
        # Loss calc
        loss = nn.functional.binary_cross_entropy_with_logits(pred_labels, labels_ground)
        self.log("loss", loss)

        # ---------------------- Accuracy calc ----------------------------
        acc = self.acc_calc(labels_ground, pred_labels)
        self.log("train_acc", acc)

        log_dict = {"loss": loss, "acc": acc}
        return {"loss": loss, "acc": acc, "progress_bar": log_dict, "log": log_dict}

    def validation_step(self, batch):

        image_with_text, image_with_text_clean, image_gradCAM_cat, image_without_text, image_only_text \
            = batch['image_with_text'], batch['image_with_text_clean'], batch[
            'image_gradCAM_cat'], batch['image_without_text'], batch['image_only_text']

        # Set the labels
        labels_ground, pred_labels, _ = self.ground_and_pred_labels_generate(image_with_text, image_without_text)
        # Loss calc
        val_loss = nn.functional.binary_cross_entropy_with_logits(pred_labels, labels_ground)
        self.log("val_loss", val_loss)

        # ---------------------- Accuracy calc ----------------------------
        val_acc = self.acc_calc(labels_ground, pred_labels)
        self.log("val_acc", val_acc)

        log_dict = {"val_loss": val_loss, "val_acc": val_acc}
        return {"val_loss": val_loss, "val_acc": val_acc, "progress_bar": log_dict, "log": log_dict}

    def test_step(self, batch):
        image_with_text, image_with_text_clean, image_gradCAM_cat, image_without_text, image_only_text \
            = batch['image_with_text'], batch['image_with_text_clean'], batch[
            'image_gradCAM_cat'], batch['image_without_text'], batch['image_only_text']

        # Set the labels
        labels_ground, pred_labels, images = self.ground_and_pred_labels_generate(image_with_text, image_without_text)
        # Calc the acc
        test_acc = self.acc_calc(labels_ground, pred_labels)

        self.test_acc_step_out.append(test_acc)
        self.log("test_acc", test_acc)

    def on_test_end(self) -> None:
        test_acc = np.mean(self.test_acc_step_out)
        self.test_acc_step_out.clear()
        print("the total test acc is: ", test_acc)

    def configure_optimizers(self):
        lr = self.hparams.lr
        b1 = self.hparams.b1
        b2 = self.hparams.b2
        # tried weight_decay which did not improve the results

        # optimizers
        opt = torch.optim.Adam(self.classifier.parameters(), lr=lr, betas=(b1, b2))
        # schedulers
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt)

        return [opt],\
               [{"scheduler": scheduler, "monitor": "loss", "name": "scheduler_loss"}]

    def ground_and_pred_labels_generate(self, image_with_text, image_without_text):
        # Set the labels
        labels_text = torch.ones(image_with_text.size(0), dtype=torch.int64)
        labels_no_text = torch.zeros(image_with_text.size(0), dtype=torch.int64)
        # Concatenate the labels to one vector of ones and zeros
        labels_ground = torch.cat((labels_text, labels_no_text), dim=0)
        # Convert the labels vector to one hot vector
        labels_ground = F.one_hot(labels_ground, num_classes=2)
        labels_ground = labels_ground.type_as(image_with_text)

        # Concatenate the image to one vector
        images = torch.cat((image_with_text, image_without_text), dim=0)
        # Shuffle the lists of names and labels to get random batch
        indices = torch.randperm(images.size()[0])
        images = images[indices]
        labels_ground = labels_ground[indices]

        pred_labels = self.forward(images)

        return labels_ground, pred_labels, images

    def acc_calc(self, labels_ground, pred_labels):
        pred_labels = torch.sigmoid(pred_labels)
        _, pred_labels = torch.max(pred_labels, 1)

        _, labels_ground = torch.max(labels_ground, 1)
        sum_correct = (torch.sum(torch.eq(pred_labels, labels_ground))).item()  # predicted correctly
        #
        acc = sum_correct / labels_ground.size(0)

        return acc


