#!/usr/bin/env python3

"""Compare results to ground truth to compute a ROC curve and the AUC score. Results will be written in a numpy .npz archive."""

import sys
import argparse
import itertools

import numpy as np

from sklearn.metrics import roc_curve, roc_auc_score

from tqdm import tqdm


def compare_res_gt(res, gt, exclude_training_algorithms=False):
    image_names = np.load('image_names.npy')
    ground_truth = gt['ground_truth']
    forged = gt['forged']
    algorithm_main = gt['algorithm_main']
    algorithm_forged = gt['algorithm_forged']
    global_confidence = []
    global_label = []
    local_confidence = []
    local_label = []
    for this_forged, this_gt, this_alg_main, this_alg_forged, this_image_name in zip(forged, ground_truth, algorithm_main, algorithm_forged, image_names):
        if exclude_training_algorithms and (this_alg_main in ['bilinear', 'ha', 'lmmse'] or this_alg_forged in ['bilinear', 'ha', 'lmmse']):
            continue
        this_result = 1 - res[this_image_name]
        global_confidence.append(np.max(this_result))
        global_label.append(this_forged)
        local_confidence.append(this_result[this_gt!=-1].ravel())
        local_label.append(this_gt[this_gt!=-1].ravel())
    local_confidence = np.concatenate(local_confidence)
    local_label = np.concatenate(local_label)
    fpr_global, tpr_global, _ = roc_curve(global_label, global_confidence)
    auc_global = roc_auc_score(global_label, global_confidence) 
    fpr_local, tpr_local, _ = roc_curve(local_label, local_confidence)
    auc_local = roc_auc_score(local_label, local_confidence) 
    return fpr_global, tpr_global, auc_global, fpr_local, tpr_local, auc_local



    
def get_parser():
    parser = argparse.ArgumentParser(description="Detect forgeries on multiple images with the proposed method. Results are returned in one numpy archive file")
    parser.add_argument("results", type=str, help="Path to the results, as obtained by detect_forgeries_multiple.py, choi_intermediate_values.py or shin_variance.py.")
    parser.add_argument("-g", "--ground-truth", type=str, default="ground_truths/only_misaligned.npz", help="Ground truth file to use. Default: ground_truths/only_misaligned.npz.")
    parser.add_argument("-e", "--exclude-training-algorithms", action="store_true", default=False, help="If specified, exclude from the database images demosaiced with one of the algorithms used to train the network, or forged with an image demosaiced with one of those algorithms. Those algorithms are bilinear, LMMSE and Hamilton-Adams.")
    parser.add_argument("-o", "--out", type=str, default="roc.npz", help="Path to the output file.")
    return parser

if __name__ == "__main__":
    parser = get_parser()
    args = parser.parse_args(sys.argv[1:])
    gt = np.load(args.ground_truth, allow_pickle=True)  # pickle needed for mixed-type arrays
    res = np.load(args.results)
    fpr_global, tpr_global, auc_global, fpr_local, tpr_local, auc_local = compare_res_gt(res, gt, exclude_training_algorithms=args.exclude_training_algorithms)
    np.savez(args.out, fpr_global=fpr_global, tpr_global=tpr_global, auc_global=auc_global, fpr_local=fpr_local, tpr_local=tpr_local, auc_local=auc_local)


