import torch
from PIL import Image
import os
import time
import shutil
import csv
import random
import argparse
import json
from segmentation import DeepLabV3
from model_general import AddStickerGeneral, HeadGeneral, d_iou, export_case
from model_background import BackgroundClassifier, eval_single, composite
# from model_lookup import StickerLookUp
import cv2
from torch import nn
import func_timeout


class AddStickerModel(nn.Module):
    info_dict = {
        'lookup': 'Confidence of making a prediction: {}\n',
        'type_pred': 'Non-background probability: {} \n',
        'mask_pred': 'Need-a-mask probability: {} \n',
        'opa_pred': 'Need-low-opacity probability: {} \n',
        'cls_pred': 'Position score: {} \n',
        'raw_box': 'Raw prediction: {} \n',
        'normalized_box': 'Prediction after w/h ratio adjustment: {} \n',
        'seg_time': 'Segmentation time: {} \n',
        'process_time': 'Total process time: {}\n',
        'lookup_time': 'Time to calculate confidence: {}\n',
    }

    def __init__(self, config):
        super().__init__()

        self.bg_model = BackgroundClassifier()
        self.bg_model.load_state_dict(torch.load(config['model_background']))
        self.bg_model.eval().cuda()

        self.thresholds = config['thresholds']

        self.general_model = AddStickerGeneral(config['model_general.backbone'],
                                               torch.load(config['model_general.box_info']))
        self.general_model.load_state_dict(torch.load(config['model_general']))
        self.general_model.eval().cuda()
        
        self.segmentation_model = DeepLabV3()
        self.two_stage = bool(config['two_stage'])
        self.confidence_threshold = float(config['confidence_threshold'])
        self.log = dict()

    def format(self):
        log = []
        for k, v in AddStickerModel.info_dict.items():
            if k in self.log.keys():
                log.append(v.format(self.log[k]))
        return log

    def process(self, img_url, sticker_url, output_url):
        self.log = dict()

        start_time = time.time()
        img = Image.open(img_url)
        sticker = Image.open(sticker_url)
        read_time = time.time() - start_time
        start_time = time.time()

        raw_for_mask = cv2.imread(img_url)
        img_mask = self.segmentation_model.process_data(raw_for_mask)        
        seg_time = time.time() - start_time
        self.log['seg_time'] = seg_time

        output_json = None

        if self.two_stage:
            type_pred, mask_pred, opa_pred = eval_single(self.bg_model, sticker, device=torch.device('cuda'))
            self.log['type_pred'] = float(type_pred.detach().cpu().numpy())
            self.log['mask_pred'] = float(mask_pred.detach().cpu().numpy())
            self.log['opa_pred'] = float(opa_pred.detach().cpu().numpy())

            type_pred = 0 if type_pred < self.thresholds[0] else 1
            mask_pred = 0 if mask_pred < self.thresholds[1] else 1
            opa_pred = 0 if opa_pred < self.thresholds[2] else 1

            if type_pred == 0:
                result = composite(img, sticker, mask_pred, opa_pred, img_mask)
                self.log['normalized_box'] = '0.00, 0.00, 1.00, 1.00'
                output_json = {
                    'rect': {
                        'x': 0.,
                        'y': 0.,
                        'w': 1.,
                        'h': 1.,
                    },
                    'rotation': 0.,
                    'opacity': 50 if opa_pred else 100,
                }
            else:
                result, top5pred = self.general_model.eval_single(img, sticker, img_mask, device=torch.device('cuda'))
                print(top5pred)
                for k, v in self.general_model.log.items():
                    self.log[k] = v

                if self.log['cls_pred'] < self.confidence_threshold:
                    r = sticker.size[0] * img.size[1] / sticker.size[1] / img.size[0]
                    h = (1. / 9 / r) ** 0.5
                    w = r * h
                    center_pred = torch.tensor((
                        0.5 - w / 2, 0.5 - h / 2,
                        w, h
                    ))
                    result = export_case(img, sticker, center_pred, is_url=False)
                    self.log['normalized_box'] = '{:.2f}, {:.2f}, {:.2f}, {:.2f}'.format(0.5 - w / 2, 0.5 - h / 2, w, h)

                x, y, w, h = (float(_) for _ in self.log['normalized_box'].split(','))
                output_json = {
                    'rect': {
                        'x': x,
                        'y': y,
                        'w': w,
                        'h': h,
                    },
                    'rotation': 0.,
                    'opacity': 100,
                }
        else:
            result = self.general_model.eval_single(img, sticker, img_mask, device=torch.device('cuda'))
            for k, v in self.general_model.log.items():
                self.log[k] = v

        process_time = time.time() - start_time
        self.log['process_time'] = process_time

        if len(output_url) == 0:
            output_url = 'output.' + img_url.split('.')[-1]

        result.save(output_url)
        print(self.log)
        print(json.dumps(output_json, indent=2))
        return json.dumps(output_json, indent=2)


def process_single(config, args):
    model = AddStickerModel(config)
    model.process(args.img, args.sticker, args.output)


