from torch.autograd import Variable
from collections import OrderedDict
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from torch.nn import functional as F

'''
Functional definitions of common layers
Useful for when weights are exposed rather 
than being contained in modules
'''

def linear(input, weight, bias=None):
    if bias is None:
        return F.linear(input, weight.cuda())
    else:
        return F.linear(input, weight.cuda(), bias.cuda())

def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    return F.conv2d(input, weight, bias, stride, padding, dilation, groups)

def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1):
    return F.conv_transpose2d(input, weight.cuda(), bias.cuda(), stride, padding, dilation, groups)

def relu(input):
    return F.threshold(input, 0, 0, inplace=False)

def dropout(input, p=0.8, training=True):
    return F.dropout(input, p=p, training=training, inplace=False)

def sigmoid(input):
    return F.sigmoid(input)

def maxpool(input, kernel_size, stride=None):
    return F.max_pool2d(input, kernel_size, stride)


def mean_distance(a, b, weight=None, training=True):
    dis = ((a - b) ** 2).sum(-1)

    if weight is not None:
        dis *= weight

    if not training:
        return dis
    else:
        return dis.mean().unsqueeze(0)


def distance(a, b):
    return ((a - b) ** 2).sum(-1)


def heatmap(x, name='heatmap'):
    # import pdb;pdb.set_trace()
    x = x.squeeze(-1)
    for j in range(x.shape[2]):
        plt.cla()
        y = x[0, :, j].reshape((32, 32))
        df = pd.DataFrame(y.data.cpu().numpy())
        sns.heatmap(df)
        plt.savefig('results/heatmap/{}_{}.png'.format(name, str(j)))
        plt.close()
    return True


