import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
import torchvision.models as models
import ImagenetLoaderValidation
import time
import densenetcont
  
data_path='./imagenet/imagenet_tmp/raw_data/validation'
dataset_val= ImagenetLoaderValidation.ImageNetDataset(data_path, is_train = False)
BATCH_SIZE =4

data_loader_val = DataLoader(dataset_val, BATCH_SIZE, num_workers=8)



clf = densenetcont.DenseNet()        

device=torch.device("cuda:0")
clf.to(device)
clf.load_state_dict(torch.load("densenet/densenetcont_19_.pt"))

def test():
    clf.eval()
    SumAcc1=0.0
    SumAcc5=0.0
    ind=0.0
    for batch_id, batch in enumerate(data_loader_val):
        start = time.time()
        data=batch[0].to(device)
        label=batch[1].to(device)
        
        preds =clf(data)
        
        _, predind1 = preds.data.max(1)
        _, predind5 = torch.topk(preds.data,k=5, dim=1)
        
        acc1 = predind1.eq(label.data).float().mean().cpu() 
        
        label5=torch.unsqueeze(label.data,1)
        label5=label5.data.expand_as(predind5)
        correct5,_= predind5.eq(label5).max(1)

        acc5 = correct5.float().mean().cpu() 
        end = time.time()
        ind+=1.0
        SumAcc1+=acc1.item()
        SumAcc5+=acc5.item()
        if batch_id % 100== 0:
            print(" Batch: "+str(batch_id)+" Acc1: "+str(acc1.item())+" Acc5: "+str(acc5.item()))
    print( "FinalAcc1: "+str(SumAcc1/ind)+" FinalAcc5: "+str(SumAcc5/ind) )            
test()
      

