
from __future__ import print_function

import argparse
import os
import sys
import classifier
import torch
import model
import util

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='CUB', help='AWA1')
parser.add_argument('--gpu_id',type=str,default='0')
parser.add_argument('--syn_num', type=int, default=300, help='number features to generate per class')
parser.add_argument('--nepoch', type=int, default=300, help='number of epochs to train for')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size')
parser.add_argument('--nclass_all', type=int, default=200, help='number of all classes')
parser.add_argument('--resSize', type=int, default=2048, help='size of visual features')
parser.add_argument('--attSize', type=int, default=312, help='size of semantic features')
parser.add_argument('--nz', type=int, default=312, help='size of the latent z vector; same with dimension of attr')
parser.add_argument('--ngh', type=int, default=4096, help='size of the hidden units in generator')
parser.add_argument('--classifier_lr', type=float, default=0.001, help='learning rate to train softmax classifier')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--netG_name', default='MLP_G')
parser.add_argument('--netG', default='./checkpoint/netG.pkl', help="path to netG")
parser.add_argument('--dataroot', default='./data', help='path to dataset')
parser.add_argument('--matdataset', default=True, help='Data in matlab format')
parser.add_argument('--image_embedding', default='res101')
parser.add_argument('--class_embedding', default='att')
parser.add_argument('--gzsl', action='store_true', default=True, help='enable generalized zero-shot learning')
parser.add_argument('--preprocessing', action='store_true', default=True,
                    help='enbale MinMaxScaler on visual features')

opt = parser.parse_args()
os.environ['CUDA_VISIBLE_DEVICES']=opt.gpu_id

print('Begin run!!!')
sys.stdout.flush()
data = util.DATA_LOADER(opt)
netG = model.MLP_G(opt)
if opt.netG != '':
    netG.load_state_dict(torch.load(opt.netG))

def generate_syn_feature(netG, classes, attribute, num):
    nclass = classes.size(0)
    syn_feature = torch.FloatTensor(nclass * num, opt.resSize)
    syn_label = torch.LongTensor(nclass * num)
    syn_att = torch.FloatTensor(num, opt.attSize)
    syn_noise = torch.FloatTensor(num, opt.nz)
    syn_att = syn_att.cuda()
    syn_noise = syn_noise.cuda()

    for i in range(nclass):
        iclass = classes[i]
        iclass_att = attribute[iclass]
        syn_att.copy_(iclass_att.repeat(num, 1))
        syn_noise.normal_(0, 1)
        with torch.no_grad():
            output = netG(syn_noise, syn_att)
        syn_feature.narrow(0, i * num, num).copy_(output.data.cpu())
        syn_label.narrow(0, i * num, num).fill_(iclass)

    return syn_feature, syn_label

def map_label(label, classes):
    mapped_label = torch.LongTensor(label.size()).cuda()
    for i in range(classes.size(0)):
        mapped_label[label==classes[i]] = i

    return mapped_label


netG.cuda()
if opt.gzsl:
    syn_feature, syn_label = generate_syn_feature(netG, data.unseenclasses, data.attribute, opt.syn_num)
    train_X = torch.cat((data.train_feature, syn_feature), 0)
    train_Y = torch.cat((data.train_label, syn_label), 0)
    nclass = opt.nclass_all
    cls = classifier.CLASSIFIER(train_X, train_Y, data, nclass, opt.classifier_lr, opt.beta1, 50, 2 * opt.syn_num,
                                      False)

else:
    syn_feature, syn_label = generate_syn_feature(netG, data.unseenclasses, data.attribute, opt.syn_num)
    cls = classifier.CLASSIFIER(syn_feature, util.map_label(syn_label, data.unseenclasses), data,
                                     data.unseenclasses.size(0), opt.classifier_lr, opt.beta1, 50, 2 * opt.syn_num,
                                     False)

del cls
cls = None

print('End run!!!')