from utils import *
from tqdm import tqdm
from pprint import PrettyPrinter
import torch
from sparseSSDmodel import *
from IXdataloader import *
import time 

# Good formatting when printing the APs for each class and mAP
pp = PrettyPrinter()

# Parameters
data_folder = './'
keep_difficult = True  # difficult ground truth objects must always be considered in mAP calculation, because these objects DO exist!
batch_size = 16
workers = 4
sparse = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

checkpoint = './VGG16_IXReal_sparse_ckp.pth.tar'

# Load model checkpoint that is to be evaluated
sparse_SSD = SSD300_sparse(n_classes=len(label_map))
sparse_SSD.load_state_dict(torch.load(checkpoint,  map_location=device))
sparse_SSD.eval().to(device)

# Load test data
sparse_test_dataset = IXRealDataset_sparse_local(data_folder,
                                    split='local',
                                    keep_difficult=keep_difficult,
                                    sparse=sparse)

sparse_test_loader = torch.utils.data.DataLoader(sparse_test_dataset, batch_size=batch_size, shuffle=True,
                                        collate_fn=sparse_test_dataset.collate_fn_sparse, num_workers=workers,
                                        pin_memory=True)     


def evaluate(test_loader, model):
    """
    Evaluate.

    :param test_loader: DataLoader for test data
    :param model: model
    """

    # Make sure it's in eval mode
    model.eval()

    # Lists to store detected and true boxes, labels, scores
    det_boxes = list()
    det_labels = list()
    det_scores = list()
    true_boxes = list()
    true_labels = list()
    true_difficulties = list()  # it is necessary to know which objects are 'difficult', see 'calculate_mAP' in utils.py
    elapsed_times = 0
    
    with torch.no_grad():
        # Batches
        for i, (images, boxes, labels, difficulties) in enumerate(tqdm(test_loader, desc='Evaluating')):
            
            predicted_locs, predicted_scores = model(images)
            
            # Detect objects in SSD output
            det_boxes_batch, det_labels_batch, det_scores_batch = model.detect_objects(predicted_locs, predicted_scores,
                                                                                       min_score=0.01, max_overlap=0.45,
                                                                                       top_k=200)
            # Evaluation MUST be at min_score=0.01, max_overlap=0.45, top_k=200 for fair comparision with the paper's results and other repos

            # Store this batch's results for mAP calculation
            boxes = [b.to(device) for b in boxes]
            labels = [l.to(device) for l in labels]
            difficulties = [d.to(device) for d in difficulties]

            det_boxes.extend(det_boxes_batch)
            det_labels.extend(det_labels_batch)
            det_scores.extend(det_scores_batch)
            true_boxes.extend(boxes)
            true_labels.extend(labels)
            true_difficulties.extend(difficulties)

        # Calculate mAP
        APs, mAP = calculate_mAP(det_boxes, det_labels, det_scores, true_boxes, true_labels, true_difficulties)

    # Print AP for each class
    pp.pprint(APs)

    print('\nMean Average Precision (mAP): %.3f' % mAP)
    print('Elapsed times is: %.3f s' % elapsed_times)


if __name__ == '__main__':
    evaluate(sparse_test_loader, sparse_SSD)
