#!/usr/bin/env python3

"""Detect forgeries on multiple images with the proposed method, and write the results onto one numpy archive"""

import os
import sys
import argparse

import numpy as np
from tqdm import tqdm
import matplotlib as mpl
from matplotlib import pyplot as plt
import torch

from utils import img_to_tensor, jpeg_compress
from structure import FullNet




def get_parser():
    parser = argparse.ArgumentParser(description="Detect forgeries on multiple images with the proposed method. Results are returned in one numpy archive file")
    parser.add_argument("-m", "--model", type=str, default="models/pretrained.pt", help="Model to use for the network. Default: models/pretrained.pt.")
    parser.add_argument("-j", "--jpeg", type=int, default=None, help="JPEG compression quality. Default: no compression is done before analysis.")
    parser.add_argument("-b", "--block-size", type=int, default=32, help="Block size. Default: 32.")
    parser.add_argument("-o", "--out", type=str, default="out.npz", help="Path to output file. Default: out.npz")
    parser.add_argument("input", nargs='+', type=str, help="Images to analyse.")
    return parser

if __name__ == "__main__":
    mpl.rcParams['figure.figsize'] = (30.0, 10.0)
    parser = get_parser()
    args = parser.parse_args(sys.argv[1:])
    image_names = args.input
    block_size = args.block_size
    quality = args.jpeg
    model = args.model
    out = args.out
    confidences = {}
    net = FullNet().cuda()
    net.load_state_dict(torch.load(model))
    for image_name in tqdm(image_names):
        img = plt.imread(image_name)
        Y, X, C = img.shape
        C = 3
        img = img[:Y, :X, :C]
        if img.max()>1:
            img /= 255
        if quality is not None:
            img = jpeg_compress(img, quality)
        img = img_to_tensor(img).cuda().type(torch.float)
        res = np.exp(net(img).detach().cpu().numpy())
        res[:, 1] = res[([1, 0, 3, 2], 1)]
        res[:, 2] = res[([2, 3, 0, 1], 2)]
        res[:, 3] = res[([3, 2, 1, 0], 3)]
        res = np.mean(res, axis=1)
        best_grid = np.argmax(np.mean(res, axis=(1, 2)))
        authentic = np.argmax(res, axis=(0))==best_grid
        confidence = 1 - np.max(res, axis=0)
        confidence[confidence<0] = 0
        confidence[confidence>1] = 1
        confidence[authentic] = 1
        confidences[os.path.splitext(os.path.basename(image_name))[0]] = confidence
    np.savez(out, **confidences)
