import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import seaborn as sns
import matplotlib.pyplot as plt
# Torchvision
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet34, resnet50
from torch.utils.tensorboard import SummaryWriter

from dataset import *
from mymodels import *
import models
from losses import *
from utils import *

class HookModule:
    def __init__(self, model, module):
        self.model = model
        self.activations = None

        module.register_forward_hook(self._hook_activations)

    def _hook_activations(self, module, inputs, outputs):
        self.activations = outputs

    def grads(self, outputs, inputs, retain_graph=True, create_graph=True):
        grads = torch.autograd.grad(outputs=outputs,
                                    inputs=inputs,
                                    retain_graph=retain_graph,
                                    create_graph=create_graph)[0]
        self.model.zero_grad()
        return grads

def view_grads(grads, fig_w, fig_h, fig_path):
    f, ax = plt.subplots(figsize=(fig_w, fig_h), ncols=1)
    ax.set_xlabel('convolutional kernel')
    ax.set_ylabel('category')
    sns.heatmap(grads, annot=False, ax=ax)
    plt.savefig(fig_path, bbox_inches='tight')
    # plt.show()
    plt.clf()


def conv_grad_visual(noise_ratio=40, loss_name='ce', epoch_idx=0):
    # parameters
    data_type = 'cifar-10'
    n_class = 10
    model_type = 'layer6'
    batch_size = 128
    padding = 2
    img_size = 32
    img_dim = 3
    #checkpoint_path = 'checkpoints/{2}_{0}_cifar10_noise{1}.pth'.format(loss_name, noise_ratio, data_type)
    checkpoint_path = 'checkpoints/{2}_{0}_cifar10_noise{1}_epoch{3}.pth'.format(loss_name, noise_ratio, data_type, epoch_idx)

    # dataset
    transform_train = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(img_size, padding=padding),
        transforms.ToTensor(),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
    ])
    X_train, y_train, y_train_clean, X_test, y_test = get_data(data_type, noise_ratio=noise_ratio, asym=False, random_shuffle=False)
    valid_idx = np.load('data/%s_valid_idx.npy' % (data_type))
    total_num = X_train.shape[0]
    X_valid = X_train[valid_idx]
    y_valid = y_train_clean[valid_idx]
    X_train = X_train[np.setdiff1d(np.arange(total_num), valid_idx)]
    y_train = y_train[np.setdiff1d(np.arange(total_num), valid_idx)]
    y_train_clean = y_train_clean[np.setdiff1d(np.arange(total_num), valid_idx)]
    train_datasets = MyDataset(X_train, y_train, y_train_clean, require_index=True, transform=transform_train)
    train_loader = DataLoader(train_datasets, batch_size=batch_size, shuffle=True, num_workers=2)
    valid_datasets = MyDataset(X_valid, y_valid, transform = transform_test)
    valid_loader = DataLoader(valid_datasets, batch_size=batch_size, shuffle=False, num_workers=2)
    test_datasets = MyDataset(X_test, y_test, transform = transform_test)
    test_loader = DataLoader(test_datasets, batch_size=batch_size, shuffle=False, num_workers=2)

    # model
    model = models.load_model(model_name='googlenet', in_channels=3, num_classes=10).cuda()
    model.load_state_dict(torch.load(checkpoint_path))
    model.eval()

    # processing
    taa = np.zeros((10, 128)) # true prediction; accurate sample; gradient on accurate label
    taa_num = np.zeros((10))+0.01
    fnn = np.zeros((10, 128)) # false prediction; noisy sample; gradient on noisy label
    fnn_num = np.zeros((10))+0.01
    fna = np.zeros((10, 128)) # false prediction; noisy sample; gradient on accurate label
    fna_num = np.zeros((10))+0.01
    tna = np.zeros((10, 128)) # true prediction; noisy sample; gradient on accurate label
    tna_num = np.zeros((10))+0.01
    for batch_idx, data in enumerate(train_loader):
        if batch_idx >= 10:
            break
        train_image, train_label, train_label_true, train_index = data
        train_image = Variable(train_image.float(), requires_grad=True).cuda()
        train_label = Variable(train_label).cuda()
        train_label_true = Variable(train_label_true).cuda()
        #target_layer = model.encoder[layer-1]
        target_layer = model.b5.b4[1]
        module = HookModule(model=model, module=target_layer)
        cls_output = model(train_image)
        cls_output = nn.Softmax(dim=1)(cls_output)
        prediction = torch.sum(cls_output * train_label.float(), dim=1)
        prediction_true = torch.sum(cls_output * train_label_true.float(), dim=1)
        for sample_idx in range(train_label.shape[0]):
            if torch.argmax(cls_output[sample_idx]) == torch.argmax(train_label[sample_idx]):
                nll_loss = -torch.log(prediction[sample_idx])
                grads = module.grads(outputs=-nll_loss, inputs=module.activations,
                                     retain_graph=True, create_graph=False)
                nll_loss.backward(retain_graph=True)
                label = torch.argmax(train_label, dim=1)[sample_idx]
                # abs
                #a[label] = a[label] + np.array([torch.abs(grads)[sample_idx,i].sum().item() for i in range(128)])
                # pos
                if (train_label.argmax(dim=1)==train_label_true.argmax(dim=1))[sample_idx] and min(taa_num)<100: 
                    # accurate
                    taa[label] = taa[label] + np.array([grads[sample_idx,i][grads[sample_idx,i]>0].sum().item() for i in range(128)])
                    taa_num[label] += 1
                elif (train_label.argmax(dim=1)!=train_label_true.argmax(dim=1))[sample_idx] and min(fnn_num)<100: 
                    # noisy
                    fnn[label] = fnn[label] + np.array([grads[sample_idx,i][grads[sample_idx,i]>0].sum().item() for i in range(128)])
                    fnn_num[label] += 1

                nll_loss = -torch.log(prediction_true[sample_idx])
                grads = module.grads(outputs=-nll_loss, inputs=module.activations,
                                     retain_graph=True, create_graph=False)
                nll_loss.backward(retain_graph=True)
                label = torch.argmax(train_label_true, dim=1)[sample_idx]
                if (train_label.argmax(dim=1)!=train_label_true.argmax(dim=1))[sample_idx] and min(fna_num)<100:
                    fna[label] = fna[label] + np.array([grads[sample_idx,i][grads[sample_idx,i]>0].sum().item() for i in range(128)])
                    fna_num[label] += 1

            elif torch.argmax(cls_output[sample_idx]) == torch.argmax(train_label_true[sample_idx]) > 0.5 and min(tna_num)<100:
                nll_loss = -torch.log(prediction_true[sample_idx])
                grads = module.grads(outputs=-nll_loss, inputs=module.activations,
                                     retain_graph=True, create_graph=False)
                nll_loss.backward(retain_graph=True)
                label = torch.argmax(train_label_true, dim=1)[sample_idx]
                tna[label] = tna[label] + np.array([grads[sample_idx,i][grads[sample_idx,i]>0].sum().item() for i in range(128)])
                tna_num[label] += 1
        print('\r', batch_idx, end='')

    # visual
    taa = taa / taa_num.reshape(10, 1) * 100
    fnn = fnn / fnn_num.reshape(10, 1) * 100
    fna = fna / fna_num.reshape(10, 1) * 100
    tna = tna / tna_num.reshape(10, 1) * 100
    max_value = max(taa.max(), fnn.max(), fna.max(), tna.max())
    min_value = min(taa.min(), fnn.min(), fna.min(), tna.min())
    taa = (taa - min_value) / (max_value - min_value)
    fnn = (fnn - min_value) / (max_value - min_value)
    fna = (fna - min_value) / (max_value - min_value)
    tna = (tna - min_value) / (max_value - min_value)
    print(taa_num, tna_num, fnn_num, fna_num)

    view_grads(taa, 128, 10, 'fig/{0}_noise{1}_epoch{2}_taa.png'.format(loss_name, noise_ratio, epoch_idx))
    view_grads(fnn, 128, 10, 'fig/{0}_noise{1}_epoch{2}_fnn.png'.format(loss_name, noise_ratio, epoch_idx))
    view_grads(fna, 128, 10, 'fig/{0}_noise{1}_epoch{2}_fna.png'.format(loss_name, noise_ratio, epoch_idx))
    view_grads(tna, 128, 10, 'fig/{0}_noise{1}_epoch{2}_tna.png'.format(loss_name, noise_ratio, epoch_idx))