class Meta_Prototype(nn.Module):
    def __init__(self, proto_size, feature_dim, key_dim, temp_update, temp_gather, shrink_thres=0):
        super(Meta_Prototype, self).__init__()
        # Constants
        self.proto_size = proto_size
        self.feature_dim = feature_dim
        self.key_dim = key_dim
        self.temp_update = temp_update
        self.temp_gather = temp_gather
        # multi-head
        self.Mheads = nn.Linear(key_dim, proto_size, bias=False)
        # self.Dim_reduction = nn.Linear(key_dim, feature_dim)
        # self.softmax = nn.Softmax(dim=1)
        self.shrink_thres = shrink_thres

    def get_score(self, pro, query):
        bs, n, d = query.size()  # n=w*h
        bs, m, d = pro.size()
        # import pdb;pdb.set_trace()
        score = torch.bmm(query, pro.permute(0, 2, 1))  # b X h X w X m
        score = score.view(bs, n, m)  # b X n X m

        score_query = F.softmax(score, dim=1)
        score_proto = F.softmax(score, dim=2)

        return score_query, score_proto

    def forward(self, key, query, weights, train=True):

        batch_size, dims, h, w = key.size()  # b X d X h X w
        key = key.permute(0, 2, 3, 1)  # b X h X w X d
        _, _, h_, w_ = query.size()
        query = query.permute(0, 2, 3, 1)  # b X h X w X d
        query = query.reshape((batch_size, -1, self.feature_dim))
        # train
        if train:
            if weights == None:
                multi_heads_weights = self.Mheads(key)
                # p1+..+pk = topk(multi_heads_weights)
                # maximize (p1+..+pk)
            else:
                multi_heads_weights = linear(key, weights['prototype.Mheads.weight'])

            multi_heads_weights = multi_heads_weights.view((batch_size, h * w, self.proto_size, 1))

            # softmax on weights
            multi_heads_weights = F.softmax(multi_heads_weights, dim=1)

            key = key.reshape((batch_size, w * h, dims))
            protos = multi_heads_weights * key.unsqueeze(-2)
            protos = protos.sum(1)

            updated_query, fea_loss = self.query_loss(query, protos, weights, train)

            # skip connection
            updated_query = updated_query + query

            # reshape
            updated_query = updated_query.permute(0, 2, 1)  # b X d X n
            updated_query = updated_query.view((batch_size, self.feature_dim, h_, w_))
            return updated_query, protos, fea_loss

        # test
        else:
            if weights == None:
                multi_heads_weights = self.Mheads(key)
            else:
                multi_heads_weights = linear(key, weights['prototype.Mheads.weight'])

            multi_heads_weights = multi_heads_weights.view((batch_size, h * w, self.proto_size, 1))

            # softmax on weights
            multi_heads_weights = F.softmax(multi_heads_weights, dim=1)

            key = key.reshape((batch_size, w * h, dims))
            protos = multi_heads_weights * key.unsqueeze(-2)
            protos = protos.sum(1)

            # loss
            updated_query, fea_loss, query = self.query_loss(query, protos, weights, train)

            # skip connection
            updated_query = updated_query + query
            # reshape
            updated_query = updated_query.permute(0, 2, 1)  # b X d X n
            updated_query = updated_query.view((batch_size, self.feature_dim, h_, w_))
            return updated_query, protos, query, fea_loss

    def gaussian_kernel_vectorization(self, x1, x2, l=1.0, sigma_f=1.0):
        """More efficient approach."""
        dist_matrix = torch.sum(x1 ** 2, 1).reshape(-1, 1) + torch.sum(x2 ** 2, 1) - 2 * torch.mm(x1, x2.T)
        return sigma_f ** 2 * torch.exp(-0.5 / l ** 2 * dist_matrix)

    def query_loss(self, query, keys, weights, train):
        batch_size, n, dims = query.size()  # b X n X d, n=w*h
        if train:

            # Distinction constrain
            keys_ = F.normalize(keys, dim=-1)
            dis = 1 - distance(keys_.unsqueeze(1), keys_.unsqueeze(2))

            mask = dis > 0
            dis *= mask.float()
            dis = torch.triu(dis, diagonal=1)
            dis_loss = dis.sum(1).sum(1) * 2 / (self.proto_size * (self.proto_size - 1))
            dis_loss = dis_loss.mean()

            # Normal constrain
            loss_mse = torch.nn.MSELoss(reduction='none')

            keys = F.normalize(keys, dim=-1)
            _, softmax_score_proto = self.get_score(keys, query)

            new_query = softmax_score_proto.unsqueeze(-1) * keys.unsqueeze(1)
            new_query = new_query.sum(2)
            #new_query = F.normalize(new_query, dim=-1)

            topk_prob, gathering_indices = torch.topk(softmax_score_proto, 5, dim=-1)
            #b n k

            # 1st closest memories
            pos = torch.gather(keys, 1, gathering_indices[:, :, :1].repeat((1, 1, dims)))

            #top K memories
            mask = torch.zeros_like(softmax_score_proto).scatter(-1,gathering_indices, topk_prob.float())
            mask[mask>0] = 1
            topk_feature = softmax_score_proto.masked_fill((1 - mask).bool(), float(1e-9))
            topk_feature = F.softmax(topk_feature,-1)
            # b,n,m,1 b,1,m,d
            
            topk_feature = topk_feature.unsqueeze(-1) * keys.unsqueeze(1)
            
            topk_feature = topk_feature.sum(2)+new_query
            
            topk_feature = F.normalize(topk_feature, dim=-1)
            
            #print(topk_feature.shape)
            fea_loss = loss_mse(query, topk_feature)

            # Core Memory
            #topk_prob = topk_prob.sum(-1) # b x n

            #new_query = softmax_score_proto.unsqueeze(-1) * keys.unsqueeze(1)
            #new_query = new_query.sum(2)
            #new_query = F.normalize(new_query, dim=-1)

            # maintain the distinction among attribute vectors
            #_, gathering_indices = torch.topk(softmax_score_proto, 2, dim=-1)

            # 1st closest memories
            #pos = torch.gather(keys, 1, gathering_indices[:, :, :1].repeat((1, 1, dims)))

            #fea_loss = loss_mse(query, pos)

            return new_query, fea_loss.mean() + 0.0001*dis_loss.mean()


        else:
            loss_mse = torch.nn.MSELoss(reduction='none')

            keys = F.normalize(keys, dim=-1)
            softmax_score_query, softmax_score_proto = self.get_score(keys, query)

            new_query = softmax_score_proto.unsqueeze(-1) * keys.unsqueeze(1)
            new_query = new_query.sum(2)
            new_query = F.normalize(new_query, dim=-1)

            _, gathering_indices = torch.topk(softmax_score_proto, 2, dim=-1)

            # 1st closest memories
            pos = torch.gather(keys, 1, gathering_indices[:, :, :1].repeat((1, 1, dims)))

            fea_loss = loss_mse(query, pos)

            return new_query, fea_loss, query

