import os
import math
import warnings
from typing import List
from argparse import Namespace

import torch
import pytorch_lightning as pl
from torch.utils.data import DataLoader
from torchvision import transforms as T
from torchvision.datasets import CIFAR10
from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
#from pytorch_lightning.metrics import Accuracy
from torchmetrics.classification import Accuracy
from torch.optim.lr_scheduler import _LRScheduler

from src.cifar10_models.densenet import densenet121, densenet161, densenet169
from src.cifar10_models.googlenet import googlenet
from src.cifar10_models.inception import inception_v3
from src.cifar10_models.mobilenetv2 import mobilenet_v2
from src.cifar10_models.resnet import resnet18, resnet34, resnet50
from src.cifar10_models.vgg import vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn


class WarmupCosineLR(_LRScheduler):
    def __init__(self, optimizer, warmup_epochs: int, max_epochs: int, warmup_start_lr: float = 1e-8, eta_min: float = 1e-8, last_epoch: int = -1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        self.warmup_start_lr = warmup_start_lr
        self.eta_min = eta_min
        super(WarmupCosineLR, self).__init__(optimizer, last_epoch)
    def get_lr(self) -> List[float]:
        if not self._get_lr_called_within_step:
            warnings.warn("To get the last learning rate computed by the scheduler, please use `get_last_lr()`.", UserWarning)
        if self.last_epoch == 0:
            return [self.warmup_start_lr] * len(self.base_lrs)
        elif self.last_epoch < self.warmup_epochs:
            return [group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)]
        elif self.last_epoch == self.warmup_epochs:
            return self.base_lrs
        elif (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
            return [group["lr"] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)]
        return [((1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) / (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs))) * (group["lr"] - self.eta_min) + self.eta_min) for group in self.optimizer.param_groups]
    def _get_closed_form_lr(self) -> List[float]:
        if self.last_epoch < self.warmup_epochs:
            return [self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) for base_lr in self.base_lrs]
        return [self.eta_min + 0.5 * (base_lr - self.eta_min) * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) for base_lr in self.base_lrs]



class CIFAR10Data(pl.LightningDataModule):
    def __init__(self, args):
        super().__init__()
        self.save_hyperparameters(vars(args))
        self.cfg = args
        self.mean = (0.4914, 0.4822, 0.4465)
        self.std = (0.2471, 0.2435, 0.2616)

    def train_dataloader(self):
        t = T.Compose([T.RandomCrop(32, padding=4), T.RandomHorizontalFlip(), T.ToTensor(), T.Normalize(self.mean, self.std)])
        ds = CIFAR10(root=self.cfg.data_dir, download=True, train=True, transform=t)
        return DataLoader(ds, batch_size=self.cfg.batch_size, num_workers=self.cfg.num_workers, shuffle=True, drop_last=True, pin_memory=True)

    def val_dataloader(self):
        t = T.Compose([T.ToTensor(), T.Normalize(self.mean, self.std)])
        ds = CIFAR10(root=self.cfg.data_dir, download=True, train=False, transform=t)
        return DataLoader(ds, batch_size=self.cfg.batch_size, num_workers=self.cfg.num_workers, drop_last=True, pin_memory=True)

    def test_dataloader(self):
        return self.val_dataloader()


def _build_model_dict():
    return {
        "vgg11_bn": vgg11_bn(),
        "vgg13_bn": vgg13_bn(),
        "vgg16_bn": vgg16_bn(),
        "vgg19_bn": vgg19_bn(),
        "resnet18": resnet18(),
        "resnet34": resnet34(),
        "resnet50": resnet50(),
        "densenet121": densenet121(),
        "densenet161": densenet161(),
        "densenet169": densenet169(),
        "mobilenet_v2": mobilenet_v2(),
        "googlenet": googlenet(),
        "inception_v3": inception_v3(),
    }



class CIFAR10Module(pl.LightningModule):
    def __init__(self, hparams):
        super().__init__()
        self.save_hyperparameters(vars(hparams))
        self.cfg = hparams
        self.criterion = torch.nn.CrossEntropyLoss()
        self.accuracy = Accuracy(task="multiclass", num_classes=10)
        self.model = _build_model_dict()[self.cfg.classifier]

    def forward(self, batch):
        images, labels = batch
        predictions = self.model(images)
        loss = self.criterion(predictions, labels)
        acc = self.accuracy(predictions, labels) * 100
        return loss, acc

    def training_step(self, batch, batch_nb):
        loss, acc = self.forward(batch)
        self.log("loss/train", loss)
        self.log("acc/train", acc)
        return loss

    def validation_step(self, batch, batch_nb):
        loss, acc = self.forward(batch)
        self.log("loss/val", loss)
        self.log("acc/val", acc)

    def test_step(self, batch, batch_nb):
        loss, acc = self.forward(batch)
        self.log("acc/test", acc)

    def configure_optimizers(self):
        opt = torch.optim.SGD(self.model.parameters(),
                              lr=self.cfg.learning_rate,
                              weight_decay=self.cfg.weight_decay,
                              momentum=0.9, nesterov=True)
        total_steps = self.cfg.max_epochs * len(self.trainer.datamodule.train_dataloader())
        sch = {"scheduler": WarmupCosineLR(opt, warmup_epochs=int(total_steps*0.3), max_epochs=total_steps),
               "interval": "step", "name": "learning_rate"}
        return [opt], [sch]


def train_and_return_model(classifier: str):
    args = Namespace(
        data_dir="/data/train/cifar10",
        test_phase=0,
        dev=0,
        logger=False,
        enable_checkpointing=False,
        classifier=classifier,
        pretrained=0,
        precision=32,
        batch_size=256,
        max_epochs=100,
        num_workers=8,
        gpu_id="0",
        learning_rate=1e-2,
        weight_decay=1e-2,
    )

    pl.seed_everything(0)

    has_cuda = (torch.version.cuda is not None) and torch.cuda.is_available() and torch.cuda.device_count() > 0
    has_mps = getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available()

    if has_cuda:
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id
        accelerator, devices = "gpu", 1
        precision_arg = "32-true" if args.precision == 32 else "16-mixed"
    elif has_mps:
        accelerator, devices = "mps", 1
        precision_arg = "32-true"
    else:
        accelerator, devices = "cpu", 1
        precision_arg = "32-true"

    logger = (
        WandbLogger(name=args.classifier, project="cifar10")
        if args.logger == "wandb"
        else TensorBoardLogger("cifar10", name=args.classifier)
        if args.logger == "tensorboard"
        else None
    )

    checkpoint = ModelCheckpoint(monitor="acc/val", mode="max", save_last=False)

    trainer = pl.Trainer(
        fast_dev_run=bool(args.dev),
        logger=False, #logger=logger if not bool(args.dev + args.test_phase) else None,
        enable_checkpointing=False, 
        callbacks=[], # [checkpoint]
        accelerator=accelerator,
        devices=devices,
        deterministic=True,
        log_every_n_steps=1,
        max_epochs=args.max_epochs,
        precision=precision_arg,
    )

    lit_model = CIFAR10Module(args)
    data = CIFAR10Data(args)

    if bool(args.pretrained):
        state_dict = os.path.join("cifar10_models", "state_dicts", args.classifier + ".pt")
        lit_model.model.load_state_dict(torch.load(state_dict, map_location="cpu"))

    if bool(args.test_phase):
        trainer.test(lit_model, data.test_dataloader())
        return lit_model.model

    trainer.fit(lit_model, data)
    trainer.test(lit_model, datamodule=data)
    return lit_model.model
