import os
import yaml
from types import SimpleNamespace
from collections import OrderedDict

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, WeightedRandomSampler
import torchvision.transforms as T
from sklearn import metrics
from PIL import Image

from .BaseTrainer import BaseTrainer
from data.graph_dataset import PatchGraphNoLinkDataset, get_weight_list, no_link_coll_fn, PatchGraphTestDataset, NoLinkGeneratePseudoDataset
from data.graph_dataset import PATCH_EDGE, SPLIT_EDGE
import data.transforms as MT
from models import load_param, BiGCN_Nolink
from utils.ranger import Ranger
from utils.lr_scheduler import CosineAnnealingWithWarmUpLR
from utils.losses import FuzzyClassificationLoss, GHMCELoss
from data.utils import draw_link

class GCNJoTrainer(BaseTrainer): #联合训练，不使用链接预测
    def __init__(self, opt_file='args/gcn_jo.yaml'):
        with open(opt_file) as f:
            opt = yaml.safe_load(f)
            opt = SimpleNamespace(**opt)
        self.opt = opt

        super(GCNJoTrainer, self).__init__(checkpoint_root='GCN_JO', opt=opt)
        with open(opt_file) as f:
            self.logger.info(f'{opt_file} START************************\n'
            f'{f.read()}\n'
            f'************************{opt_file} END**************************\n')

        if opt.device == 'cpu':
            self.device = torch.device('cpu')
        else:
            os.environ['CUDA_VISIBLE_DEVICES'] = str(opt.device)
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.logger.info(f'Using device {self.device}')

        self.mask_multilable = (opt.mask != 'mask')
        self.train_transform = T.Compose([
            MT.ResizeG(opt.image_size, link=False),
            MT.RandomHorizontalFlipG(opt.image_size, p=0.5, link=False),
            MT.RandomVerticalFlipG(opt.image_size, p=0.5, link=False),
            T.RandomApply([MT.ColorJitterG(0.4, 0.4, 0.4, link=False)], p=0.6),
            MT.ToTensorG(link=False, multilable = self.mask_multilable),
        ])
        self.val_transform_cls = T.Compose([
            MT.ResizeG(opt.image_size, link=False),
            MT.ToTensorG(link=False, multilable = self.mask_multilable),
        ])
        self.val_transform_seg = T.Compose([
            MT.Resize3D(opt.image_size),
            MT.ToTensor3D(no_mask=True),
        ])

        self.train_dataset = PatchGraphNoLinkDataset(opt.source, cate_list=opt.cate_list, random=True, transform=self.train_transform, mask=opt.mask)
        self.n_train = len(self.train_dataset)
        self.train_sampler = WeightedRandomSampler(get_weight_list(self.train_dataset), self.n_train)
        self.train_loader = DataLoader(self.train_dataset,
                                       batch_size=opt.batch_size,
                                       sampler=self.train_sampler,
                                       drop_last=False,
                                       num_workers=8, 
                                       collate_fn=no_link_coll_fn,
                                       pin_memory=True)

        self.val_dataset_cls = PatchGraphNoLinkDataset(opt.source, cate_list=[cate+'_val' for cate in opt.cate_list], random=False, transform=self.val_transform_cls, mask=opt.mask)
        self.n_val_cls = len(self.val_dataset_cls)
        self.val_loader = DataLoader(self.val_dataset_cls,
                                     batch_size=opt.batch_size,
                                     shuffle=False,
                                     drop_last=False,
                                     num_workers=8,
                                     collate_fn=no_link_coll_fn,
                                     pin_memory=True)
        
        self.val_dataset_seg = PatchGraphTestDataset(os.path.join(opt.source_test, 'val'), transform=self.val_transform_seg, mask=opt.test_mask)
        self.n_val_seg = len(self.val_dataset_seg)

        self.test_dataset_cls = PatchGraphNoLinkDataset(opt.source, cate_list=[cate+'_test' for cate in opt.cate_list], random=False, transform=self.val_transform_cls, mask=opt.mask)
        self.n_test_cls = len(self.test_dataset_cls)
        self.test_loader = DataLoader(self.test_dataset_cls,
                                      batch_size=opt.batch_size,
                                      shuffle=False,
                                      drop_last=False,
                                      num_workers=8,
                                      collate_fn=no_link_coll_fn,
                                      pin_memory=True)

        self.test_dataset_seg = PatchGraphTestDataset(os.path.join(opt.source_test, 'test'), transform=self.val_transform_seg, mask=opt.test_mask)
        self.n_test_seg = len(self.test_dataset_seg)

        self.net = BiGCN_Nolink(n_channels=3, n_cls_classes=2, n_seg_classes=3, input_size=opt.image_size)
        if opt.load_model:
            try:
                self.net.load_state_dict(torch.load(opt.load_model, map_location=self.device))
            except (RuntimeError, KeyError):
                load_param(self.net, opt.load_model, map_location=self.device)
            self.logger.info(f'Model loaded from {opt.load_model}')
        self.net.to(device=self.device)
        if torch.cuda.device_count() > 1 and self.device.type != 'cpu':
            self.net = nn.DataParallel(self.net)
            self.logger.info(f'torch.cuda.device_count:{torch.cuda.device_count()}, Use nn.DataParallel')
        self.net_module = self.net.module if isinstance(self.net, nn.DataParallel) else self.net
        self.net.eval()

        self.optimizer = Ranger([{'params': self.net_module.graph_branch.parameters()},
                                 {'params': list(self.net_module.resnet50.parameters())+list(self.net_module.cls_branch.parameters()), 'lr': opt.lr/10},
                                ], lr=opt.lr, weight_decay=opt.weight_decay)
        self.scheduler = CosineAnnealingWithWarmUpLR(self.optimizer, T_total=max(opt.epochs,1), eta_min=opt.lr/1000, warm_up_lr=opt.lr/100, warm_up_step=opt.warm_up_step) #防止除0报错
        if opt.load_optimizer:
            self.optimizer.load_state_dict(torch.load(opt.load_optimizer))
            self.logger.info(f'Optimizer loaded from {opt.load_optimizer}')
        if opt.load_scheduler:
            self.scheduler.load_state_dict(torch.load(opt.load_scheduler))
            self.logger.info(f'Scheduler loaded from {opt.load_scheduler}')
        
        self.criterion_fuzzy = FuzzyClassificationLoss()
        self.criterion_ce = nn.CrossEntropyLoss()
        self.criterion_ceforbg = nn.CrossEntropyLoss(ignore_index=1)
        #self.criterion_ce_with_label_smooth = nn.CrossEntropyLoss(label_smoothing=0.1)
        self.criterion_cls = GHMCELoss()

        self.epochs = opt.epochs
        self.save_cp = opt.save_cp
        self.early_stopping = opt.early_stopping
        self.training_info = opt.info

        self.logger.info(f'''Starting training net:
        Epochs:              {opt.epochs}
        Batch size:          {opt.batch_size}
        Learning rate:       {opt.lr}
        Image size:          {opt.image_size}
        Training size:       {self.n_train}
        Validation cls size: {self.n_val_cls}
        Validation seg size: {self.n_val_seg}
        Test cls size:       {self.n_test_cls}
        Test seg size:       {self.n_test_seg}
        Checkpoints:         {opt.save_cp}
        Device:              {self.device.type}
        Data source:         {opt.source}
        Training info:       {opt.info}
    ''')

    def train(self):
        global_step = 0
        best_val_score = -1 #float('inf')
        useless_epoch_count = 0
        for epoch in range(self.opt.start_epoch, self.epochs):
            try:
                self.net.train() #self.net_module.graph_branch.train()
                epoch_loss = 0
                epoch_cls, epoch_fuzzy, epoch_ce_zero, epoch_ce_one, epoch_one = 0, 0, 0, 0, 0
                epoch_one_pos, epoch_one_neg, pos_count = 0, 0, 0
                true_list = []
                pred_list = []
                vis_flag = True
                pbar = tqdm(total=self.n_train, desc=f'Epoch {epoch + 1}/{self.epochs}', unit='img')
                for imgs, masks, labels in self.train_loader:
                    global_step += 1
                    true_list += labels.tolist()
                    imgs, masks, labels = imgs.to(self.device), masks.to(self.device), labels.to(self.device)

                    cls_logit, seg_cls_logit = self.net(imgs)
                    seg_cls_pred = torch.softmax(seg_cls_logit, dim=1)
                    pred_list += cls_logit.detach().argmax(dim=1).tolist()

                    loss_cls = self.criterion_cls(cls_logit, labels) #torch.tensor(0., device=self.device)

                    #loss_fuzzy = self.criterion_fuzzy(seg_cls_pred)

                    zero_logit = seg_cls_logit[labels==0]
                    loss_ce_zero = torch.tensor(0, dtype=torch.float32, device=self.device)
                    if zero_logit.size(0) != 0:
                        loss_ce_zero = self.criterion_ce(zero_logit, masks[labels==0]) * zero_logit.size(0) / imgs.size(0) 

                    one_pred = seg_cls_pred[labels==1]
                    loss_one = torch.tensor(0, dtype=torch.float32, device=self.device)
                    loss_ce_one = torch.tensor(0, dtype=torch.float32, device=self.device)
                    if one_pred.size(0) != 0:
                        if self.mask_multilable:
                            masks_pos = masks[labels==1]
                            masks_pos[masks_pos==2] = 1
                            loss_ce_one = self.criterion_ceforbg(seg_cls_logit[labels==1], masks_pos) * one_pred.size(0) / imgs.size(0)
                        else:
                            loss_ce_one = self.criterion_ceforbg(seg_cls_logit[labels==1], masks[labels==1]) * one_pred.size(0) / imgs.size(0)

                        for param in list(self.net_module.resnet50.parameters())+list(self.net_module.cls_branch.parameters()):
                            param.requires_grad_(False)
                        self.net_module.resnet50.eval()
                        self.net_module.cls_branch.eval()
                        
                        pos_channel = one_pred[:, 2:3] #self.random_mask(one_pred[:, 2:3], p=0.2)
                        filter_pos = F.interpolate(pos_channel, size=imgs.shape[2:], mode='nearest')
                        filter_pos_logit = self.net(imgs[labels==1] * filter_pos, 'cls')
                        one_loss_pos = self.criterion_ce(filter_pos_logit, labels[labels==1])

                        #filter_neg = F.interpolate(one_pred[:, 1:2], size=imgs.shape[2:], mode='nearest')
                        filter_neg = F.interpolate(torch.sum(one_pred[:, 0:2], dim=1, keepdim=True), size=imgs.shape[2:], mode='nearest')
                        #filter_neg = F.interpolate(1-pos_channel, size=imgs.shape[2:], mode='nearest')
                        filter_neg_logit = self.net(imgs[labels==1] * filter_neg, 'cls')
                        one_loss_neg = self.criterion_ce(filter_neg_logit, torch.zeros_like(labels[labels==1]))
                        
                        loss_one = (one_loss_pos + one_loss_neg) * one_pred.size(0) / imgs.size(0)

                        for param in list(self.net_module.resnet50.parameters())+list(self.net_module.cls_branch.parameters()):
                            param.requires_grad_(True)
                        self.net_module.resnet50.train()
                        self.net_module.cls_branch.train()

                    loss = 0.5*loss_cls+loss_ce_zero+loss_ce_one+loss_one #+loss_fuzzy
                    
                    self.writer.add_scalar(f'Train/batch_loss', loss.item(), global_step)
                    self.writer.add_scalar(f'Train/batch_loss_cls', loss_cls.item(), global_step)
                    self.writer.add_scalar(f'Train/batch_loss_ce_zero', loss_ce_zero.item(), global_step)
                    self.writer.add_scalar(f'Train/batch_loss_ce_one', loss_ce_one.item(), global_step)
                    self.writer.add_scalar(f'Train/batch_loss_one', loss_one.item(), global_step)
                    self.writer.add_scalar(f'Train/batch_loss_one_pos', one_loss_pos.item(), global_step)
                    self.writer.add_scalar(f'Train/batch_loss_one_neg', one_loss_neg.item(), global_step)
                    #self.writer.add_scalar(f'Train/batch_loss_fuzzy', loss_fuzzy.item(), global_step)
                    epoch_loss += loss.item() * labels.size(0)
                    epoch_cls += loss_cls.item() * labels.size(0)
                    epoch_ce_zero += loss_ce_zero.item() * labels.size(0)
                    epoch_ce_one += loss_ce_one.item() * labels.size(0)
                    epoch_one += loss_one.item() * labels.size(0)
                    pos_count += labels[labels==1].size(0)
                    epoch_one_pos += one_loss_pos.item() * labels[labels==1].size(0)
                    epoch_one_neg += one_loss_neg.item() * labels[labels==1].size(0)
                    #epoch_fuzzy += loss_fuzzy.item() * labels.size(0)
                    postfix = OrderedDict()
                    postfix['loss'] = loss.item()
                    postfix['cls'] = loss_cls.item()
                    postfix['ce0'] = loss_ce_zero.item()
                    postfix['ce1'] = loss_ce_one.item()
                    postfix['one'] = loss_one.item()
                    postfix['1pos'] = one_loss_pos.item()
                    postfix['1neg'] = one_loss_neg.item()
                    #postfix['fuzzy'] = loss_fuzzy.item()
                    pbar.set_postfix(postfix)

                    self.optimizer.zero_grad()
                    loss.backward()
                    # nn.utils.clip_grad_value_(self.net.parameters(), 0.1)
                    self.optimizer.step()

                    if vis_flag:
                        vis_flag = False
                        self.draw_tensorboard_images(epoch, imgs, labels, masks, seg_cls_pred)

                    pbar.update(labels.shape[0])
                pbar.close()

                epoch_loss /= self.n_train
                epoch_cls /= self.n_train
                epoch_ce_zero /= self.n_train
                epoch_ce_one /= self.n_train
                epoch_one /= self.n_train
                epoch_one_pos /= pos_count
                epoch_one_neg /= pos_count
                epoch_fuzzy /= self.n_train
                self.logger.info(f'Train epoch {epoch+1} loss: {epoch_loss}, cls:{epoch_cls}, ce0:{epoch_ce_zero}, ce1:{epoch_ce_one}, one:{epoch_one}, pos:{epoch_one_pos}, neg:{epoch_one_neg}') #, fuzzy:{epoch_fuzzy}
                self.writer.add_scalar('Train/epoch_loss', epoch_loss, epoch+1)
                self.writer.add_scalar(f'Train/epoch_cls', epoch_cls, epoch+1)
                self.writer.add_scalar(f'Train/epoch_ce_zero', epoch_ce_zero, epoch+1)
                self.writer.add_scalar(f'Train/epoch_ce_one', epoch_ce_one, epoch+1)
                self.writer.add_scalar(f'Train/epoch_one', epoch_one, epoch+1)
                self.writer.add_scalar(f'Train/epoch_one_pos', epoch_one_pos, epoch+1)
                self.writer.add_scalar(f'Train/epoch_one_neg', epoch_one_neg, epoch+1)
                #self.writer.add_scalar(f'Train/epoch_fuzzy', epoch_fuzzy, epoch+1)
                self.logger.info(f'Train epoch {epoch + 1} train report:\n'+metrics.classification_report(true_list, pred_list, digits=4))

                for tag, value in self.net_module.named_parameters():
                    tag = tag.replace('.', '/')
                    self.writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), epoch+1)
                    if value.grad is not None:
                        self.writer.add_histogram('grads/' + tag, value.grad.data.cpu().numpy(), epoch+1)

                self.writer.add_scalar('learning_rate', self.optimizer.param_groups[0]['lr'], epoch+1)

                map_cls, cls_loss, ac_cls, precision, recall, f1, dice = self.evaluate(epoch)
                self.logger.info(f'Val epoch {epoch+1} cls map_cls: {map_cls}, cls_loss: {cls_loss}, f1: {ac_cls}')
                self.logger.info(f'Val epoch {epoch+1} seg precision: {precision}, recall: {recall}, f1: {f1}, dice: {dice}')
                self.writer.add_scalar('Val/map_cls', map_cls, epoch+1)
                self.writer.add_scalar('Val/cls_loss', cls_loss, epoch+1)
                self.writer.add_scalar('Val/ac_cls', ac_cls, epoch+1)
                self.writer.add_scalar('Val/precision', precision, epoch+1)
                self.writer.add_scalar('Val/recall', recall, epoch+1)
                self.writer.add_scalar('Val/f1', f1, epoch+1)
                self.writer.add_scalar('Val/dice', dice, epoch+1)
                val_score = 0.5*ac_cls + 0.5*dice

                self.scheduler.step()

                if val_score > best_val_score:
                    best_val_score = val_score
                    torch.save(self.net_module.state_dict(), self.checkpoint_dir + 'Net_best.pth')
                    torch.save(self.optimizer.state_dict(), self.checkpoint_dir + 'Optimizer_best.pth')
                    torch.save(self.scheduler.state_dict(), self.checkpoint_dir + 'Scheduler_best.pth')
                    self.logger.info('Best model saved !')
                    useless_epoch_count = 0
                else:
                    useless_epoch_count += 1

                if self.save_cp:
                    torch.save(self.net_module.state_dict(), self.checkpoint_dir + f'Net_epoch{epoch + 1}.pth')
                    self.logger.info(f'Checkpoint {epoch + 1} saved !')
                else:
                    torch.save(self.net_module.state_dict(), self.checkpoint_dir + 'Net_last.pth')
                    self.logger.info('Last model saved !')
                torch.save(self.optimizer.state_dict(), self.checkpoint_dir + 'Optimizer_last.pth')
                torch.save(self.scheduler.state_dict(), self.checkpoint_dir + 'Scheduler_last.pth')

                if self.early_stopping and useless_epoch_count == self.early_stopping:
                    self.logger.info(f'There are {useless_epoch_count} useless epochs! Early Stop Training!')
                    break

            except KeyboardInterrupt:
                self.logger.info('Receive KeyboardInterrupt, stop training...')
                pbar.close()
                break
        
        if self.opt.epochs > 0:
            self.net_module.load_state_dict(torch.load(self.checkpoint_dir + 'Net_best.pth', map_location=self.device))
        self.logger.info(f'Val best model:')
        map_cls, cls_loss, ac_cls, precision, recall, f1, dice = self.evaluate(-1)
        self.logger.info(f'Best model val cls map_cls: {map_cls}, cls_loss: {cls_loss}, f1: {ac_cls}')
        self.logger.info(f'Best model val seg precision: {precision}, recall: {recall}, f1: {f1}, dice: {dice}')
        self.logger.info(f'Test best model:')
        map_cls, cls_loss, ac_cls, precision, recall, f1, dice = self.evaluate(-1, cate='test')
        self.logger.info(f'Best model test cls map_cls: {map_cls}, cls_loss: {cls_loss}, f1: {ac_cls}')
        self.logger.info(f'Best model test seg precision: {precision}, recall: {recall}, f1: {f1}, dice: {dice}')

        if self.opt.generate_pseudo_dir:
            self.generate_pseudo()

    @torch.no_grad()
    def evaluate(self, epoch, cate='val'):
        if cate == 'val':
            dataset, loader, num_cls = self.val_dataset_seg, self.val_loader, self.n_val_cls
        elif cate == 'test':
            dataset, loader, num_cls = self.test_dataset_seg, self.test_loader, self.n_test_cls
        else:
            raise ValueError(f'cate should be val or test, bot got {cate}')

        self.net.eval()

        tot_loss = 0
        true_list = []
        pred_list_filter = []
        pred_ori_list_filter = []
        pred_list = []
        pred_ori_list = []
        vis_flag = True
        with tqdm(total=num_cls, desc=f'{cate} classification round', unit='img', leave=False) as pbar:
            for imgs, masks, labels in loader:
                true_list += labels.tolist()
                imgs, masks, labels = imgs.to(self.device), masks.to(self.device), labels.to(self.device)
                
                cls_logit, seg_cls_logit = self.net(imgs)

                filters = F.interpolate(F.softmax(seg_cls_logit, dim=1)[:, 2:3], size=imgs.shape[2:], mode='nearest')
                filtered_img = imgs * filters
                filter_cls_logit = self.net(filtered_img, 'cls')

                if vis_flag and epoch >= 0 and labels[labels==0].size(0) > 0 and labels[labels==1].size(0) > 0:
                    vis_flag = False
                    imgs = imgs.cpu() #*self.img_std+self.img_mean
                    self.writer.add_images('Val_cls/2origin', imgs, epoch+1, dataformats='NCHW')
                    color_list = [torch.ByteTensor([0,0,0]), torch.ByteTensor([255,0,0]), torch.ByteTensor([0,255,0]), torch.ByteTensor([0,0,255])]
                    labels_img = torch.zeros(labels.shape[0], 100, 100, 3, dtype = torch.uint8)
                    pred_img = torch.zeros(cls_logit.shape[0], 100, 100, 3, dtype = torch.uint8)
                    filter_pred_img = torch.zeros(filter_cls_logit.shape[0], 100, 100, 3, dtype = torch.uint8)
                    pred_idx = cls_logit.argmax(dim=1)
                    filter_pred_idx = filter_cls_logit.argmax(dim=1)
                    for category in range(1, self.net_module.n_cls_classes):
                        labels_img[labels==category] = color_list[category]
                        pred_img[pred_idx==category] = color_list[category]
                        filter_pred_img[filter_pred_idx==category] = color_list[category]
                    self.writer.add_images('Val_cls/0categories/0true', labels_img, epoch+1, dataformats='NHWC')
                    self.writer.add_images('Val_cls/0categories/1pred', pred_img, epoch+1, dataformats='NHWC')
                    self.writer.add_images('Val_cls/0categories/2filter_pred', filter_pred_img, epoch+1, dataformats='NHWC')
                    masks_pred_img = torch.zeros(seg_cls_logit.shape[0], seg_cls_logit.shape[2], seg_cls_logit.shape[3], 3, dtype = torch.uint8)
                    masks_pred_idx = seg_cls_logit.argmax(dim=1)
                    for category in range(1, self.net_module.n_seg_classes):
                        masks_pred_img[masks_pred_idx==category] = color_list[category]
                    masks_pred_img = F.interpolate(masks_pred_img.permute(0,3,1,2), size=imgs.shape[2:])
                    if self.mask_multilable:
                        masks_true_img = torch.zeros(masks.shape[0], masks.shape[1], masks.shape[2], 3, dtype = torch.uint8)
                        for category in range(1, self.net_module.n_seg_classes):
                            masks_true_img[labels==category] = color_list[category]
                        masks_true_img = F.interpolate(masks_true_img.permute(0,3,1,2), size=imgs.shape[2:])
                        self.writer.add_images('Val_cls/3mask/0true', masks_true_img, epoch+1, dataformats='NCHW')
                    else:
                        self.writer.add_images('Val_cls/3mask/0true', F.interpolate(masks.byte().unsqueeze(1), size=imgs.shape[2:])*255, epoch+1, dataformats='NCHW')
                    self.writer.add_images('Val_cls/3mask/1pred', masks_pred_img, epoch+1, dataformats='NCHW')
                    self.writer.add_images('Val_cls/4visualize', 0.7*imgs+0.3*masks_pred_img/255, epoch+1, dataformats='NCHW')
                    self.writer.add_images('Val_cls/5filtered_img', filtered_img.cpu(), epoch+1, dataformats='NCHW')

                pred_idx = torch.softmax(filter_cls_logit, dim=1)
                pred_ori_list_filter += pred_idx.tolist()
                pred_idx = pred_idx.argmax(dim=1)
                pred_list_filter.extend(pred_idx.tolist())

                pred_idx = torch.softmax(cls_logit, dim=1)
                pred_ori_list += pred_idx.tolist()
                pred_idx = pred_idx.argmax(dim=1)
                pred_list.extend(pred_idx.tolist())

                tot_loss += F.cross_entropy(cls_logit, labels).item() * labels.size(0)
                pbar.update(labels.size(0))

        AP = []
        for c in range(self.net_module.n_cls_classes):
            c_true_list = [int(item==c) for item in true_list]
            c_pred_ori_list = [item[c] for item in pred_ori_list]
            AP.append(metrics.average_precision_score(c_true_list, c_pred_ori_list))
        self.logger.info(f'{cate} classify report:\n'+metrics.classification_report(true_list, pred_list, digits=4))
        self.logger.info(f'{cate} filter classify report:\n'+metrics.classification_report(true_list, pred_list_filter, digits=4))
        map_cls, cls_loss, ac_cls = float(np.mean(AP)), tot_loss / num_cls, metrics.accuracy_score(true_list, pred_list)

        true_list = []
        pred_list = []
        dice_list = []
        vis_flag = True
        for img_patch, lbl_patch in tqdm(dataset, desc=f'{cate} seg round', unit='img', leave=False):
            num = len(img_patch)
            lbl_flatten = []
            pred_flatten = []
            for i in range(0, num, self.opt.batch_size):
                img = torch.stack(img_patch[i:i+self.opt.batch_size], dim=0).to(self.device)
                lbl = torch.stack(lbl_patch[i:i+self.opt.batch_size], dim=0)
                true_list.extend(lbl.flatten().tolist())
                lbl_flatten.append(lbl.flatten())
                logit = self.net(img, 'seg')
                pred = (F.softmax(logit, dim=1)[:, 2]>0.5).long().cpu()
                pred_list.extend(pred.flatten().tolist())
                pred_flatten.append(pred.flatten())
                if vis_flag and epoch >= 0:
                    vis_flag = False
                    img = img.cpu() #*self.img_std+self.img_mean
                    self.writer.add_images('Val_seg/2origin', img, epoch+1, dataformats='NCHW')
                    color_list = [torch.ByteTensor([0,0,0]), torch.ByteTensor([255,0,0]), torch.ByteTensor([0,255,0]), torch.ByteTensor([0,0,255])]
                    masks_pred_img = torch.zeros(logit.shape[0], logit.shape[2], logit.shape[3], 3, dtype = torch.uint8)
                    masks_pred_idx = logit.argmax(dim=1)
                    for category in range(1, self.net_module.n_seg_classes):
                        masks_pred_img[masks_pred_idx==category] = color_list[category]
                    masks_pred_img = F.interpolate(masks_pred_img.permute(0,3,1,2), size=img.shape[2:])
                    self.writer.add_images('Val_seg/3mask/0true', F.interpolate(lbl.byte().unsqueeze(1), size=img.shape[2:])*255, epoch+1, dataformats='NCHW')
                    self.writer.add_images('Val_seg/3mask/1pred', masks_pred_img, epoch+1, dataformats='NCHW')
                    self.writer.add_images('Val_seg/4visualize', 0.7*img+0.3*masks_pred_img/255, epoch+1, dataformats='NCHW')
                del img
            lbl_flatten, pred_flatten = torch.cat(lbl_flatten), torch.cat(pred_flatten)
            dice = 2*(lbl_flatten*pred_flatten).sum() / (lbl_flatten.sum()+pred_flatten.sum())
            dice_list.append(dice.item())

        precision = metrics.precision_score(true_list, pred_list, pos_label=1)
        recall = metrics.recall_score(true_list, pred_list, pos_label=1)
        f1 = metrics.f1_score(true_list, pred_list, pos_label=1)
        dice = torch.tensor(dice_list).mean().item()
        self.logger.info(f'{cate} seg report:\n'+metrics.classification_report(true_list, pred_list, digits=4))

        self.net.train() #self.net_module.graph_branch.train()

        return map_cls, cls_loss, ac_cls, precision, recall, f1, dice

    def __del__(self):
        del self.train_loader, self.val_loader, self.test_loader
        super(GCNJoTrainer, self).__del__()

    def draw_tensorboard_images(self, epoch, imgs, labels, masks, seg_cls_pred):
        imgs = imgs.cpu() #*self.img_std+self.img_mean
        self.writer.add_images('Train/2origin_images', imgs, epoch+1, dataformats='NCHW')

        color_list = [torch.ByteTensor([0,0,0]), torch.ByteTensor([255,0,0]), torch.ByteTensor([0,255,0]), torch.ByteTensor([0,0,255])]
        labels_img = torch.zeros(labels.shape[0], 100, 100, 3, dtype = torch.uint8)
        for category in range(1, self.net_module.n_cls_classes):
            labels_img[labels==category] = color_list[category]
        self.writer.add_images('Train/0categories/true', labels_img, epoch+1, dataformats='NHWC')

        if self.mask_multilable:
            masks_true_img = torch.zeros(masks.shape[0], masks.shape[1], masks.shape[2], 3, dtype = torch.uint8)
            for category in range(1, self.net_module.n_seg_classes):
                masks_true_img[labels==category] = color_list[category]
            masks_true_img = F.interpolate(masks_true_img.permute(0,3,1,2), size=imgs.shape[2:])
            self.writer.add_images('Train/3masks/0true', masks_true_img, epoch+1, dataformats='NCHW')
        else:
            self.writer.add_images('Train/3masks/0true', F.interpolate(masks.byte().unsqueeze(1), size=imgs.shape[2:])*255, epoch+1, dataformats='NCHW')

        masks_pred_img = torch.zeros(seg_cls_pred.shape[0], seg_cls_pred.shape[2], seg_cls_pred.shape[3], 3, dtype = torch.uint8)
        masks_pred_idx = seg_cls_pred.argmax(dim=1)
        for category in range(1, self.net_module.n_seg_classes):
            masks_pred_img[masks_pred_idx==category] = color_list[category]
        masks_pred_img = F.interpolate(masks_pred_img.permute(0,3,1,2), size=imgs.shape[2:])
        self.writer.add_images('Train/3masks/1pred', masks_pred_img, epoch+1, dataformats='NCHW')

        self.writer.add_images('Train/4visualize', 0.7*imgs+0.3*masks_pred_img/255, epoch+1, dataformats='NCHW')

    @torch.no_grad()
    def generate_pseudo(self):
        self.net.eval()
        cate_list = self.opt.cate_list + [cate+'_val' for cate in self.opt.cate_list] + [cate+'_test' for cate in self.opt.cate_list]
        dataset = NoLinkGeneratePseudoDataset(self.opt.source, cate_list=cate_list, transform=self.val_transform_seg)

        for imgs, patch_name_list, path, name in tqdm(dataset, desc='Use best model to generate new pseudo masks'):
            preds = self.net(imgs.to(self.device), 'seg')
            preds = preds.argmax(dim=1)
            preds = preds.cpu().numpy().astype(np.uint8)

            mask_save_path = os.path.join(path, self.opt.generate_pseudo_dir, name)
            os.makedirs(mask_save_path, exist_ok=False)

            for pred_img, patchs in zip(preds, patch_name_list):
                for idx, patch in enumerate(patchs):
                    x, y = idx % PATCH_EDGE * SPLIT_EDGE, idx // PATCH_EDGE * SPLIT_EDGE
                    p_img = Image.fromarray(pred_img[y:y+SPLIT_EDGE, x:x+SPLIT_EDGE])
                    p_img.save(os.path.join(mask_save_path, patch))
                    
                    # # visualize
                    # p_img = pred_img[y:y+SPLIT_EDGE, x:x+SPLIT_EDGE]
                    # vis_img = np.zeros((p_img.shape[0], p_img.shape[1], 3), dtype=np.uint8)
                    # vis_img[p_img==1] = [255,0,0]
                    # vis_img[p_img==2] = [0,255,0]
                    # vis_img = Image.fromarray(vis_img)
                    # vis_img.save(os.path.join(mask_save_path, patch))