def batchnorm(input, weight=None, bias=None, running_mean=None, running_var=None, training=True, eps=1e-5, momentum=0.1):
    ''' momentum = 1 restricts stats to the current mini-batch '''
    # This hack only works when momentum is 1 and avoids needing to track running stats
    # by substuting dummy variables
    # running_mean = torch.zeros(int(np.prod(np.array(input.data.size()[1])))).cuda()
    # running_var = torch.ones(int(np.prod(np.array(input.data.size()[1])))).cuda()
    return F.batch_norm(input, running_mean, running_var, weight, bias, training, momentum, eps)

def bilinear_upsample(in_, factor):
    return F.upsample(in_, None, factor, 'bilinear')

def log_softmax(input):
    return F.log_softmax(input)

class Encoder(torch.nn.Module):
    def __init__(self, t_channel=20, n_channel=3):
        super(Encoder, self).__init__()

        def Basic(intInput, intOutput):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False)
            )

        def Basic_(intInput, intOutput):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
            )

        self.moduleConv1 = Basic(t_channel, 64)
        self.modulePool1 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.moduleConv2 = Basic(64, 128)
        self.modulePool2 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.moduleConv3 = Basic(128, 256)
        self.modulePool3 = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.moduleConv4 = Basic_(256, 512)

    def forward(self, x):
        tensorConv1 = self.moduleConv1(x)
        tensorPool1 = self.modulePool1(tensorConv1)

        tensorConv2 = self.moduleConv2(tensorPool1)
        tensorPool2 = self.modulePool2(tensorConv2)

        tensorConv3 = self.moduleConv3(tensorPool2)
        tensorPool3 = self.modulePool3(tensorConv3)

        tensorConv4 = self.moduleConv4(tensorPool3)

        return tensorConv4, tensorConv1, tensorConv2, tensorConv3


class Decoder_new(torch.nn.Module):
    def __init__(self, r_channel=6, n_channel=3):
        super(Decoder_new, self).__init__()

        def Basic(intInput, intOutput):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=intOutput, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False)
            )

        def Upsample(nc, intOutput):
            return torch.nn.Sequential(
                torch.nn.ConvTranspose2d(in_channels=nc, out_channels=intOutput, kernel_size=3, stride=2, padding=1,
                                         output_padding=1),
                torch.nn.ReLU(inplace=False)
            )


        self.moduleConv = Basic(512, 512)
        self.moduleUpsample4 = Upsample(512, 256)

        self.moduleDeconv3 = Basic(512, 256)
        self.moduleUpsample3 = Upsample(256, 128)

        self.moduleDeconv2 = Basic(256, 128)
        self.moduleUpsample2 = Upsample(128, 64)

    def forward(self, x, skip1, skip2, skip3):
        tensorConv = self.moduleConv(x)

        tensorUpsample4 = self.moduleUpsample4(tensorConv)
        cat4 = torch.cat((skip3, tensorUpsample4), dim=1)

        tensorDeconv3 = self.moduleDeconv3(cat4)
        tensorUpsample3 = self.moduleUpsample3(tensorDeconv3)
        cat3 = torch.cat((skip2, tensorUpsample3), dim=1)

        tensorDeconv2 = self.moduleDeconv2(cat3)
        tensorUpsample2 = self.moduleUpsample2(tensorDeconv2)
        cat2 = torch.cat((skip1, tensorUpsample2), dim=1)

        return cat2


