import csv
import json
import shutil
from itertools import repeat
from flask import Flask, render_template
from flask import request, url_for, redirect, flash
import os
import sys
import numpy as np
from model import *


app = Flask(__name__)
app.secret_key = 'haoml'

app.config['UPLOAD_FOLDER'] = 'static'

model = AddStickerModel(json.load(open('add_sticker.inf', 'r')))
input_urls = []
output_urls = []
sticker_url = None
predicted = False
output_count = 0


@app.route('/', methods=['GET', 'POST'])
def demo():
    global input_urls, output_urls, sticker_url, predicted, output_count
    if predicted:
        input_urls, output_urls = [], []
        sticker_url = None
        predicted = False
        output_count = 0

    if request.method == 'POST':
        if 'reset' in request.form:
            predicted = True
            return redirect(url_for('demo'))
        if 'add_input' in request.form:
            input = request.files['input']
            input_name = '{}_'.format(len(input_urls)) + input.filename
            input_target = os.path.join(app.config['UPLOAD_FOLDER'], input_name)
            input.save(input_target)

            output_name = '{}_output.'.format(len(input_urls)) + input.filename.split('.')[-1]
            output_target = os.path.join(app.config['UPLOAD_FOLDER'], output_name)

            input_urls.append((input_name, input_target))
            output_urls.append((output_name, output_target))

            return render_template('index.html', input_count=len(input_urls), output_count=0,
                                   input=input_urls, sticker=None, output=None, log=['Image added.'])

        if 'add_sticker' in request.form:
            sticker = request.files['sticker']
            sticker_url = os.path.join(app.config['UPLOAD_FOLDER'], sticker.filename)
            sticker.save(sticker_url)
            sticker_url = sticker.filename, sticker_url

            img = Image.open(sticker_url[1])
            valid_img = True
            if img.mode == 'RGBA':
                img_arr = np.array(img)
                minimum_alpha = np.min(img_arr[:, :, 3])
                if minimum_alpha >= 128:
                    valid_img = False
            else:
                valid_img = False

            if valid_img:
                return render_template('index.html', input_count=len(input_urls), output_count=0,
                                       input=input_urls, sticker=sticker_url, output=None, log=['Sticker added.'])
            else:
                return render_template('index.html', input_count=len(input_urls), output_count=0,
                                       input=input_urls, sticker=sticker_url, output=None,
                                       log=['Sticker added.', 'Warning: please check whether the sticker is valid.'])

        if 'evaluate' in request.form:
            if sticker_url is None or len(input_urls) < 1:
                return render_template('index.html', input_count=0, output_count=0,
                                       input=None, sticker=None, output=None, log=['No image/sticker uploaded.'])
            else:
                for input_target, sticker_target, output_target in \
                        zip(input_urls, repeat(sticker_url), output_urls):
                    print(input_target, sticker_target, output_target)
                    _ = model.process(input_target[1], sticker_target[1], output_target[1])

            process_log = model.format()
            output_count = len(input_urls)
            predicted = True

            return render_template('index.html', input_count=0, output_count=output_count,
                                   input=input_urls, sticker=sticker_url, output=output_urls,
                                   log=process_log)
    else:
        return render_template('index.html', input_count=0, output_count=output_count,
                               input=None, sticker=None, output=None, log=[])


if __name__ == '__main__':
    # db.create_all()
    app.run(host='0.0.0.0', port=3121, debug=False)
