import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning import Trainer

from training.Decomposition import Decomposition
from data.dataLoaderPL import DataModule
from training.DetectionAndLocalization import TwoTailClassifierPL
from models.gradCAM import gradcam, gradcam_ensambel
from config import DIRPATH_DECOMPOSITION, TRIAL_NAME_DECOMPOSITION, LOGIN_KEY, WANDB_PROJECT_DECOMPOSITION
from config import DECOMPOSITION_DATA_DIR_NO_TEXT, DECOMPOSITION_DATA_DIR_TEXT, DECOMPOSITION_DATA_DIR_ONLY_TEXT, DECOMPOSITION_DATA_DIR_TEXT_CLEAN, DATA_DIR_GRADCAM
from config import CLASSIFIER_WEIGHTS_PATH_RUN, GENERATOR_TEXT_REMOVE_WEIGHTS_PATH, GENERATOR_TEXT_DETECT_WEIGHTS_PATH, DISCRIMINATOR_TEXT_DETECT_WEIGHTS_PATH, DISCRIMINATOR_TEXT_REMOVE_WEIGHTS_PATH
from config import TRIAL_NAME_DETECTION_AND_LOCALIZATION, DIRPATH_DETECTION_AND_LOCALIZATION, WANDB_PROJECT_DETECTION_AND_LOCALIZATION, DETECTION_AND_LOCALIZATION_DATA_DIR_NO_TEXT, DETECTION_AND_LOCALIZATION_DATA_DIR_TEXT



def run_train():
    '''Run train process (train detection and localization and train decomposition) and store checkpoints'''
    run_train_detection_and_localization()
    run_train_decomposition()


def run_test():
    '''Run the test and log metrics to wandb'''
    run_test_detection_and_localization()
    run_test_decomposition()


def run_train_decomposition():
    wandb.login(key=LOGIN_KEY)
    wandb_logger = WandbLogger(project=WANDB_PROJECT_DECOMPOSITION, name=TRIAL_NAME_DECOMPOSITION)

    lr_monitor = LearningRateMonitor(logging_interval='step')
    checkpoint_callback = ModelCheckpoint(every_n_epochs=1,
                                          save_last=True,
                                          dirpath=DIRPATH_DECOMPOSITION)


    dm = DataModule(data_dir_no_text=DECOMPOSITION_DATA_DIR_NO_TEXT,
                    data_dir_text=DECOMPOSITION_DATA_DIR_TEXT,
                    data_dir_only_text=DECOMPOSITION_DATA_DIR_ONLY_TEXT,
                    data_dir_text_clean=DECOMPOSITION_DATA_DIR_TEXT_CLEAN,
                    data_dir_gadCAM=DATA_DIR_GRADCAM)

    model = Decomposition(*dm.dims,
                                classifier_weights_path=CLASSIFIER_WEIGHTS_PATH_RUN,
                                generator_text_remove_weights_path=GENERATOR_TEXT_REMOVE_WEIGHTS_PATH,
                                generator_text_detect_weights_path=GENERATOR_TEXT_DETECT_WEIGHTS_PATH,
                                discriminator_text_detect_weights_path=DISCRIMINATOR_TEXT_DETECT_WEIGHTS_PATH,
                                discriminator_text_remove_weights_path=DISCRIMINATOR_TEXT_REMOVE_WEIGHTS_PATH
                                )

    trainer = Trainer(
        log_every_n_steps=1,
        accelerator="gpu",
        max_epochs=300,
        logger=wandb_logger,
        callbacks=[checkpoint_callback, lr_monitor]
    )

    trainer.fit(model, dm)


