import sys
sys.path.insert(1, "../Source Codes/")
import keras
import keras.models as km
import keras.layers as kl
import keras.backend as kb
import Utils_MAI as u
import tensorflow as tf
from SRModels_MAI import *
import os
import random
my_devices = tf.config.experimental.list_physical_devices(device_type='CPU')
tf.config.experimental.set_visible_devices(devices= my_devices, device_type='CPU')

def representative_dataset():
    dataset_folder_lr = "../Other/DIV2K_train_LR_bicubic_X3"
    sorted_files = sorted(os.listdir(dataset_folder_lr))
    files = [sorted_files[i] for i in [299]]  # best scoring representative image is selected manually for the models.
    for file in files:
        print(file)
        data = u.modcrop_imread_RGB(dataset_folder_lr + "/" + file)
        if data.shape[0]>=360 and data.shape[1]>=640:
            data = data[None, 0:360,0:640,:]
            yield [data]


generation_type = 0 #generate 0 model.tflite
# generation_type = 1 #generate 1 model_none.tflite
# generation_type = 2 #generate 2 model_none_float.tflite

convert_tflite = True
downscale=3
save_fp32_submission = False

if generation_type == 0:
    use_uint8 = True
    use_fixed_size = True
    tflite_name = "../TFlite/model.tflite"

if generation_type == 1:
    use_uint8 = True
    use_fixed_size = False
    tflite_name = "../TFlite/model_none.tflite"

if generation_type == 2:
    use_uint8 = False
    use_fixed_size = False
    tflite_name = "../TFlite/model_none_float.tflite"


model_name = "best_model.h5"  #checkpoint model file
validation_dataset_folder_lr = "../Other/DIV2K_valid_LR_bicubic_X3_db"


with keras.utils.custom_object_scope({'custom_mse': u.custom_mse, "custom_psnr": u.custom_psnr, "tf": tf, "kb": kb}):
    model = km.load_model(model_name)

if save_fp32_submission:
    u.submission_prediction(model,validation_dataset_folder_lr, False)

if convert_tflite:
    model.save("./TFModel", overwrite=True, include_optimizer=False, save_format='tf')
    model = tf.saved_model.load("./TFModel")
    concrete_func = model.signatures[tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]

    if use_fixed_size:
        concrete_func.inputs[0].set_shape([1, 360, 640, 3])
    else:
        concrete_func.inputs[0].set_shape([1, None, None, 3])

    converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
    converter.experimental_new_converter = True
    if not generation_type == 2:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]

    if use_uint8:
        converter.representative_dataset = representative_dataset
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.uint8
        converter.inference_output_type = tf.uint8

    model_byte = converter.convert()
    with open(tflite_name, "wb") as file:
        file.write(model_byte)


    # The following lines can be used to infer using the generated model
    # model_int = tf.lite.Interpreter(model_path=tflite_name)
    #
    # model_lite = u.KerasLite(model_int)
    #
    # if not use_fixed_size:
    #     u.dataset_prediction(model_lite, validation_images, downscale, use_uint8, use_fixed_size)
