import os
import yaml
from types import SimpleNamespace
from collections import OrderedDict

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
from torch import optim
import torchvision.transforms as T
from sklearn import metrics

from .BaseTrainer import BaseTrainer
from data.graph_dataset import PatchGraphClsDataset, get_weight_list
from models import BiGCN_Nolink
from utils.ranger import Ranger
from utils.lr_scheduler import CosineAnnealingWithWarmUpLR
from utils.losses import GHMCELoss

class GCNClsTrainer(BaseTrainer):
    def __init__(self, opt_file='args/gcn_cls.yaml'):
        with open(opt_file) as f:
            opt = yaml.safe_load(f)
            opt = SimpleNamespace(**opt)
        self.opt = opt

        super(GCNClsTrainer, self).__init__(checkpoint_root='GCN_CLS', opt=opt)
        with open(opt_file) as f:
            self.logger.info(f'{opt_file} START************************\n'
            f'{f.read()}\n'
            f'************************{opt_file} END**************************\n')

        if opt.device == 'cpu':
            self.device = torch.device('cpu')
        else:
            os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.device)
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.logger.info(f'Using device {self.device}')

        self.train_transform = T.Compose([
            T.Resize(opt.image_size),
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
            T.RandomApply([T.ColorJitter(0.4, 0.4, 0.4)], p=0.6),
            T.ToTensor(),
        ])
        self.val_transform = T.Compose([
            T.Resize(opt.image_size),
            T.ToTensor(),
        ])

        #self.train_dataset = PatchGraphDataset(opt.source, cate_list=opt.cate_list, random=True, transform=self.train_transform, link=opt.link)
        self.train_dataset = PatchGraphClsDataset(opt.source, opt.cate_list, True, self.train_transform)
        self.n_train = len(self.train_dataset)
        self.train_sampler = WeightedRandomSampler(get_weight_list(self.train_dataset), self.n_train)
        self.train_loader = DataLoader(self.train_dataset,
                                       batch_size=opt.batch_size,
                                       sampler=self.train_sampler,
                                       drop_last=False,
                                       num_workers=8, 
                                       pin_memory=True)

        self.val_dataset = PatchGraphClsDataset(opt.source, [cate+'_val' for cate in opt.cate_list], False, self.val_transform)
        self.n_val = len(self.val_dataset)
        self.val_loader = DataLoader(self.val_dataset,
                                     batch_size=opt.batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=8,
                                     pin_memory=True)

        self.test_dataset = PatchGraphClsDataset(opt.source, [cate+'_test' for cate in opt.cate_list], False, self.val_transform)
        self.n_test = len(self.test_dataset)
        self.test_loader = DataLoader(self.test_dataset,
                                      batch_size=opt.batch_size,
                                      shuffle=False,
                                      drop_last=False,
                                      num_workers=8,
                                      pin_memory=True)

        self.net = BiGCN_Nolink(n_channels=3, n_cls_classes=2, n_seg_classes=3, input_size=opt.image_size)
        if opt.load_model:
            self.net.load_state_dict(torch.load(opt.load_model, map_location=self.device))
            self.logger.info(f'Model loaded from {opt.load_model}')
        self.net.to(device=self.device)
        if torch.cuda.device_count() > 1 and self.device.type != 'cpu':
            self.net = nn.DataParallel(self.net)
            self.logger.info(f'torch.cuda.device_count:{torch.cuda.device_count()}, Use nn.DataParallel')
        self.net_module = self.net.module if isinstance(self.net, nn.DataParallel) else self.net

        self.optimizer = Ranger(self.net.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
        self.scheduler = CosineAnnealingWithWarmUpLR(self.optimizer, T_total=max(opt.epochs,1), eta_min=opt.lr/1000, warm_up_lr=opt.lr/100, warm_up_step=opt.warm_up_step)
        if opt.load_optimizer:
            self.optimizer.load_state_dict(torch.load(opt.load_optimizer))
            self.logger.info(f'Optimizer loaded from {opt.load_optimizer}')
        if opt.load_scheduler:
            self.scheduler.load_state_dict(torch.load(opt.load_scheduler))
            self.logger.info(f'Scheduler loaded from {opt.load_scheduler}')
        
        self.criterion = GHMCELoss()
        #self.criterion = nn.CrossEntropyLoss()

        self.epochs = opt.epochs
        self.save_cp = opt.save_cp
        self.early_stopping = opt.early_stopping
        self.training_info = opt.info

        self.logger.info(f'''Starting training net:
        Epochs:          {opt.epochs}
        Batch size:      {opt.batch_size}
        Learning rate:   {opt.lr}
        Image size:      {opt.image_size}
        Training size:   {self.n_train}
        Validation size: {self.n_val}
        Test size:       {self.n_test}
        Checkpoints:     {opt.save_cp}
        Device:          {self.device.type}
        Data source:     {opt.source}
        Training info:   {opt.info}
    ''')

    def train(self):
        global_step = 0
        best_val_score = -1 #float('inf')
        useless_epoch_count = 0
        vis_flag = True
        for epoch in range(self.opt.start_epoch, self.epochs):
            try:
                self.net.train()
                epoch_loss = 0
                true_list = []
                pred_list = []
                pbar = tqdm(total=self.n_train, desc=f'Epoch {epoch + 1}/{self.epochs}', unit='img')
                for imgs, labels in self.train_loader:
                    global_step += 1
                    imgs, labels = imgs.to(self.device), labels.to(self.device)
                    true_list += labels.tolist()

                    preds = self.net(imgs, 'cls')
                    pred_list += preds.detach().argmax(dim=1).tolist()

                    loss = self.criterion(preds, labels)
                    
                    self.writer.add_scalar(f'Train/batch_loss', loss.item(), global_step)
                    epoch_loss += loss.item() * labels.size(0)
                    postfix = OrderedDict()
                    postfix['loss'] = loss.item()
                    pbar.set_postfix(postfix)

                    self.optimizer.zero_grad()
                    loss.backward()
                    # nn.utils.clip_grad_value_(self.net.parameters(), 0.1)
                    self.optimizer.step()

                    if vis_flag:
                        vis_flag = False
                        self.draw_tensorboard_images(epoch, imgs, labels, preds)

                    pbar.update(labels.shape[0])
                pbar.close()

                epoch_loss /= self.n_train
                self.logger.info(f'Train epoch {epoch+1} loss: {epoch_loss}')
                self.writer.add_scalar('Train/epoch_loss', epoch_loss, epoch+1)
                self.logger.info(f'Train epoch {epoch + 1} train report:\n'+metrics.classification_report(true_list, pred_list, digits=4))

                for tag, value in self.net_module.named_parameters():
                    tag = tag.replace('.', '/')
                    self.writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), epoch+1)
                    if value.grad is not None:
                        self.writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), epoch+1)

                self.writer.add_scalar('learning_rate', self.optimizer.param_groups[0]['lr'], epoch+1)

                mAP, val_loss, accuracy = self.evaluate(epoch)
                self.logger.info(f'Val epoch {epoch+1} accuracy: {accuracy}, mAP: {mAP}, loss: {val_loss}')
                self.writer.add_scalar('Val/mAP', mAP, epoch+1)
                self.writer.add_scalar('Val/loss', val_loss, epoch+1)
                self.writer.add_scalar('Val/accuracy', accuracy, epoch+1)

                self.scheduler.step()

                if self.save_cp:
                    torch.save(self.net_module.state_dict(), self.checkpoint_dir + f'Net_epoch{epoch + 1}.pth')
                    self.logger.info(f'Checkpoint {epoch + 1} saved !')
                else:
                    torch.save(self.net_module.state_dict(), self.checkpoint_dir + 'Net_last.pth')
                    self.logger.info('Last model saved !')
                torch.save(self.optimizer.state_dict(), self.checkpoint_dir + 'Optimizer_last.pth')
                torch.save(self.scheduler.state_dict(), self.checkpoint_dir + 'Scheduler_last.pth')

                if accuracy > best_val_score:
                    best_val_score = accuracy
                    torch.save(self.net_module.state_dict(), self.checkpoint_dir + 'Net_best.pth')
                    torch.save(self.optimizer.state_dict(), self.checkpoint_dir + 'Optimizer_best.pth')
                    torch.save(self.scheduler.state_dict(), self.checkpoint_dir + 'Scheduler_best.pth')
                    self.logger.info('Best model saved !')
                    useless_epoch_count = 0
                else:
                    useless_epoch_count += 1

                if self.early_stopping and useless_epoch_count == self.early_stopping:
                    self.logger.info(f'There are {useless_epoch_count} useless epochs! Early Stop Training!')
                    break

            except KeyboardInterrupt:
                self.logger.info('Receive KeyboardInterrupt, stop training...')
                pbar.close()
                break

        if self.opt.epochs > 0:
            self.net_module.load_state_dict(torch.load(self.checkpoint_dir + 'Net_best.pth', map_location=self.device))
        self.logger.info(f'Evaluate best model:')
        mAP, val_loss, accuracy = self.evaluate(-1)
        self.logger.info(f'Best model val mAP: {mAP}, loss: {val_loss}, accuracy: {accuracy}')
        mAP, accuracy = self.test()
        self.logger.info(f'Best model test mAP: {mAP}, accuracy: {accuracy}')
        #self.logger.info(f'Classify wrong list:\n{wrong_list}')

    @torch.no_grad()
    def evaluate(self, epoch):
        self.net.eval()
        tot_loss = 0
        true_list = []
        pred_list = []
        pred_ori_list = []
        vis_flag = True
        with tqdm(total=self.n_val, desc='Validation round', unit='img', leave=False) as pbar:
            for imgs, labels in self.val_loader:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                preds = self.net(imgs, 'cls')

                if vis_flag and epoch >= 0 and labels[labels==0].size(0) > 0 and labels[labels==1].size(0) > 0:
                    vis_flag = False
                    imgs = imgs.cpu()
                    self.writer.add_images('Val_cls/1origin_images', imgs, epoch+1, dataformats='NCHW')
                    color_list = [torch.ByteTensor([0,0,0]), torch.ByteTensor([255,0,0]), torch.ByteTensor([0,255,0]), torch.ByteTensor([0,0,255])]
                    labels_img = torch.zeros(labels.shape[0], 100, 100, 3, dtype = torch.uint8)
                    preds_img = torch.zeros(preds.shape[0], 100, 100, 3, dtype = torch.uint8)
                    preds_idx = preds.argmax(dim=1)
                    for category in range(1, self.net_module.n_cls_classes):
                        labels_img[labels==category] = color_list[category]
                        preds_img[preds_idx==category] = color_list[category]
                    self.writer.add_images('Val_cls/0categories/0true', labels_img, epoch+1, dataformats='NHWC')
                    self.writer.add_images('Val_cls/0categories/1pred', preds_img, epoch+1, dataformats='NHWC')

                preds_idx = torch.softmax(preds, dim=1)
                pred_ori_list += preds_idx.tolist()
                preds_idx = preds_idx.argmax(dim=1)
                true_list += labels.tolist()
                pred_list.extend(preds_idx.tolist())
                tot_loss += F.cross_entropy(preds, labels).item() * labels.size(0)
                pbar.update(labels.size(0))
        self.net.train()
        AP = []
        for c in range(self.net_module.n_cls_classes):
            c_true_list = [int(item==c) for item in true_list]
            c_pred_ori_list = [item[c] for item in pred_ori_list]
            AP.append(metrics.average_precision_score(c_true_list, c_pred_ori_list))
        self.logger.info('Validation report:\n'+metrics.classification_report(true_list, pred_list, digits=4))
        return float(np.mean(AP)), tot_loss / self.n_val, metrics.accuracy_score(true_list, pred_list)

    @torch.no_grad()
    def test(self):
        self.net.eval()
        true_list = []
        pred_list = []
        pred_ori_list = []
        #wrong_list = []
        with tqdm(total=self.n_test, desc='Test round', unit='img', leave=False) as pbar:
            for imgs, labels in self.test_loader:
                imgs, labels = imgs.to(self.device), labels.to(self.device)
                preds = self.net(imgs, 'cls')
                preds_idx = torch.softmax(preds, dim=1)
                pred_ori_list += preds_idx.tolist()
                preds_idx = preds_idx.argmax(dim=1)
                true_list += labels.tolist()
                pred_list.extend(preds_idx.tolist())
                # for name, y, yp in zip(names, labels.tolist(), preds_idx.tolist()):
                #     if y != yp:
                #         wrong_list.append(name)
                pbar.update(labels.size(0))
        
        AP = []
        for c in range(self.net_module.n_cls_classes):
            c_true_list = [int(item==c) for item in true_list]
            c_pred_ori_list = [item[c] for item in pred_ori_list]
            AP.append(metrics.average_precision_score(c_true_list, c_pred_ori_list))
        self.logger.info('Test report:\n'+metrics.classification_report(true_list, pred_list, digits=4))
        return float(np.mean(AP)), metrics.accuracy_score(true_list, pred_list)#, wrong_list

    def __del__(self):
        del self.train_loader, self.val_loader
        super(GCNClsTrainer, self).__del__()

    def draw_tensorboard_images(self, epoch, imgs, labels, preds):
        imgs = imgs.cpu().permute(0,2,3,1) #*self.img_std+self.img_mean
        self.writer.add_images('Train/1origin_images', imgs, epoch+1, dataformats='NHWC')
        color_list = [torch.ByteTensor([0,0,0]), torch.ByteTensor([255,0,0]), torch.ByteTensor([0,255,0]), torch.ByteTensor([0,0,255])]
        labels_img = torch.zeros(labels.shape[0], 100, 100, 3, dtype = torch.uint8)
        preds_img = torch.zeros(preds.shape[0], 100, 100, 3, dtype = torch.uint8)
        preds_idx = preds.argmax(dim=1)
        for category in range(1, self.net_module.n_cls_classes):
            labels_img[labels==category] = color_list[category]
            preds_img[preds_idx==category] = color_list[category]
        self.writer.add_images('Train/0categories/0true', labels_img, epoch+1, dataformats='NHWC')
        self.writer.add_images('Train/0categories/1pred', preds_img, epoch+1, dataformats='NHWC')