class convAE(torch.nn.Module):
    def __init__(self, n_channel=3, t_channel=20, r_channel = 6, proto_size=10, feature_dim=512, key_dim=512, temp_update=0.1,
                 temp_gather=0.1):
        super(convAE, self).__init__()

        def Outhead(intInput, intOutput, nc):
            return torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=intInput, out_channels=nc, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=nc, out_channels=nc, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=nc, out_channels=intOutput, kernel_size=3, stride=1, padding=1),
                torch.nn.Tanh()
            )

        self.encoder = Encoder(t_channel, n_channel)
        self.decoder = Decoder_new(r_channel, n_channel)
        self.prototype = Meta_Prototype(proto_size, feature_dim, key_dim, temp_update, temp_gather)
        # output_head
        self.ohead = Outhead(128, r_channel, 64)

    def set_learnable_params(self, layers):
        for k, p in self.named_parameters():
            if any([k.startswith(l) for l in layers]):
                p.requires_grad = True
            else:
                p.requires_grad = False

    def get_learnable_params(self):
        params = OrderedDict()
        for k, p in self.named_parameters():
            if p.requires_grad:
                # print(k)
                params[k] = p
        return params

    def get_params(self, layers):
        params = OrderedDict()
        for k, p in self.named_parameters():
            if any([k.startswith(l) for l in layers]):
                # print(k)
                params[k] = p
        return params

    def forward(self, x, weights=None, train=True):

        fea, skip1, skip2, skip3 = self.encoder(x)
        new_fea = self.decoder(fea, skip1, skip2, skip3)

        new_fea = F.normalize(new_fea, dim=1)

        if train:
            updated_fea, keys, fea_loss = self.prototype(new_fea, new_fea, weights, train)
            if weights == None:
                output = self.ohead(updated_fea)
            else:
                x = conv2d(updated_fea, weights['ohead.0.weight'], weights['ohead.0.bias'], stride=1, padding=1)
                x = relu(x)
                x = conv2d(x, weights['ohead.2.weight'], weights['ohead.2.bias'], stride=1, padding=1)
                x = relu(x)
                x = conv2d(x, weights['ohead.4.weight'], weights['ohead.4.bias'], stride=1, padding=1)
                output = F.tanh(x)

            return output, fea, updated_fea, keys, fea_loss.mean()

        # test
        else:
            updated_fea, keys, query, fea_loss = self.prototype(new_fea, new_fea, weights, train)
            if weights == None:
                output = self.ohead(updated_fea)
            else:
                x = conv2d(updated_fea, weights['ohead.0.weight'], weights['ohead.0.bias'], stride=1, padding=1)
                x = relu(x)
                x = conv2d(x, weights['ohead.2.weight'], weights['ohead.2.bias'], stride=1, padding=1)
                x = relu(x)
                x = conv2d(x, weights['ohead.4.weight'], weights['ohead.4.bias'], stride=1, padding=1)
                output = F.tanh(x)

            return output, fea_loss


def meta_update(model, model_weights, meta_init_grads, model_alpha, meta_alpha_grads,
                meta_init_optimizer, meta_alpha_optimizer):
    # Unpack the list of grad dicts
    # init_gradients = {k: sum(d[k] for d in meta_init_grads) for k in meta_init_grads[0].keys()}
    init_gradients = {k: (sum(d[k] for d in meta_init_grads) / len(meta_init_grads)) for k in meta_init_grads[0].keys()}
    # alpha_gradients = {k: sum(d[k] for d in meta_alpha_grads) for k in meta_alpha_grads[0].keys()}
    alpha_gradients = {k: (sum(d[k] for d in meta_alpha_grads) / len(meta_init_grads)) for k in
                       meta_alpha_grads[0].keys()}

    # dummy variable to mimic forward and backward
    dummy_x = Variable(torch.Tensor(np.random.randn(1)), requires_grad=False).cuda()

    # update meta_init(for initial weights)
    for k, init in model_weights.items():
        dummy_x = torch.sum(dummy_x * init)
    meta_init_optimizer.zero_grad()
    dummy_x.backward()
    for k, init in model_weights.items():
        init.grad = init_gradients[k]
    meta_init_optimizer.step()

    # update meta_alpha(for learning rate)
    dummy_y = Variable(torch.Tensor(np.random.randn(1)), requires_grad=False).cuda()
    for k, alpha in model_alpha.items():
        dummy_y = torch.sum(dummy_y * alpha)
    meta_alpha_optimizer.zero_grad()
    dummy_y.backward()
    for k, alpha in model_alpha.items():
        alpha.grad = alpha_gradients[k]
    meta_alpha_optimizer.step()


def train_init(model, model_weights, model_alpha, loss_fn, img, lh_img, gt, lh_gt, idx, args):
    pred, _, _, _, fea_loss, _, dis_loss = model.forward(img, model_weights, True)

    loss_pixel = loss_fn(pred, gt)
    loss = args.loss_fea_reconstruct * fea_loss + args.loss_distinguish * dis_loss + args.loss_fra_reconstruct * loss_pixel

    grads = torch.autograd.grad(loss, model_weights.values(), create_graph=True)

    update_weights = OrderedDict((name, param - torch.mul(meta_alpha, grad)) for
                                 ((name, param), (_, meta_alpha), grad) in
                                 zip(model_weights.items(), model_alpha.items(), grads))

    lh_pred, _, _, _, lh_fea_loss, _, lh_dis_loss = model.forward(lh_img, update_weights, True)

    idx = idx + 1

    lh_loss_pixel = loss_fn(lh_pred, lh_gt)
    lh_loss = args.loss_fea_reconstruct * lh_fea_loss + args.loss_distinguish * lh_dis_loss + args.loss_fra_reconstruct * lh_loss_pixel

    grads_ = torch.autograd.grad(lh_loss, model_weights.values(), retain_graph=True)
    alpha_grads = torch.autograd.grad(lh_loss, model_alpha.values())
    meta_init_grads = {}
    meta_alpha_grads = {}
    count = 0
    for k, _ in model_weights.items():
        meta_init_grads[k] = grads_[count]
        meta_alpha_grads[k] = alpha_grads[count]
        count = count + 1
    return meta_init_grads, meta_alpha_grads, loss, lh_loss, idx