def run_train_detection_and_localization():
    k=5

    for i in range(k):
        trial_name = TRIAL_NAME_DETECTION_AND_LOCALIZATION + str(i+1)
        dirpath = DIRPATH_DETECTION_AND_LOCALIZATION + trial_name + '/'

        wandb.login(key=LOGIN_KEY)
        wandb_logger = WandbLogger(project=WANDB_PROJECT_DETECTION_AND_LOCALIZATION,
                                name=trial_name)

        lr_monitor = LearningRateMonitor(logging_interval='step')
        checkpoint_callback = ModelCheckpoint(every_n_epochs=1,
                                            dirpath=dirpath,
                                            save_last=True)

        dm = DataModule(data_dir_text=DETECTION_AND_LOCALIZATION_DATA_DIR_TEXT,
                        data_dir_no_text=DETECTION_AND_LOCALIZATION_DATA_DIR_NO_TEXT)

        model = TwoTailClassifierPL()

        trainer = Trainer(
            accelerator="gpu",
            max_epochs=100,
            logger=wandb_logger,
            callbacks=[checkpoint_callback,
                    lr_monitor],
            log_every_n_steps=1
        )

        trainer.fit(model, dm)

        classifier_weights_path = dirpath + 'last.ckpt'
        dst_folder_path_gradCam = dirpath + 'gradCAM'
        gradcam(classifier_weights_path, dst_folder_path_gradCam, DETECTION_AND_LOCALIZATION_DATA_DIR_TEXT)

    input_folder_path = [DIRPATH_DETECTION_AND_LOCALIZATION + TRIAL_NAME_DETECTION_AND_LOCALIZATION + str(i+1) + '/gradCAM' for i in range(k)]
    gradcam_ensambel(input_folder_path, DATA_DIR_GRADCAM)


def run_test_detection_and_localization():
    k=5
    for i in range(k):
        trial_name = TRIAL_NAME_DETECTION_AND_LOCALIZATION + str(i+1)
        dirpath = DIRPATH_DETECTION_AND_LOCALIZATION + trial_name + '/'

        wandb.login(key=LOGIN_KEY)
        wandb_logger = WandbLogger(project=WANDB_PROJECT_DETECTION_AND_LOCALIZATION,
                                name=trial_name)

        dm = DataModule(data_dir_text=DETECTION_AND_LOCALIZATION_DATA_DIR_TEXT,
                        data_dir_no_text=DETECTION_AND_LOCALIZATION_DATA_DIR_NO_TEXT)
        
        path = dirpath + 'last.ckpt'
        model = TwoTailClassifierPL()
        model = model.load_from_checkpoint(path)

        trainer = Trainer(
            accelerator="gpu",
            logger=wandb_logger,
            log_every_n_steps=1
        )

        trainer.test(model, dm)


def run_test_decomposition():
    wandb.login(key=LOGIN_KEY)
    wandb_logger = WandbLogger(project=WANDB_PROJECT_DECOMPOSITION, name=TRIAL_NAME_DECOMPOSITION)

    dm = DataModule(data_dir_no_text=DECOMPOSITION_DATA_DIR_NO_TEXT,
                    data_dir_text=DECOMPOSITION_DATA_DIR_TEXT,
                    data_dir_only_text=DECOMPOSITION_DATA_DIR_ONLY_TEXT,
                    data_dir_text_clean=DECOMPOSITION_DATA_DIR_TEXT_CLEAN,
                    data_dir_gadCAM=DATA_DIR_GRADCAM)

    path = DIRPATH_DECOMPOSITION + TRIAL_NAME_DECOMPOSITION + '/last.ckpt'
    model = Decomposition(*dm.dims,
                                classifier_weights_path=CLASSIFIER_WEIGHTS_PATH_RUN,
                                generator_text_remove_weights_path=GENERATOR_TEXT_REMOVE_WEIGHTS_PATH,
                                generator_text_detect_weights_path=GENERATOR_TEXT_DETECT_WEIGHTS_PATH,
                                discriminator_text_detect_weights_path=DISCRIMINATOR_TEXT_DETECT_WEIGHTS_PATH,
                                discriminator_text_remove_weights_path=DISCRIMINATOR_TEXT_REMOVE_WEIGHTS_PATH
                                )
    model = model.load_from_checkpoint(path)

    trainer = Trainer(
        log_every_n_steps=1,
        accelerator="gpu",
        logger=wandb_logger,
    )

    trainer.test(model, dm)



if __name__ == "__main__":
    run_train()  
    run_test()
