import torch
import torch.nn as nn
import torch.optim as optim
import util
import sys



class CLASSIFIER:
    def __init__(self, _train_X, _train_Y, data_loader, _nclass, _lr=0.001, _beta1=0.5, _nepoch=20, _batch_size=100, generalized=True):
        self.train_X =  _train_X 
        self.train_Y = _train_Y 
        self.test_seen_feature = data_loader.test_seen_feature
        self.test_seen_label = data_loader.test_seen_label 
        self.test_unseen_feature = data_loader.test_unseen_feature
        self.test_unseen_label = data_loader.test_unseen_label 
        self.seenclasses = data_loader.seenclasses
        self.unseenclasses = data_loader.unseenclasses
        self.batch_size = _batch_size
        self.nepoch = _nepoch
        self.nclass = _nclass
        self.input_dim = _train_X.size(1)
        self.model =  LINEAR_LOGSOFTMAX(self.input_dim, self.nclass)
        self.model.apply(util.weights_init)
        self.criterion = nn.NLLLoss()
        self.input = torch.FloatTensor(_batch_size, self.input_dim) 
        self.label = torch.LongTensor(_batch_size)
        self.optimizer = optim.Adam(self.model.parameters(), lr=_lr, betas=(_beta1, 0.999))

        self.model.cuda()
        self.criterion.cuda()
        self.input = self.input.cuda()
        self.label = self.label.cuda()

        self.index_in_epoch = 0
        self.epochs_completed = 0
        self.ntrain = self.train_X.size()[0]
        if generalized:
            self.acc_seen, self.acc_unseen, self.H = self.fit_gzsl()
        else:
            self.acc = self.fit_zsl()

    def fit_gzsl(self):
        all_test_label = torch.cat( (self.test_seen_label, self.test_unseen_label), 0 )
        best_H = 0
        for epoch in range(self.nepoch):
            for i in range(0, self.ntrain, self.batch_size):
                self.model.zero_grad()
                batch_input, batch_label = self.next_batch(self.batch_size)
                self.input.copy_(batch_input)
                self.label.copy_(batch_label)
                output = self.model(self.input)
                loss = self.criterion(output, self.label)
                loss.backward()
                self.optimizer.step()
            acc_seen,pred_seen,output_seen = self.val_gzsl(self.test_seen_feature, self.test_seen_label, self.seenclasses)
            acc_unseen,pred_unseen,output_unseen = self.val_gzsl(self.test_unseen_feature, self.test_unseen_label, self.unseenclasses)
            H = 2 * acc_seen * acc_unseen / (acc_seen + acc_unseen)
            if H > best_H:
                best_H = H
                all_pred = torch.cat( (pred_seen,pred_unseen), 0 )
        seen_pred,seen_label,unseen_pred,unseen_label = self.split_pred(all_pred,all_test_label)
        acc_seen = self.compute_per_class_acc_gzsl(seen_label, seen_pred,self.seenclasses)
        acc_unseen = self.compute_per_class_acc_gzsl(unseen_label, unseen_pred,self.unseenclasses)
        acc_H = 2*acc_seen*acc_unseen/(acc_seen+acc_unseen)
        print('Seen Accuracy: {:.2f}%, Unseen Accuracy: {:.2f}%, H: {:.2f}%'.format(acc_seen*100,acc_unseen*100,acc_H*100))
        sys.stdout.flush()
        return acc_seen,acc_unseen,acc_H

    def val_gzsl(self, test_X, test_label, target_classes):
        start = 0
        ntest = test_X.size()[0]
        predicted_label = torch.LongTensor(test_label.size())
        all_output = None
        for i in range(0, ntest, self.batch_size):
            end = min(ntest, start+self.batch_size)
            with torch.no_grad():
                output = self.model(test_X[start:end].cuda())

            if all_output is None:
                all_output = output
            else:
                all_output = torch.cat( (all_output, output), 0 )
            _, predicted_label[start:end] = torch.max(output.data, 1)
            start = end
        acc = self.compute_per_class_acc_gzsl(test_label, predicted_label, target_classes)
        return acc, predicted_label, all_output


    def fit_zsl(self):
        best_acc=0
        for epoch in range(self.nepoch):
            for i in range(0, self.ntrain, self.batch_size):
                self.model.zero_grad()
                batch_input, batch_label = self.next_batch(self.batch_size)
                self.input.copy_(batch_input)
                self.label.copy_(batch_label)
                output = self.model(self.input)
                loss = self.criterion(output, self.label)
                loss.backward()
                self.optimizer.step()
            acc, pred, output,all_acc = self.val_zsl(self.test_unseen_feature, self.test_unseen_label, self.unseenclasses)
            if acc > best_acc:
                best_acc = acc
        print('ZSL Acc: {:.2f}%'.format(best_acc * 100))

        sys.stdout.flush()
        return best_acc

    def val_zsl(self, test_X, test_label, target_classes):
        start = 0
        ntest = test_X.size()[0]
        predicted_label = torch.LongTensor(test_label.size())
        all_output = None
        for i in range(0, ntest, self.batch_size):
            end = min(ntest, start+self.batch_size)
            with torch.no_grad():
                output = self.model(test_X[start:end].cuda())
            if all_output is None:
                all_output = output
            else:
                all_output = torch.cat( (all_output, output), 0 )
            _, predicted_label[start:end] = torch.max(output.data, 1)
            start = end
        acc = self.compute_per_class_acc(util.map_label(test_label, target_classes), predicted_label, target_classes.size(0))
        acc_all = self.compute_every_class_acc(util.map_label(test_label, target_classes), predicted_label, target_classes.size(0))
        return acc, predicted_label, all_output,acc_all

    def next_batch(self, batch_size):
        start = self.index_in_epoch
        if self.epochs_completed == 0 and start == 0:
            perm = torch.randperm(self.ntrain)
            self.train_X = self.train_X[perm]
            self.train_Y = self.train_Y[perm]
        if start + batch_size > self.ntrain:
            self.epochs_completed += 1
            rest_num_examples = self.ntrain - start
            if rest_num_examples > 0:
                X_rest_part = self.train_X[start:self.ntrain]
                Y_rest_part = self.train_Y[start:self.ntrain]
            perm = torch.randperm(self.ntrain)
            self.train_X = self.train_X[perm]
            self.train_Y = self.train_Y[perm]
            start = 0
            self.index_in_epoch = batch_size - rest_num_examples
            end = self.index_in_epoch
            X_new_part = self.train_X[start:end]
            Y_new_part = self.train_Y[start:end]
            if rest_num_examples > 0:
                return torch.cat((X_rest_part, X_new_part), 0) , torch.cat((Y_rest_part, Y_new_part), 0)
            else:
                return X_new_part, Y_new_part
        else:
            self.index_in_epoch += batch_size
            end = self.index_in_epoch
            return self.train_X[start:end], self.train_Y[start:end]

    def compute_per_class_acc_gzsl(self, test_label, predicted_label, target_classes):
        acc_per_class = 0
        for i in target_classes:
            idx = (test_label == i)
            if torch.sum(idx) == 0:
                continue
            else:
                acc_per_class += torch.sum(test_label[idx]==predicted_label[idx]).float() / torch.sum(idx).float()
        acc_per_class /= target_classes.size(0)
        return acc_per_class 

    def compute_per_class_acc(self, test_label, predicted_label, nclass):
        acc_per_class = torch.FloatTensor(nclass).fill_(0)
        for i in range(nclass):
            idx = (test_label == i)
            if torch.sum(idx) != 0:
                acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]) / torch.sum(idx)
        return acc_per_class.mean()

    def split_pred(self,all_pred, real_label):
        seen_pred = None
        seen_label = None
        unseen_pred = None
        unseen_label = None
        for i in self.seenclasses:
            idx = (real_label == i)
            if seen_pred is None:
                seen_pred = all_pred[idx]
                seen_label = real_label[idx]
            else:
                seen_pred = torch.cat( (seen_pred,all_pred[idx]),0 )
                seen_label = torch.cat( (seen_label, real_label[idx]) )

        for i in self.unseenclasses:
            idx = (real_label == i)
            if unseen_pred is None:
                unseen_pred = all_pred[idx]
                unseen_label = real_label[idx]
            else:
                unseen_pred = torch.cat( (unseen_pred,all_pred[idx]),0 )
                unseen_label = torch.cat(  (unseen_label, real_label[idx]), 0 )

        return seen_pred, seen_label, unseen_pred, unseen_label

    def compute_every_class_acc(self, test_label, predicted_label, nclass):
        acc_per_class = torch.FloatTensor(nclass).fill_(0)
        for i in range(nclass):
            idx = (test_label == i)
            if torch.sum(idx) != 0:
                acc_per_class[i] = torch.sum(test_label[idx]==predicted_label[idx]) / torch.sum(idx)
        return acc_per_class


class LINEAR_LOGSOFTMAX(nn.Module):
    def __init__(self, input_dim, nclass):
        super(LINEAR_LOGSOFTMAX, self).__init__()
        self.fc = nn.Linear(input_dim, nclass)
        self.logic = nn.LogSoftmax(dim=1)
    def forward(self, x): 
        o = self.logic(self.fc(x))
        return o  
