import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import torchvision.models as models
import ImagenetLoader
import time
import densenetcont
  
data_path='./imagenet/imagenet_tmp/raw_data/train'
dataset_train, dataset_test = ImagenetLoader.get_imagenet_datasets(data_path)

print(f"Number of train samplest {dataset_train.__len__()}")
print(f"Number of samples in test split {dataset_test.__len__()}")

BATCH_SIZE =20

data_loader_train = DataLoader(dataset_train, BATCH_SIZE, shuffle = True, num_workers=8)
  

clf = densenetcont.DenseNet()       

device=torch.device("cuda:0")
clf.to(device)
clf.load_state_dict(torch.load("densenet/densenetcont_9_.pt"))

criterion = nn.CrossEntropyLoss()

#for param in clf.parameters():
#        param.requires_grad = True
opt = optim.Adam(clf.parameters())
#opt = optim.SGD(clf.parameters(),lr=0.01, momentum=0.9)

train_loss_history = []
train_acc1_history = []
train_acc5_history = []                  

def train(epoch):
    clf.train()
    for batch_id, batch in enumerate(data_loader_train):
        start = time.time()
        data=batch[0].to(device)
        label=batch[1].to(device)
        opt.zero_grad()
        preds =clf(data)
        
        loss = criterion(preds, label)
        loss.backward()
        opt.step()
        _, predind1 = preds.data.max(1)
        _, predind5 = torch.topk(preds.data,k=5, dim=1)
        
        acc1 = predind1.eq(label.data).float().mean().cpu() 
        
        label5=torch.unsqueeze(label.data,1)
        label5=label5.data.expand_as(predind5)
        correct5,_= predind5.eq(label5).max(1)

        acc5 = correct5.float().mean().cpu() 
        end = time.time()
        
        if batch_id % 100== 0:
            print("Epoch: "+str(epoch)+" Batch: "+str(batch_id)+" Train Loss: "+str(loss.item())+" Acc1: "+str(acc1.item())+" Acc5: "+str(acc5.item()))
            train_loss_history.append(loss.item())
            train_acc1_history.append(acc1.item())
            train_acc5_history.append(acc5.item())

for epoch in range(10, 20):
        print("Epoch %d" % epoch)
        train(epoch)
        torch.save(clf.state_dict(),"densenet/densenetcont_"+str(epoch)+ "_.pt")


np.save("densenet/densenetcont_loss.npy",np.array(train_loss_history))
np.save("densenet/densenetcont_acc1.npy",np.array(train_acc1_history))
np.save("densenet/densenetcont_acc5.npy",np.array(train_acc5_history))