def process_batch(config, source, write_log=False, output_folder='process_batch_030'):
    if os.path.exists(output_folder):
        shutil.rmtree(output_folder)
    os.mkdir(output_folder)

    lookup_model = torch.load('model_lookup.pt')
    lookup_model.eval()

    start_time = time.time()
    model = AddStickerModel(config)
    print(time.time() - start_time)

    if write_log:
        time_writer = csv.DictWriter(open('user_random_similarity.csv', 'w', newline=''),
                                     fieldnames=['input', 'sticker', 'output', 'score', 'seg_time',
                                                 'similarity1', 'similarity3', 'similarity5', 'similarity10', 'similarity20', 'similarity50', 'similarity100',
                                                 'process_time', 'diou', 'meta', 'type_pred'])
        time_writer.writeheader()

    read_time, seg_time, process_time, count = 0., 0., 0., 0
    reader = csv.DictReader(open('user_random.csv', 'r'))
    for row in reader:
        count += 1
        output_url = '{:.0f}_pred.{}'.format(count, row['input'].split('.')[-1])
        rtime, stime, ptime = model.process(
            os.path.join(source, row['input']) if source is not None else row['input'],
            os.path.join(source, row['sticker']) if source is not None else row['sticker'],
            os.path.join(output_folder, output_url)
        )
        shutil.copy(row['input'], os.path.join(output_folder,
                                               '{:.0f}_input.{}'.format(count, row['input'].split('.')[-1])))
        shutil.copy(row['sticker'], os.path.join(output_folder,
                                                 '{:.0f}_sticker.{}'.format(count, row['sticker'].split('.')[-1])))

        gt_box = torch.tensor((float(row['x']), float(row['y']), float(row['w']), float(row['h'])))
        gt_box[2], gt_box[3] = gt_box[0] + gt_box[2], gt_box[1] + gt_box[3]

        numbers = model.log['normalized_box'].split(',')
        pred_box = torch.tensor(list(float(x) for x in numbers))
        pred_box[2], pred_box[3] = pred_box[0] + pred_box[2], pred_box[1] + pred_box[3]

        gt_box = gt_box.unsqueeze(0)
        pred_box = pred_box.unsqueeze(0)
        quality, _, _ = d_iou(pred_box, gt_box)

        if write_log:
            best_dis, best_idx = lookup_model(row['sticker'])

            time_writer.writerow({
                'input': row['input'],
                'sticker': row['sticker'],
                'output': output_url,
                'similarity100': best_dis[99][0].cpu().detach().numpy(),
                'similarity50': best_dis[49][0].cpu().detach().numpy(),
                'similarity20': best_dis[19][0].cpu().detach().numpy(),
                'similarity10': best_dis[9][0].cpu().detach().numpy(),
                'similarity5': best_dis[4][0].cpu().detach().numpy(),
                'similarity3': best_dis[2][0].cpu().detach().numpy(),
                'similarity1': best_dis[0][0].cpu().detach().numpy(),
                'score': float(model.log['cls_pred']) if 'cls_pred' in model.log.keys() else -1,
                'seg_time': stime,
                'process_time': ptime,
                'type_pred': float(model.log['type_pred']),
                'diou': float(quality[0].cpu().detach().numpy()),
                'meta': model.log,
            })

        if count >= 5:
            read_time += rtime
            seg_time += stime
            process_time += ptime

    count -= 5
    print(read_time / count, seg_time / count, process_time / count)


def postprocess_exported_data():
    if os.path.exists('dataset'):
        shutil.rmtree('dataset')
    os.mkdir('dataset')

    reader = csv.DictReader(open('unlabeled_source.csv', 'r'))
    for row in reader:
        composite_triple(row['input'], row['output'], row['sticker'], 'dataset/{}.png'.format(int(row['rid'])))


def composite_triple(img1_url, img2_url, img3_url, target):
    img1, img2, img3 = Image.open(img1_url), Image.open(img2_url), Image.open(img3_url)
    if img3.size[1] > img1.size[1]:
        img3 = img3.resize((int(img3.size[0] * img1.size[1] / img3.size[1]), img1.size[1]))

    new_w = img1.size[0] + img2.size[0] + img3.size[0] + 40
    new_h = img1.size[1]
    new_img = Image.new('RGBA', (new_w, new_h))

    new_img.paste(img1, (0, 0))
    new_img.paste(img2, (img1.size[0] + 20, 0))
    new_img.paste(img3, (img1.size[0] + img2.size[0] + 40, (new_h - img3.size[1]) // 2))
    new_img.save(target)


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Auto completion for adding stickers.')
    parser.add_argument('--img', type=str, default='12_input.png')
    parser.add_argument('--sticker', type=str, default='12_sticker.png')
    parser.add_argument('--output', type=str, default='')
    parser.add_argument('--config', type=str, default='add_sticker.inf')
    args = parser.parse_args()
    config = json.load(open(args.config, 'r'))
    print(args)    
    postprocess_exported_data()
