import numpy as np
import torch
from tqdm import tqdm
import time
from meters import AverageMeter
import wandb
import matplotlib.cm as cm


class Trainer:
    def __init__(self, model, optimizer, training_dataset, validation_dataset, config):
        self.model = model
        self.optimizer = optimizer
        self.train_loader = training_dataset
        self.val_loader = validation_dataset
        self.config = config
        self.batch_size = self.config.batch_size
        self.criterion_5 = None
        if config.l1_penalty > 0.0:
            self.l1_regularization = config.l1_penalty
        else:
            self.l1_regularization = None
        if self.config.loss.lower() == 'mse':
            self.criterion = torch.nn.MSELoss()
        elif self.config.loss.lower() == 'mae':
            self.criterion = torch.nn.L1Loss()
        elif self.config.loss.lower() == 'huber':
            self.criterion = torch.nn.SmoothL1Loss(beta=0.1)
        elif self.config.loss.lower() == 'ssim':
            import pytorch_msssim
            self.criterion_5 = pytorch_msssim.MSSSIM()
            self.criterion = torch.nn.L1Loss()
        else:
            raise RuntimeError()
        self.num_train = len(self.train_loader.sampler.indices)
        self.num_valid = len(self.val_loader.sampler)
        self.lr = self.config.init_lr
        self.curr_epoch = 0
        self.clamp = config.force_positive
        self.ortho_reg = None
        wandb.watch(self.model, log='all')

    def train(self):
        print(f"\n[*] Train on {self.num_train} samples, validate on {self.num_valid} samples")
        best_val_loss = np.inf
        epochs_since_best = 0
        lr_pat = 0
        self.curr_epoch = 0
        with torch.no_grad():
            self.log_data(commit=True)
        for epoch in range(self.config.epochs):
            self.curr_epoch = epoch + 1
            print(f'\nEpoch {epoch}/{self.config.epochs} -- lr = {self.lr}')
            train_loss, _ = self.run_one_epoch(training=True)
            with torch.no_grad():
                val_loss, _ = self.run_one_epoch(training=False)
                self.log_data(commit=False)
            msg = f'train loss {train_loss:.3f} -- val loss {val_loss:.3f}'
            is_best = val_loss < best_val_loss
            if is_best:
                best_val_loss = val_loss
                self.model.save_model(verbose=True)
                lr_pat = 0
                epochs_since_best = 0
            else:
                epochs_since_best += 1
                lr_pat += 1

            if is_best:
                msg += ' [*]'
            print(msg)
            wandb.log({
                'train_loss': train_loss,
                'val_loss': val_loss,
                'best_val_loss': best_val_loss
            }, step=self.curr_epoch)
            
            if lr_pat > self.config.lr_patience:
                lr_pat = 0
                self.lr = self.lr / np.sqrt(10)
                for param_group in self.optimizer.param_groups:
                    param_group['lr'] = self.lr
            if epochs_since_best > self.config.train_patience:
                return

    def run_one_epoch(self, training):
        tic = time.time()
        batch_time = AverageMeter()
        losses = AverageMeter()
        accs = AverageMeter()
        if training:
            amnt = self.num_train
            dataset = self.train_loader
        else:
            dataset = self.val_loader
            amnt = self.num_valid
        with tqdm(total=amnt) as pbar:
            for i, data in enumerate(dataset):
                x, y = data
                # segmentation task
                y = y.view(1, -1, 1, y.shape[-2], y.shape[-1]).expand(self.model.num_heads, -1, -1, -1, -1)
                if self.config.use_gpu:
                    x, y = x.cuda(), y.cuda()
                output = self.model(x)
                if training:
                    self.optimizer.zero_grad()
                loss = None

                for head in range(self.model.num_heads):
                    if loss is None:
                        if self.criterion_5 is None or self.curr_epoch < 5:
                            loss = self.criterion(output[head], y[head])
                        else:
                            loss = self.criterion_5(output[head], y[head])
                    else:
                        loss = loss + self.criterion(output[head], y[head])
                loss = loss / self.model.num_heads
                mse_loss = loss.detach()
                try:
                    loss_data = mse_loss.data[0]
                except IndexError:
                    loss_data = mse_loss.data.item()
                losses.update(loss_data)
                if training:
                    if self.l1_regularization is not None:
                        for param in self.model.illumination_layer.parameters():
                            loss += torch.norm(param)*self.l1_regularization
                    if self.ortho_reg is not None:
                        for param in self.model.illumination_layer.parameters():
                            loss += torch.norm(torch.prod(param, dim=0)) * self.ortho_reg
                    loss.backward()
                    self.optimizer.step()
                    if self.clamp:
                        self.model.illumination_layer.clamp_weights()
                # measure elapsed time
                toc = time.time()
                batch_time.update(toc - tic)
                pbar.set_description(f"{(toc - tic):.1f}s - loss: {loss_data:.3f}")
                pbar.update(self.batch_size)


        return losses.avg, accs.avg

    def log_data(self, commit=False):
        dataset = self.val_loader
        for data in dataset:
            x, y = data
            y = y.view(1, -1, 1, y.shape[-2], y.shape[-1]).expand(self.model.num_heads, -1, -1, -1, -1)
            if self.config.use_gpu:
                x, y = x.cuda(), y.cuda()
            output = self.model(x)
            y_sample = y[0,0].detach().cpu().numpy()
            p_sample = output[0,0].detach().cpu().numpy()
            y_viz = cm.viridis(y_sample)
            p_viz = cm.viridis(p_sample)
            wandb.log({f"images": [
                    wandb.Image(p_viz, caption="prediction"),
                    wandb.Image(y_viz, caption="label")]}, step=self.curr_epoch, commit=False)
            pattern_data = self.model.illumination_layer.physical_layer.weight.detach().cpu().numpy()
            leds = pattern_data.reshape(-1, 675)
            for l, led in enumerate(leds):
                green = led[225:225*2]
                green = green / np.percentile(green, 99)
                if self.clamp:
                    green = np.clip(green, 0, 1)
                else:
                    green = (np.clip(green, -1, 1) + 1)/2
                g_viz = cm.cividis(green.reshape(15,15))
                do_commit = commit and (l+1) == leds.shape[0]
                wandb.log({f"led_{l}": [wandb.Image(g_viz, caption=f"led_{l}")]}, step=self.curr_epoch, commit=do_commit)
            print("STEP", self.curr_epoch)
            break

