import cv2
import torch
import torch.nn as nn
import numpy as np
from PIL import Image
from torch.autograd import Variable
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
import time

from losses import *


def adjust_learning_rate(optimizer, epoch, lr, iter_num, start=0, decrease=0.):
    epoch = epoch-start
    lr *= 0.1 ** (epoch // iter_num)
    if decrease: lr *= 1 / (1 + (epoch % iter_num) * decrease)
    #lr = max(lr, 1e-5)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


class HookModule:
    def __init__(self, model, module):
        self.model = model
        self.activations = None

        module.register_forward_hook(self._hook_activations)

    def _hook_activations(self, module, inputs, outputs):
        self.activations = outputs

    def grads(self, outputs, inputs, retain_graph=True, create_graph=True):
        grads = torch.autograd.grad(outputs=outputs,
                                    inputs=inputs,
                                    retain_graph=retain_graph,
                                    create_graph=create_graph)[0]
        self.model.zero_grad()
        return grads


def output_newimgs(epoch_list, img_idx=0, main_name='newimgs_cifar10_noise40'):
    newimgs_list = []
    for epoch in epoch_list:
        newimgs_list.append(np.load('result/{0}_{1}.npy'.format(main_name, epoch)))
    for i,newimgs in enumerate(newimgs_list):
        Image.fromarray(newimgs[img_idx].astype(np.uint8)).save('result/new{}.png'.format(i))


def get_triplet_filtration(triplet_mat, train_label_c, train_label_w, train_pred_w):
    rec_label_w = torch.ones(train_label_w.shape[0]).long() * -1
    for idx_c1 in range(triplet_mat.shape[0]):
        idx_w = torch.argmax(triplet_mat[idx_c1, :]).item()
        idx_c2 = torch.argmax(triplet_mat[:, idx_w]).item()
        if idx_c1 == idx_c2:
            idx_c2 = torch.argsort(triplet_mat[:, idx_w], descending=True)[1].item()

        if train_label_c[idx_c1] != train_label_c[idx_c2]:
            rec_label_w[idx_w] = -1
        elif train_label_c[idx_c1] == train_pred_w[idx_w]:# or train_label_c[idx_c1] == train_label_w[idx_w]:
            rec_label_w[idx_w] = train_label_c[idx_c1]
        else:
            rec_label_w[idx_w] = -1
    return rec_label_w


def view_grads(grads, fig_w, fig_h, fig_path):
    f, ax = plt.subplots(figsize=(fig_w, fig_h), ncols=1)
    ax.set_xlabel('convolutional kernel')
    ax.set_ylabel('category')
    sns.heatmap(grads, annot=False, ax=ax, vmax=1.0, vmin=0.0)
    plt.savefig(fig_path, bbox_inches='tight')
    # plt.show()
    plt.clf()

# weight route
'''
def grad_route_(model, data_loader, loss_name='ce', max_iter=20, branch_idx=0):
    # processing
    dim1 = dim2 = 512
    route = np.zeros((10,dim1,dim2))
    route_num = np.zeros((10))+0.01

    for batch_idx, data in enumerate(data_loader):
        if batch_idx == max_iter:
            break
        if len(data) == 4:
            train_image, train_label_a, train_label_true, train_index = data
        if len(data) == 5:
            train_image, train_label_a, train_label_b, train_label_true, train_index = data
        if branch_idx == 0:
            train_label = train_label_a
        elif branch_idx == 1:
            train_label = train_label_b
        train_image = Variable(train_image.float(), requires_grad=True).cuda()
        train_label = Variable(train_label).cuda()
        train_label_true = Variable(train_label_true).cuda()
        #target_layer = model.encoder[17]
        #target_layer = model.features[40]
        target_layer = model.convsplit
        cls_output = model(train_image)
        cls_output = nn.Softmax(dim=1)(cls_output)
        prediction = torch.sum(cls_output * train_label.float(), dim=1)
        prediction_true = torch.sum(cls_output * train_label_true.float(), dim=1)

        for sample_idx in range(train_label.shape[0]):
            if cls_output[sample_idx].argmax() != train_label[sample_idx].argmax():
                continue
            if loss_name == 'ce':
                sample_loss = -torch.log(prediction[sample_idx])
            elif loss_name == 'sl':
                sample_loss = SLLoss(cls_output[sample_idx].unsqueeze(0), train_label[sample_idx].unsqueeze(0))
            elif loss_name == 'gce':
                sample_loss = GCELoss(cls_output[sample_idx].unsqueeze(0), train_label[sample_idx].unsqueeze(0), 0.9)
            #loss = SLLoss(cls_output[sample_idx], train_label[sample_idx], 0.1, 1.0)
            grads = torch.autograd.grad(outputs=sample_loss, inputs=target_layer.weight, retain_graph=True, grad_outputs=torch.ones_like(sample_loss))
            label = torch.argmax(train_label, dim=1)[sample_idx]
            route[label] = route[label] + grads[0].abs().sum(dim=(2,3)).detach().cpu().numpy()
            route_num[label] += 1
        print('\r', batch_idx, end='')
        
    route_norm = route / route_num.reshape(10, 1, 1)
    route_norm = route_norm
    norm_min = route_norm.min()
    norm_max = route_norm.max()
    route_norm = (route_norm - norm_min) / (norm_max- norm_min)
    return route_norm, norm_min, norm_max
'''


# weight route
'''
def MaskedGrad_(output, labels, target_layer, route_norm, norm_values, rec_ratio, mode='abs', kernel_size=3, loss_name='ce'):
    total_grad = torch.zeros(output.shape[0], route_norm[0].shape[0], route_norm[0].shape[1], kernel_size, kernel_size).cuda()
    norm_min, norm_max = norm_values
    route_norm_low = route_norm < route_norm.mean()
    threshold = output.max(dim=1)[0].sort()[0][-min(int(rec_ratio*labels.shape[0]+1), labels.shape[0])]
    for sample_idx in range(output.shape[0]):
        if loss_name == 'ce':
            #sample_loss = -torch.log(torch.sum(output * labels.float(), dim=1)[sample_idx])
            sample_loss = -torch.log(output[sample_idx].max())
        elif loss_name == 'sl':
            sample_loss = SLLoss(output[sample_idx].unsqueeze(0), labels[sample_idx].unsqueeze(0))
        elif loss_name == 'gce':
            sample_loss = GCELoss(output[sample_idx].unsqueeze(0), labels[sample_idx].unsqueeze(0), 0.9)
        total_grad[sample_idx] = torch.autograd.grad(outputs=sample_loss, inputs=target_layer.weight, retain_graph=True, grad_outputs=torch.ones_like(sample_loss))[0]
    grads_norm = (total_grad - norm_min) / (norm_max - norm_min)
    cls_distance = torch.tensor([grads_norm.abs().sum(dim=(3,4))[:,route_norm_low[i]].sum(dim=1).cpu().numpy() / route_norm_low[i].sum() for i in range(10)])
    rec_mask = (cls_distance.argmin(dim=0).cuda() == output.argmax(dim=1)) & (output.argmax(dim=1) != labels.argmax(dim=1)) & (output.max(dim=1)[0] > threshold)
    
    map_pred = cls_distance.argmin(dim=0)
    total_grad[rec_mask] = 0
    total_grad *= (torch.tensor(route_norm)[map_pred] ** (1/2)).reshape(map_pred.shape[0], *route_norm.shape[1:], 1, 1).float().cuda()
    masked_grad = total_grad.sum(dim=0)
    map_pred[~rec_mask] = -1
    return masked_grad, map_pred
'''


# feature route
def grad_route(model, input_layers, output_layers, data_loader, n_class, loss_name='ce', max_iter=351, branch_idx=0):
    # processing
    g_layer_num = len(output_layers)
    a_layer_num = len(input_layers)

    for batch_idx, data in enumerate(data_loader):
        if batch_idx == max_iter:
            break
        if len(data) == 4:
            train_image, train_label_a, train_label_true, train_index = data
        if len(data) == 5:
            train_image, train_label_a, train_label_b, train_label_true, train_index = data
        if branch_idx == 0:
            train_label = train_label_a
        elif branch_idx == 1:
            train_label = train_label_b
        train_image = Variable(train_image.float(), requires_grad=True).cuda()
        train_label = Variable(train_label).cuda()
        input_modules = [HookModule(model=model, module=layer) for layer in input_layers]
        output_modules = [HookModule(model=model, module=layer) for layer in output_layers]
        cls_logits = model(train_image)
        cls_output = nn.Softmax(dim=1)(cls_logits)
        cls_logits = torch.nn.NLLLoss(reduction='none')(-cls_logits, train_label.argmax(dim=1))
        grads = [torch.autograd.grad(outputs=cls_logits.sum(), inputs=m.activations, retain_graph=True, create_graph=False)[0] for m in output_modules]
        for layer_idx in range(g_layer_num):
            grads[layer_idx][grads[layer_idx]<0] = 0
        activations = [m.activations for m in input_modules]
        if batch_idx == 0:
            g_routes = [torch.zeros((n_class, grads[i].shape[1])).cuda() for i in range(g_layer_num)]
            g_routes_num = torch.zeros((n_class)).cuda()+0.01
            a_routes = [torch.zeros((n_class, activations[i].shape[1])).cuda() for i in range(a_layer_num)]
            a_routes_num = torch.zeros((n_class)).cuda()+0.01
        
        conf_set = torch.nn.NLLLoss(reduction='none')(-cls_output, train_label.argmax(dim=1)).argsort()[-train_label.shape[0]//2:]
        train_label = train_label[conf_set]
        for i in range(g_layer_num):
            grads[i] = grads[i][conf_set]
            if len(grads[i].shape) == 4:
                g_routes[i] += (train_label.T.float() @ grads[i].abs().sum(dim=(2,3))).detach() # (dim_out)
            elif len(grads[i].shape) == 2:
                g_routes[i] += (train_label.T.float() @ grads[i].abs()).detach() # (dim_out)
        for i in range(a_layer_num):
            activations[i] = activations[i][conf_set]
            if len(activations[i].shape) == 4:
                a_routes[i] += (train_label.T.float() @ activations[i].abs().sum(dim=(2,3))).detach() # (dim_out)
            if len(activations[i].shape) == 2:
                a_routes[i] += (train_label.T.float() @ activations[i].abs()).detach() # (dim_out)
        g_routes_num += conf_set.shape[0]
        a_routes_num += conf_set.shape[0]
        print('\r', batch_idx, end='')

    g_routes = [route / g_routes_num.reshape(-1, 1) for route in g_routes] # class_num, dim
    g_routes_min = [route.min(dim=-1)[0].unsqueeze(1) for route in g_routes]
    g_routes_max = [route.max(dim=-1)[0].unsqueeze(1) for route in g_routes]
    g_routes_norm = [(r - rmin) / (rmax- rmin) for r, rmin, rmax in zip(g_routes, g_routes_min, g_routes_max)]
    a_routes = [route / a_routes_num.reshape(-1, 1) for route in a_routes] # class_num, dim
    a_routes_min = [route.min(dim=-1)[0].unsqueeze(1) for route in a_routes]
    a_routes_max = [route.max(dim=-1)[0].unsqueeze(1) for route in a_routes]
    a_routes_norm = [(r - rmin) / (rmax- rmin) for r, rmin, rmax in zip(a_routes, a_routes_min, a_routes_max)]
    return g_routes_norm, a_routes_norm, g_routes, a_routes


def norm(g, norm_type='max1'):
    if isinstance(g, list):
        g = [(g_ + g_.abs()) / 2 for g_ in g]
        g = [g_.sum(dim=(2,3)) for g_  in g]
        if norm_type == 'max1':
            g_min = [g_.min(dim=1)[0].reshape(10,1) for g_ in g]
            g_max = [g_.max(dim=1)[0].reshape(10,1) for g_ in g]
            g = [(g_-min_) / (max_-min_+1e-16) for g_, min_, max_ in zip(g, g_min, g_max)]
            g = [g_.detach()for g_ in g]
        elif norm_type == 'length':
            g = [nn.functional.normalize(g_, p=2.0, dim=1).detach() for g_ in g]
    else:
        g = (g + g.abs()) / 2
        g = g.sum(dim=(-1,-2))
        if norm_type == 'max1':
            g_min = g.min(dim=-1)[0].reshape(*g.shape[:-1],1)
            g_max = g.max(dim=-1)[0].reshape(*g.shape[:-1],1)
            g = (g-g_min) / (g_max-g_min+1e-16)
            g = g.detach()
        elif norm_type == 'length':
            g = nn.functional.normalize(g, p=2.0, dim=1).detach()
    return g


# feature route
def grad_map(model, input_layers, output_layers, data_loader, n_class, loss_name='ce', max_iter=351, branch_idx=0):
    model.eval()
    # processing
    g_layer_num = len(output_layers)
    a_layer_num = len(input_layers)

    for batch_idx, data in enumerate(data_loader):
        if batch_idx == max_iter:
            break
        if data[0].shape==data[1].shape:
            train_image, _, train_label_a = data[:3]
        elif len(data) <= 4:
            train_image, train_label_a = data[:2]
        elif len(data) == 5:
            train_image, train_label_a, train_label_b, train_label_true, train_index = data
        if branch_idx == 0:
            train_label = train_label_a
        elif branch_idx == 1:
            train_label = train_label_b
        if len(train_label.shape) == 1:
            train_label = torch.eye(n_class)[train_label]
        train_image = Variable(train_image.float(), requires_grad=True).cuda()
        train_label = Variable(train_label).cuda()
        input_modules = [HookModule(model=model, module=layer) for layer in input_layers]
        output_modules = [HookModule(model=model, module=layer) for layer in output_layers]
        cls_logits = model(train_image)
        cls_output = nn.Softmax(dim=1)(cls_logits)
        cls_output = (cls_output+1e-4)/(1+2e-4)
        cls_logits = torch.nn.NLLLoss(reduction='none')(-cls_logits, train_label.argmax(dim=1))
        grads = [torch.autograd.grad(outputs=cls_logits.sum(), inputs=m.activations, retain_graph=True, create_graph=False)[0] for m in output_modules]
        #for layer_idx in range(g_layer_num):
        #    grads[layer_idx][grads[layer_idx]<0] = 0
        activations = [m.activations for m in input_modules]
        if batch_idx == 0:
            g_routes = [torch.zeros((n_class, *grads[i].shape[1:])).cuda() for i in range(g_layer_num)]
            g_routes_num = torch.zeros((n_class)).cuda()+0.01
            a_routes = [torch.zeros((n_class, *activations[i].shape[1:])).cuda() for i in range(a_layer_num)]
            a_routes_num = torch.zeros((n_class)).cuda()+0.01
        
        conf_set = torch.nn.NLLLoss(reduction='none')(-cls_output, train_label.argmax(dim=1)).argsort()[-train_label.shape[0]//2:]
        train_label = train_label[conf_set]
        for i in range(g_layer_num):
            grads[i] = grads[i][conf_set]
            g_routes[i] += torch.matmul(grads[i].transpose(0,3), train_label.float()).transpose(0,3).detach() # (dim_out)
        for i in range(a_layer_num):
            activations[i] = activations[i][conf_set]
            a_routes[i] += torch.matmul(activations[i].transpose(0,3), train_label.float()).transpose(0,3).detach() # (dim_out)
        g_routes_num += conf_set.shape[0]
        a_routes_num += conf_set.shape[0]
        print('\r', batch_idx, end='')

    g_routes = [route / g_routes_num.reshape(-1, 1, 1, 1) for route in g_routes] # class_num, dim
    g_routes_min = [route.min(dim=-1)[0].min(dim=-1)[0].min(dim=-1)[0].reshape(route.shape[0], 1, 1, 1) for route in g_routes]
    g_routes_max = [route.max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0].reshape(route.shape[0], 1, 1, 1) for route in g_routes]
    g_routes_norm = [(r - rmin) / (rmax- rmin) for r, rmin, rmax in zip(g_routes, g_routes_min, g_routes_max)]
    a_routes = [route / a_routes_num.reshape(-1, 1, 1, 1) for route in a_routes] # class_num, dim
    a_routes_min = [route.min(dim=-1)[0].min(dim=-1)[0].min(dim=-1)[0].reshape(route.shape[0], 1, 1, 1) for route in a_routes]
    a_routes_max = [route.max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0].reshape(route.shape[0], 1, 1, 1) for route in a_routes]
    a_routes_norm = [(r - rmin) / (rmax- rmin) for r, rmin, rmax in zip(a_routes, a_routes_min, a_routes_max)]
    return g_routes_norm, a_routes_norm, g_routes, a_routes
    

# feature route
def activation_pred(x, labels, model, input_layers, a_routes_norm):
    model.eval()
    input_modules = [HookModule(model=model, module=layer) for layer in input_layers]
    cls_logits = model(x)
    cls_logits = torch.nn.NLLLoss(reduction='none')(-cls_logits, labels.argmax(dim=1))
    
    total_activations = [m.activations for m in input_modules]
    total_activations = [g.abs().sum(dim=(2,3)) if len(g.shape)==4 else g.abs() for g in total_activations]
    total_activations = [(g - g.min(dim=1)[0].reshape(-1,1)) / (g.max(dim=1)[0] - g.min(dim=1)[0] + 1e-6).reshape(-1,1) for g in total_activations]
    activation_distance = sum([abs((g.unsqueeze(1)-a_routes_norm[layer_idx].unsqueeze(0))).sum(dim=-1) for layer_idx, g in enumerate(total_activations)])
    a_pred = activation_distance.argmin(axis=1)

    return a_pred


# feature route
def MaskedGrad(x, labels, model, input_layers, output_layers , g_routes_norm, a_routes_norm, rec_ratio, loss_name='ce', a_pred=None):
    model.eval()
    layer_num = len(output_layers)
    #g_routes_norm_low = [rn < rn.mean(dim=1).reshape(-1, 1) for rn in g_routes_norm]
    #a_routes_norm_low = [rn < rn.mean(dim=1).reshape(-1, 1) for rn in a_routes_norm]
    input_modules = [HookModule(model=model, module=layer) for layer in input_layers]
    output_modules = [HookModule(model=model, module=layer) for layer in output_layers]
    cls_logits = model(x)
    cls_output = nn.Softmax(dim=1)(cls_logits)
    cls_logits = torch.nn.NLLLoss(reduction='none')(-cls_logits, labels.argmax(dim=1))
    threshold = cls_output.max(dim=1)[0].sort()[0][-min(int(rec_ratio*labels.shape[0]+1), labels.shape[0])]
    #for sample_idx in range(output.shape[0]):
    #    sample_loss = cls_loss[sample_idx]
    #    weight_grads = [torch.autograd.grad(outputs=sample_loss, inputs=t_layer.weight, retain_graph=True)[0] for t_layer in target_layers]
    #    if sample_idx == 0:
    #        total_weight_grads = [torch.zeros(output.shape[0], *w_grad.shape).cuda() for w_grad in weight_grads]
    #    for layer_idx, w_grad in enumerate(weight_grads):
    #        total_weight_grads[layer_idx][sample_idx] = w_grad
    if a_pred is None:
        total_activations = [m.activations for m in input_modules]
        #total_activations[total_activations<0] = 0
        total_activations = [a.abs().sum(dim=(2,3)) if len(a.shape)==4 else a.abs() for a in total_activations]
        total_activations = [(a - a.min(dim=1)[0].reshape(-1,1)) / (a.max(dim=1)[0] - a.min(dim=1)[0] + 1e-6).reshape(-1,1) for a in total_activations]
        activation_distance = sum([abs((a.unsqueeze(1)-a_routes_norm[layer_idx].unsqueeze(0))).sum(dim=-1) for layer_idx, a in enumerate(total_activations)])
        a_pred = activation_distance.argmin(axis=1)

    total_feature_grads = [m.grads(outputs=cls_logits.sum(), inputs=m.activations, retain_graph=True, create_graph=True) for m in output_modules]
    total_feature_grads_abs = [g.abs().sum(dim=(2,3)) if len(g.shape)==4 else g.abs() for g in total_feature_grads]
    total_feature_grads_pos = [((g.abs()+g)/2).sum(dim=(2,3)) if len(g.shape)==4 else g.abs() for g in total_feature_grads]
    g_target1 = [(g_pos.detach() - g_abs).abs().mean() for g_abs, g_pos in zip(total_feature_grads_abs, total_feature_grads_pos)]
    total_feature_grads_abs = [(g - g.min(dim=1)[0].reshape(-1,1)) / (g.max(dim=1)[0] - g.min(dim=1)[0] + 1e-6).reshape(-1,1) for g in total_feature_grads_abs]
    total_feature_grads_pos = [(g - g.min(dim=1)[0].reshape(-1,1)) / (g.max(dim=1)[0] - g.min(dim=1)[0] + 1e-6).reshape(-1,1) for g in total_feature_grads_pos]
    g_target2 = [(g_norm[a_pred].detach() - g_pos).abs().mean() for g_pos, g_norm in zip(total_feature_grads_pos, g_routes_norm)]
    #g_target2 = [g_neg.abs().mean() for g_neg in total_feature_grads_neg]
    grad_distance = sum([abs((g.unsqueeze(1)-g_routes_norm[layer_idx].unsqueeze(0))).sum(dim=-1) for layer_idx, g in enumerate(total_feature_grads_pos)])
    g_pred = grad_distance.argmin(axis=1)

    #rec_mask = (map_pred == output.argmax(dim=1)) & (output.argmax(dim=1) != labels.argmax(dim=1)) & (output.max(dim=1)[0] > threshold)
    #cons_mask = (map_pred == labels.argmax(dim=1))
    
    #masked_grads = [g * (routes_norm[layer_idx][map_pred] ** 1).reshape(map_pred.shape[0], routes_norm[layer_idx].shape[-1], 1, 1 ,1).float() if layer_idx != 2 else g * (routes_norm[layer_idx][map_pred] ** 1).reshape(map_pred.shape[0], routes_norm[layer_idx].shape[-1], 1).float() for layer_idx, g in enumerate(total_weight_grads)]
    #for layer_idx in range(layer_num):
    #    masked_grads[layer_idx][~cons_mask] = 0
    #    masked_grads[layer_idx] = torch.mean(masked_grads[layer_idx], dim=0)
    #map_pred[~rec_mask] = -1
    #g_target = [g * (g_routes_norm_low[layer_idx][labels.argmax(dim=1)]).float() for layer_idx, g in enumerate(total_feature_grads_pos)]
    #a_target = [g * (a_routes_norm_low[layer_idx][labels.argmax(dim=1)]).float() for layer_idx, g in enumerate(total_activations)]
    return g_target1, g_target2, g_pred, a_pred


# feature route
def RectifiedGrad(cls_logits, labels, model, total_activations, total_feature_grads, g_routes, noise_mask, ly2z, n_class, tau=0, std_w=1.0, T=0.1):
    #model.train()
    #cls_logits = model(x)
    cls_output = nn.Softmax(dim=1)(cls_logits)
    cls_output = (cls_output+1e-4)/(1+2e-4)
    a_pred = cls_output.argmax(dim=1).detach()

    total_feature_grads_pos = [(g.abs()+g)/2 for g in total_feature_grads]

    py = torch.nn.NLLLoss(reduction='none')(-cls_output, labels.argmax(dim=1)).detach()
    pa = torch.nn.NLLLoss(reduction='none')(-cls_output, a_pred).detach()
    #grad_w = (np.e**(py/pa)-1) / (np.e**(py/pa)+1)
    #grad_w = (1-((1-np.e**-(pa/py/2)) / (1+np.e**-(pa/py/2))))
    p_max = torch.cat((py.unsqueeze(0), pa.unsqueeze(0))).max(dim=0)[0]
    grad_w = (torch.e**((py-pa)/T) + 1-p_max) / (torch.e**((py-pa)/T) + 1)
 
    '''
    grad_direction = torch.ones_like(labels)
    grad_direction[labels.argmax(dim=1)] = -1
    grad_direction[a_pred] = 0
    l2a_clean = [(ly2z*z2a).sum(dim=0) for z2a in total_feature_grads]
    l2a_noise = [(ly2z*z2a*(grad_direction>0).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) * grad_w.reshape(-1,1,1,1) for z2a in total_feature_grads]
    l2a_route = [((lq2z*(grad_direction==0).float().transpose(0,1).reshape(*lq2z.shape)).sum(dim=0) * g[a_pred] +\
                 (lq2z*(grad_direction<0).float().transpose(0,1).reshape(*lq2z.shape)).sum(dim=0) * g[labels.argmax(dim=1)]\
                 ) * (1-grad_w).reshape(-1,1,1,1) for g in g_routes]
    l2a_total = [c * (~noise_mask).float().reshape(ly2z.shape[1],1,1,1) + (r+n) * noise_mask.float().reshape(ly2z.shape[1],1,1,1) for c, n, r in zip(l2a_clean, l2a_noise, l2a_route)]
    '''
    # -1:y, 0:q, 1:k
    grad_direction = torch.ones_like(labels).float()
    grad_direction *= (labels*-2+1)
    grad_direction *= (1-torch.eye(labels.shape[1])[a_pred]).cuda()
    # confident samples
    l2a_clean = [(ly2z*z2a).sum(dim=0) for z2a in total_feature_grads]
    # unconfident samples
    l2a_k = [(ly2z*z2a*(grad_direction==1).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) for z2a in total_feature_grads]
    l2a_tq_ori = [(ly2z*z2a*(grad_direction<1).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) * grad_w.reshape(-1,1,1,1) for z2a in total_feature_grads]

    # cgo std 
    l2a_t_std = [((ly2z*(grad_direction==-1).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) * g[a_pred] * ((1-grad_w) * (py+1e-6) / (pa+1e-6)).reshape(-1,1,1,1)) for g in g_routes] # 
    l2a_q_std = [((ly2z*(grad_direction==0).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) * g[labels.argmax(dim=1)] * ((1-grad_w) * (py+1e-6) / (pa+1e-6)).reshape(-1,1,1,1)) for g in g_routes]
    # cgo ori
    #l2a_t_std = [((ly2z*(grad_direction==-1).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) * (z2a*(grad_direction==0).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) * ((1-grad_w) * (pa-1+1e-6) / (py-1+1e-6)).reshape(-1,1,1,1)) for z2a in total_feature_grads]
    #l2a_q_std = [((ly2z*(grad_direction==0).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) * (z2a*(grad_direction==-1).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) * ((1-grad_w) * (py+1e-6) / (pa+1e-6)).reshape(-1,1,1,1)) for z2a in total_feature_grads]
    # std
    #l2a_t_std = [((ly2z*(grad_direction==-1).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) * g[labels.argmax(dim=1)] * ((1-grad_w)).reshape(-1,1,1,1))  for g in g_routes]
    #l2a_q_std = [((ly2z*(grad_direction==0).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0) * g[a_pred] * ((1-grad_w)).reshape(-1,1,1,1)) for g in g_routes]
    # none
    #l2a_t_std = [((ly2z*z2a*(grad_direction==-1).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0)) * (1-grad_w).reshape(-1,1,1,1) for z2a in total_feature_grads]
    #l2a_q_std = [((ly2z*z2a*(grad_direction==0).float().transpose(0,1).reshape(*ly2z.shape)).sum(dim=0)) * (1-grad_w).reshape(-1,1,1,1) for z2a in total_feature_grads]

    l2a_noise = [k+tq_ori+t_std+q_std for k, tq_ori, t_std, q_std in zip(l2a_k,l2a_tq_ori,l2a_t_std,l2a_q_std)]
    # total
    #noise_mask = torch.zeros_like(noise_mask)
    l2a_total = [c * (~noise_mask).float().reshape(ly2z.shape[1],1,1,1) + n * noise_mask.float().reshape(ly2z.shape[1],1,1,1) for c, n in zip(l2a_clean, l2a_noise)]
    rectified_grad = [F.conv2d(a.transpose(0,1), l2a.transpose(0,1), padding=1).transpose(0,1) for a, l2a in zip(total_activations, l2a_total)]

    return rectified_grad


# feature route
def GradConstrain(x, labels, model, output_layers):
    model.eval()
    layer_num = len(output_layers)
    output_modules = [HookModule(model=model, module=layer) for layer in output_layers]
    cls_logits = model(x)
    cls_output = nn.Softmax(dim=1)(cls_logits)
    cls_logits_l = torch.nn.NLLLoss(reduction='none')(-cls_logits, labels.argmax(dim=1))
    
    total_feature_grads = [m.grads(outputs=cls_logits.sum()-cls_logits_l.sum(), inputs=m.activations, retain_graph=True, create_graph=True) for m in output_modules]
    total_feature_grads_pos = [((g.abs()+g)/2).sum(dim=(2,3)) if len(g.shape)==4 else g.abs() for g in total_feature_grads]
    #total_feature_grads_l = [m.grads(outputs=cls_logits_l.sum(), inputs=m.activations, retain_graph=True, create_graph=True) for m in output_modules]
    #total_feature_grads_l_pos = [((g.abs()+g)/2).sum(dim=(2,3)) if len(g.shape)==4 else g.abs() for g in total_feature_grads_l]
    #g_target = [(g_l.detach() - g).abs().mean() for g, g_l in zip(total_feature_grads_pos, total_feature_grads_l_pos)]
    g_target = [g.abs().mean() for g in total_feature_grads_pos]

    return g_target


# feature route
def grad_route_map(model, target_layers, data_loader, loss_name='ce', max_iter=20, branch_idx=0):
    model.eval()
    # processing
    dim_in = dim_out = 512
    route = torch.zeros((10, dim_out, dim_in)).cuda()
    route_num = torch.zeros((10)).cuda()+0.01

    for batch_idx, data in enumerate(data_loader):
        if batch_idx == max_iter:
            break
        if len(data) == 4:
            train_image, train_label_a, train_label_true, train_index = data
        if len(data) == 5:
            train_image, train_label_a, train_label_b, train_label_true, train_index = data
        if branch_idx == 0:
            train_label = train_label_a
        elif branch_idx == 1:
            train_label = train_label_b
        train_image = Variable(train_image.float(), requires_grad=True).cuda()
        train_label = Variable(train_label).cuda()
        train_label_true = Variable(train_label_true).cuda()
        module_in = HookModule(model=model, module=target_layers[0])
        module_out = HookModule(model=model, module=target_layers[1])
        cls_output = model(train_image)
        cls_output = nn.Softmax(dim=1)(cls_output)
        grads = [torch.autograd.grad(outputs=module_out.activations[:, d].abs().sum(), inputs=module_in.activations, retain_graph=True, create_graph=False)[0].abs().sum(dim=(2,3)) for d in range(dim_out)]
        grads = torch.stack(grads).transpose(0, 1).cuda() # batch, dim_out, dim_in
        #grads[grads<0] = 0
        route += (train_label.T.float() @ grads.reshape(grads.shape[0], -1)).reshape(-1, dim_out, dim_in)
        route_num += train_label.sum(dim=0).float()
        print('\r', batch_idx, end='')

    route_norm = route / route_num.reshape(-1, 1, 1)
    route_min = route_norm.min(dim=2)[0].min(dim=1)[0].reshape(-1, 1, 1)
    route_max = route_norm.max(dim=2)[0].max(dim=1)[0].reshape(-1, 1, 1)
    route_norm = (route_norm - route_min) / (route_max- route_min)
    return route_norm, route_min, route_max #, route_out, route_in


# feature route
def MaskedGrad_map(x, output, labels, model, target_layers, route_norm, norm_values, rec_ratio, cons_ratio, mode='abs', kernel_size=3, loss_name='ce'):
    model.eval()
    dim_in = dim_out = 512
    route_min, route_max = norm_values
    route_norm_low = route_norm < route_norm.mean(axis=(1,2)).reshape(-1, 1, 1)
    route_norm_high = route_norm > 0.5
    module_in = HookModule(model=model, module=target_layers[0])
    module_out = HookModule(model=model, module=target_layers[1])
    cls_output = model(x)
    cls_output = nn.Softmax(dim=1)(cls_output)
    threshold = output.max(dim=1)[0].sort()[0][-min(int(rec_ratio*labels.shape[0]+1), labels.shape[0])]
    if loss_name == 'ce':
        cls_loss = torch.nn.NLLLoss(reduction='none')(torch.log(cls_output), labels.argmax(dim=1))
    elif loss_name == 'sl':
        cls_loss = SLLoss(cls_output, labels, reduction='none')
    elif loss_name == 'gce':
        cls_loss = GCELoss(cls_output, labels, 0.9, reduction='none')
    #for sample_idx in range(output.shape[0]):
    #    sample_loss = cls_loss[sample_idx]
    #    weight_grads = torch.autograd.grad(outputs=sample_loss, inputs=target_layers[2].weight, retain_graph=True)[0]
    #    if sample_idx == 0:
    #        total_weight_grads = torch.zeros(output.shape[0], *weight_grads.shape).cuda()
    #    total_weight_grads[sample_idx] = weight_grads
    total_feature_grads = [torch.autograd.grad(outputs=module_out.activations[:, d].abs().sum(), inputs=module_in.activations, retain_graph=True, create_graph=False)[0].abs().sum(dim=(2,3)) for d in range(dim_out)]
    total_feature_grads = torch.stack(total_feature_grads).transpose(0, 1) # batch, dim_out, dim_in
    #total_feature_grad[total_feature_grad<0] = 0
    total_feature_grads = (total_feature_grads.unsqueeze(1) - route_min.unsqueeze(0)) / (route_max.unsqueeze(0) - route_min.unsqueeze(0))
    
    cls_distance = (total_feature_grads*route_norm_low.unsqueeze(0).float()).sum(dim=(2,3)) / route_norm_low.float().sum(dim=(1,2)).reshape(1,-1)
    map_pred = cls_distance.argmin(dim=1)
    rec_mask = (map_pred == output.argmax(dim=1)) & (output.argmax(dim=1) != labels.argmax(dim=1)) & (output.max(dim=1)[0] > threshold)
    map_conf = torch.nn.NLLLoss(reduction='none')(-cls_output, map_pred)
    cons_threshold = map_conf.sort(descending=True)[0][int(cons_ratio*map_conf.shape[0])-1]
    cons_mask = map_conf < cons_threshold
    '''
    masked_grads = total_weight_grads * (route_norm[map_pred] ** (1/2)).reshape(map_pred.shape[0], *route_norm.shape[1:], 1 ,1).float().cuda()
    masked_grads[~cons_mask] = 0
    masked_grads = torch.mean(masked_grads, dim=0)
    map_pred[~rec_mask] = -1
    return [masked_grads], route_grad_target, map_pred, cons_mask
    '''
    route_grad_target = (total_feature_grads * torch.eye(10).cuda()[map_pred].reshape(map_pred.shape[0],10,1,1)).sum(dim=1) * (route_norm_low[map_pred]).reshape(labels.shape[0], *route_norm.shape[1:]).float().cuda()
    return None, route_grad_target, map_pred, cons_mask


def grad_rectify(g_list, a_list):
    L = len(g_list)
    C, D, FH, FW = g_list[0].shape
    a_list = [a.unsqueeze(1).repeat(1,C,1,1,1).reshape(C**2, *a.shape[-3:]) for a in a_list]
    g_list = [(g.unsqueeze(0) - g.unsqueeze(1)).reshape(C**2, *g.shape[-3:])/(C**2) for g in g_list]
    rec_grad_w = [F.conv2d(a.transpose(0,1), g.transpose(0,1), padding=1).transpose(0,1)/(C**2) for a,g in zip(a_list, g_list)]
    return rec_grad_w


def grad_rectify_real(g_total, a_total, y_noisy, y_true):
    L = len(g_total)
    C, N, N, D, FH, FW = g_total[0].shape
    g_diff = [(g * (y_noisy - y_true).reshape(N, C, 1, 1, 1, 1).transpose(0,1)/N).sum(dim=(0,1)) for g in g_total]
    rec_grad_w = [F.conv2d(a.transpose(0,1), g.transpose(0,1), padding=1).transpose(0,1) for a,g in zip(a_total, g_diff)]
    return rec_grad_w


def grad_real(g_total, a_total, ly2z):
    L = len(g_total)
    C, N, N, D, FH, FW = g_total[0].shape
    ly2a_list = [(ly2z*z2a).sum(dim=(0,1)) for z2a in g_total]
    g_final = [F.conv2d(a.transpose(0,1), l2a.transpose(0,1), padding=1).transpose(0,1) for a, l2a in zip(a_total, ly2a_list)]
    return g_final