def test_init(model, model_weights, model_alpha, loss_fn, imgs, gts, args):
    update_weights = model_weights
    for j in range(args.test_iter):

        grad_list = []
        for k in range(imgs.shape[0]):
            pred, _, _, _, fea_loss, _, dis_loss = model.forward(imgs[k:k + 1], model_weights, True)

            loss_pixel = loss_fn(pred, gts[k:k + 1]).mean()
            loss = args.loss_fea_reconstruct * fea_loss + args.loss_distinguish * dis_loss + args.loss_fra_reconstruct * loss_pixel
            grads = torch.autograd.grad(loss, model_weights.values())
            grad_list.append(grads)

        k_grads = ()
        for i in range(len(grad_list[0])):
            grad_temp = grad_list[0][i]
            for k in range(1, len(grad_list)):
                grad_temp += grad_list[k][i]
            k_grads += (grad_temp / len(grad_list),)

        update_weights = OrderedDict((name, param - torch.mul(meta_alpha, grad)) for
                                     ((name, param), (_, meta_alpha), grad) in
                                     zip(model_weights.items(), model_alpha.items(), k_grads))
        model_weights = update_weights

    return update_weights


def test_ft(model, model_weights, model_alpha, loss_fn, img, gt, args):
    update_weights = model_weights
    for j in range(args.test_iter):
        pred, _, _, _, fea_loss, _, dis_loss = model.forward(img, model_weights, True)

        loss_pixel = loss_fn(pred, gt).mean()
        loss = args.loss_fea_reconstruct * fea_loss + args.loss_distinguish * dis_loss + args.loss_fra_reconstruct * loss_pixel

        grads = torch.autograd.grad(loss, model_weights.values())

        update_weights = OrderedDict((name, param - torch.mul(meta_alpha, grad)) for
                                     ((name, param), (_, meta_alpha), grad) in
                                     zip(model_weights.items(), model_alpha.items(), grads))

        model_weights = update_weights

    return update_weights


def dismap(x, name='pred'):
    x = x.data.cpu().numpy()
    x = x.mean(1)
    for j in range(x.shape[0]):
        plt.cla()
        y = x[j]
        df = pd.DataFrame(y)
        sns.heatmap(df)
        plt.savefig('results/dismap/{}_{}.png'.format(name, str(j)))
        plt.close()
    return True

class unet(torch.nn.Module):
    def __init__(self, n_channel=3, o_channel=3):
        super(unet, self).__init__()

        self.encoder = Encoder(n_channel, 64)
        self.decoder = Decoder_new()
        self.output = torch.nn.Sequential(
                torch.nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False),
                torch.nn.Conv2d(in_channels=64, out_channels=o_channel, kernel_size=3, stride=1, padding=1),
                torch.nn.ReLU(inplace=False)
            )
        #self.prototype = Meta_Prototype(proto_size, feature_dim, key_dim, temp_update, temp_gather)
        # output_head
        #self.ohead = Outhead(128, r_channel, 64)

    def set_learnable_params(self, layers):
        for k, p in self.named_parameters():
            if any([k.startswith(l) for l in layers]):
                p.requires_grad = True
            else:
                p.requires_grad = False

    def get_learnable_params(self):
        params = OrderedDict()
        for k, p in self.named_parameters():
            if p.requires_grad:
                # print(k)
                params[k] = p
        return params

    def get_params(self, layers):
        params = OrderedDict()
        for k, p in self.named_parameters():
            if any([k.startswith(l) for l in layers]):
                # print(k)
                params[k] = p
        return params

    def forward(self, x, weights=None, train=True):

        fea, skip1, skip2, skip3 = self.encoder(x)
        #print(fea.shape)
        fea = self.decoder(fea, skip1, skip2, skip3)
        output = self.output(fea)

        return output


