"""Copyright 2020 ETH Zurich, Seonwook Park

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""
import glob
import json
import os
import sys
import zipfile
import torch
from datetime import datetime
import logging
logger = logging.getLogger(__name__)


class DefaultConfig(object):

    identifier_suffix = ''

    # Misc. notes
    computer = ''
    note = ''
    wandb_tags = []  # Tags for Weights & Biases (wandb) logging

    # Device
    device = "cuda:0" if torch.cuda.is_available() else "cpu"

    # Data Sources
    datasrc_eve = 'C:\Workdir\Develop\Datasets\eve_mini'

    # Logging
    experiment_dir = "experiments"
    save_dir = os.path.join(experiment_dir, datetime.now().strftime('%Y-%m-%d_%H-%M-%S-%f'))
    save_to_file = True

    # Data loading
    video_decoder_codec = 'libx264'  # libx264 | nvdec
    assumed_frame_rate = 10  # We will skip frames from source videos accordingly
    max_sequence_len = 30  # In frames assuming {assumed_frame_rate}Hz and 3 seconds of video
    start_time = 0.0  # Start time in seconds
    start_frame = int(start_time * assumed_frame_rate)  # Start sequence in frames
    face_size = [128, 128]  # width, height
    eyes_size = [128, 128]  # width, height
    screen_size = [960, 540]  # width, height
    crop_size = [224,224]  # width, height
    crop_factor = 1.0 
    actual_screen_size = [1920, 1080]  # DO NOT CHANGE
    camera_frame_types = ['eyes', 'face']  # full | face | eyes
    sides = ['left', 'right']
    load_screen_content = False
    load_full_frame_for_visualization = False
    flip_right_eye = True
    early_fusion = False
    st_net_combined = False

    train_cameras = ['basler', 'webcam_l', 'webcam_c', 'webcam_r']
    train_stimuli = ['image', 'video', 'wikipedia']
    test_cameras = ['basler', 'webcam_l', 'webcam_c', 'webcam_r']
    test_stimuli = ['image', 'video', 'wikipedia']

    # Inference
    input_path = ''
    output_path = ''

    # Training
    skip_training = False
    fully_reproducible = False  # enable with possible penalty of performance

    batch_size = 6
    weight_decay = 0.00001
    num_epochs = 5

    optimizer = 'adam'  # 'adam' | 'sgd'

    train_data_workers = 16

    log_every_n_steps = 1  # NOTE: Every other interval has to be a multiple of this!!!
    tensorboard_scalars_every_n_steps = 1
    tensorboard_images_every_n_steps = 10
    tensorboard_learning_rate_every_n_steps = 100

    max_steps = None

    # Learning rate
    base_learning_rate = 0.0001

    @property
    def learning_rate(self):
        return self.batch_size * self.base_learning_rate
    # Available strategies:
    #     'exponential': step function with exponential decay
    #     'cyclic':      spiky down-up-downs (with exponential decay of peaks)
    num_warmup_epochs = 0.0  # No. of epochs to warmup LR from base to target
    lr_decay_strategy = 'none'
    lr_decay_factor = 0.5
    lr_decay_epoch_interval = 0.5

    # Gradient Clipping
    do_gradient_clipping = True
    gradient_clip_by = 'value'  # 'norm' or 'value'
    gradient_clip_amount = 15.0

    # Ablation study
    st_net_ablation = False
    ablation_eye_encoder = True
    ablation_face_encoder = True
    ablation_eca = True
    ablation_sam = True
    ablation_gru = True

    # Augmentation
    do_offset_augmentation = True
    offset_augmentation_sigma = 3.0

    # ST gaze network configuration
    st_net_model_name = 'efficientnet_b3'
    st_net_load_pretrained = True
    st_net_frozen = False
    st_net_dropout = 0.15
    st_net_gru_num_cells = 2
    st_net_transformer_num_layers = 3
    st_net_transformer_num_heads = 8
    st_net_transformer_ffn_dim = 512
    st_net_static_num_features = 256
    st_net_use_head_pose_input = True
    st_net_pool_before_gru = False
    st_net_weights = "weights/efficientnet_b3/best.pth"
    loss_coeff_PoG_cm_initial = 1e-2 # 0.03
    loss_coeff_PoG_px_initial = 0.0
    loss_coeff_PoG_cons_initial = 0.0 #0.005
    loss_coeff_g_ang_initial = 1.#0.999
    loss_coeff_g_face = 2/3
    loss_coeff_g_eyes = 1/3
    loss_separate = False

    # Evaluation
    test_num_samples = 240
    test_batch_size = 12
    test_data_workers = 12
    test_every_n_steps = 500
    do_full_test = False
    full_test_batch_size = 38
    full_test_data_workers = 16

    codalab_eval_batch_size = 38
    codalab_eval_data_workers = 16

    # Checkpoints management
    checkpoints_save_every_n_steps = 300
    checkpoints_keep_n = 5
    # resume_from = os.path.join(os.environ['VSC_DATA'], "experiments/SalientGazeNet/resume_from/latest.pth")
    # resume_from = os.path.join(experiment_dir, "2025-02-13_12-14-29/checkpoints/epoch_1.pth")
    resume_from = None

    # Below lie necessary methods for working configuration tracking

    __instance = None

    # Make this a singleton class
    def __new__(cls):
        if cls.__instance is None:
            cls.__instance = super().__new__(cls)
            cls.__filecontents = cls.__get_config_file_contents()
            if not hasattr(sys, 'ps1'):
                cls.__pycontents = cls.__get_python_file_contents()
            cls.__immutable = True
        return cls.__instance

    def import_json(self, json_path, strict=True):
        """Import JSON config to over-write existing config entries."""
        assert os.path.isfile(json_path)
        assert not hasattr(self.__class__, '__imported_json_path')
        logger.info('Loading ' + json_path)
        with open(json_path, 'r') as f:
            json_string = f.read()
        self.import_dict(json.loads(json_string), strict=strict)
        self.__class__.__imported_json_path = json_path
        self.__class__.__filecontents[os.path.basename(json_path)] = json_string

    def override(self, key, value):
        self.__class__.__immutable = False
        setattr(self, key, value)
        self.__class__.__immutable = True

    def import_dict(self, dictionary, strict=True):
        """Import a set of key-value pairs from a dict to over-write existing config entries."""
        self.__class__.__immutable = False
        for key, value in dictionary.items():
            if strict is True:
                if not hasattr(self, key):
                    raise ValueError('Unknown configuration key: ' + key)
                if type(getattr(self, key)) is float and type(value) is int:
                    value = float(value)
                else:
                    assert type(getattr(self, key)) is type(value)
                if not isinstance(getattr(DefaultConfig, key), property):
                    setattr(self, key, value)
            else:
                if hasattr(DefaultConfig, key):
                    if not isinstance(getattr(DefaultConfig, key), property):
                        setattr(self, key, value)
                else:
                    setattr(self, key, value)
        self.__class__.__immutable = True

    def __get_config_file_contents():
        """Retrieve and cache default and user config file contents."""
        out = {}
        for relpath in ['config_default.py']:
            path = os.path.relpath(os.path.dirname(__file__) + '/' + relpath)
            assert os.path.isfile(path)
            with open(path, 'r') as f:
                out[os.path.basename(path)] = f.read()
        return out

    def __get_python_file_contents():
        """Retrieve and cache default and user config file contents."""
        out = {}
        base_path = os.path.relpath(os.path.dirname(__file__) + '/../')
        source_fpaths = [
            p for p in glob.glob(base_path + '/**/*.py')
            if not p.startswith('./3rdparty/')
        ]
        source_fpaths += [os.path.relpath(sys.argv[0])]
        for fpath in source_fpaths:
            assert os.path.isfile(fpath)
            with open(fpath, 'r') as f:
                out[fpath[2:]] = f.read()
        return out

    def get_all_key_values(self):
        return dict([
            (key, getattr(self, key))
            for key in dir(self)
            if not key.startswith('_DefaultConfig')
            and not key.startswith('__')
            and not callable(getattr(self, key))
        ])

    def get_full_json(self):
        return json.dumps(self.get_all_key_values(), indent=4)

    def write_file_contents(self, target_base_dir):
        """Write cached config file contents to target directory."""
        assert os.path.isdir(target_base_dir) or not os.path.exists(target_base_dir)

        # Write config file contents
        target_dir = target_base_dir + '/configs'
        if not os.path.isdir(target_dir):
            os.makedirs(target_dir)
        outputs = {  # Also output flattened config
            'combined.json': self.get_full_json(),
        }
        outputs.update(self.__class__.__filecontents)
        for fname, content in outputs.items():
            fpath = os.path.relpath(target_dir + '/' + fname)
            with open(fpath, 'w') as f:
                f.write(content)
                logger.info('Written %s' % fpath)

        # Copy source folder contents over
        target_path = os.path.relpath(target_base_dir + '/src.zip')
        source_path = os.path.relpath(os.path.dirname(__file__) + '/../')
        filter_ = lambda x: x.endswith('.py') or x.endswith('.json')  # noqa
        with zipfile.ZipFile(target_path, 'w', zipfile.ZIP_DEFLATED) as zip_file:
            for root, dirs, files in os.walk(source_path):
                for file_or_dir in files + dirs:
                    full_path = os.path.join(root, file_or_dir)
                    if os.path.isfile(full_path) and filter_(full_path):
                        zip_file.write(
                            os.path.join(root, file_or_dir),
                            os.path.relpath(os.path.join(root, file_or_dir),
                                            os.path.join(source_path, os.path.pardir)))
        logger.info('Written source folder to %s' % os.path.relpath(target_path))

    def __setattr__(self, name, value):
        """Initial configs should not be overwritten!"""
        if self.__class__.__immutable:
            raise AttributeError('DefaultConfig instance attributes are immutable.')
        else:
            super().__setattr__(name, value)

    def __delattr__(self, name):
        """Initial configs should not be removed!"""
        if self.__class__.__immutable:
            raise AttributeError('DefaultConfig instance attributes are immutable.')
        else:
            super().__delattr__(name)
