import os
import argparse
import torch
import numpy as np
from config.config_default import DefaultConfig
from datasources import EVESequences_val, init_datasets, get_training_batches
from models import STGazeNet, to_screen_coordinates, STGazeNetAblation

import logging
from tqdm import tqdm
import wandb

def parse_args():
    parser = argparse.ArgumentParser(description="Evaluation script for SalientGazeNet")
    parser.add_argument('--save_dir', type=str, default=None, help='Directory to save the results')
    
    known_args, unknown_args = parser.parse_known_args()
    
    # Convert unknown args into a dictionary (key=value format)
    arg_dict = vars(known_args)
    
    for arg in unknown_args:
        if arg.startswith("--"):
            key, value = arg.lstrip('-').split('=', 1)
            arg_dict[key] = value

    return arg_dict

args = parse_args()
config = DefaultConfig()

config.override('do_full_test', True)
config.override('save_to_file', True)
config.override('max_sequence_len', 30)
config.override('start_time', 0)
config.override('assumed_frame_rate', 10)
config.override('start_frame', int(config.start_time * config.assumed_frame_rate))

for key, value in args.items():
    try:
        if value is not None:
            config.override(key, value)
    except Exception as e:
        logging.error(f"Error: {e}")

# Change "save_dir" from "path/to/dir" to "path/to/evaluation/dir"
if args['save_dir'] is None:
    leaf_dir = os.path.basename(config.save_dir)
    config.override('save_dir', os.path.join(config.experiment_dir, "evaluation", leaf_dir))

def automatic_worker_number():
    # try to compute a suggested max number of worker based on system's resource
    max_num_worker_suggest = None
    if hasattr(os, 'sched_getaffinity'):
        try:
            max_num_worker_suggest = len(os.sched_getaffinity(0))-1
        except Exception:
            pass
    if max_num_worker_suggest is None:
        # os.cpu_count() could return Optional[int]
        # get cpu count first and check None in order to satisfy mypy check
        cpu_count = os.cpu_count()
        if cpu_count is not None:
            max_num_worker_suggest = cpu_count
    return max_num_worker_suggest

cpu_count = automatic_worker_number()
if cpu_count is not None and cpu_count > 0:
    if config.train_data_workers > cpu_count:
        config.override('train_data_workers', cpu_count-1)
    if config.test_data_workers > cpu_count:
        config.override('test_data_workers', cpu_count-1)
device = torch.device(config.device)

# Set up logging
if config.save_to_file:
    log_path = os.path.join(config.save_dir, 'training.log')
    if not os.path.exists(config.save_dir):
        os.makedirs(config.save_dir)
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=log_path)
else:
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

validation_dataset_paths = [
    ('eve_val', EVESequences_val, config.datasrc_eve, config.test_stimuli, config.test_cameras),
]
project_name="ST Gaze Net Using Video Evaluation"
model_name, run_name, weights_name = config.st_net_weights.split("/")[1:]
weights_name = weights_name.split(".")[0]

def main_evaluation_loop():
    wandb.init(
        # set the wandb project where this run will be logged
        project=project_name,
        name="_".join([model_name, run_name, weights_name]),
        group=config.st_net_weights.split("/")[-2].split("_")[-1],
        config=config.get_all_key_values(),
        dir=config.experiment_dir
    )

    _, val_data = init_datasets([], validation_dataset_paths)

    if config.st_net_ablation:
        eye_model = STGazeNetAblation(config.st_net_model_name, config.st_net_dropout, 
                                          config.ablation_eye_encoder, config.ablation_face_encoder,
                                          config.ablation_eca, config.ablation_sam, config.ablation_gru)
    else:
        eye_model = STGazeNet(config.st_net_model_name, config.st_net_dropout)
    
    eye_model.load_state_dict(torch.load(config.st_net_weights))
    eye_model.to(device)
    eye_model.eval()

    max_dataset_len = np.amax([len(data_dict['dataset']) for data_dict in val_data.values()])
    batch_size = config.full_test_batch_size
    num_steps_per_epoch = int((max_dataset_len // batch_size) + (max_dataset_len % batch_size > 0))

    loss_dicts = []
    tags_dicts = []
    with torch.no_grad():
        tqdm_bar = tqdm(range(num_steps_per_epoch), desc="Validation")
        for step in tqdm_bar:
            batch = get_training_batches(val_data)
            full_input_dict = next(iter(batch.values()))
            batch_size, sequence_len = next(iter(full_input_dict.values())).shape[:2]
            intermediate_dicts = []
            hidden_states = {"left": None, "right": None}
            for t in range(sequence_len):
                sub_input_dict = {k: v[:, t] for k, v in full_input_dict.items() if isinstance(v, torch.Tensor)}
                sub_output_dict = {}
                for side in ["left", "right"]:
                    hidden_state = hidden_states[side]
                    flip = config.flip_right_eye and side == "right"
                    gaze_prediction, hidden_state = eye_model(sub_input_dict, side, hidden_state, flip)
                    sub_output_dict[f"{side}_gaze"] = gaze_prediction
                    hidden_states[side] = hidden_state
                    origin = sub_input_dict[f"{side}_o"]
                    rotation = sub_input_dict[f"{side}_R"]
                    inv_camera_transformation = sub_input_dict["inv_camera_transformation"]
                    pixels_per_millimeter = sub_input_dict["pixels_per_millimeter"]
                    PoG_mm, PoG_px = to_screen_coordinates(origin, gaze_prediction, rotation, inv_camera_transformation, pixels_per_millimeter)
                    sub_output_dict[f"{side}_PoG_mm"] = PoG_mm
                    sub_output_dict[f"{side}_PoG_px"] = PoG_px
                gaze_prediction = (sub_output_dict["left_gaze"] + sub_output_dict["right_gaze"]) / 2
                sub_output_dict["face_gaze"] = gaze_prediction
                face_origin = sub_input_dict["face_o"]
                face_rotation = sub_input_dict["face_R"]
                inv_camera_transformation = sub_input_dict["inv_camera_transformation"]
                pixels_per_millimeter = sub_input_dict["pixels_per_millimeter"]
                face_PoG_mm, face_PoG_px = to_screen_coordinates(face_origin, gaze_prediction, face_rotation, inv_camera_transformation, pixels_per_millimeter)
                sub_output_dict["face_PoG_mm"] = face_PoG_mm
                sub_output_dict["face_PoG_px"] = face_PoG_px
                intermediate_dicts.append(sub_output_dict)
            intermediate_dict = {k: torch.stack([d[k] for d in intermediate_dicts], dim=1) for k in intermediate_dicts[0].keys()}
            loss_dict = eye_model.loss(full_input_dict, intermediate_dict, reduction='none')
            for k, v in loss_dict.items():
                if "pog" in k:
                    v = torch.sqrt(v)
                    loss_dict[k] = v
            wandb_logs = {k: v.reshape(batch_size, sequence_len).mean(dim=1) for k, v in loss_dict.items()}
            cameras = np.array(full_input_dict["camera"])
            participants = np.array(full_input_dict["participant"])
            stimulis = np.array([sub_fold.split("_")[1] for sub_fold in full_input_dict["subfolder"]])
            subfolder = np.array(full_input_dict["subfolder"])
            wandb_logs["camera"] = cameras
            wandb_logs["participant"] = participants
            wandb_logs["stimuli"] = stimulis
            wandb_logs["subfolder"] = subfolder
            tags_dicts.append({
                "camera": cameras.repeat(sequence_len, axis=0),
                "participant": participants.repeat(sequence_len, axis=0),
                "stimuli": stimulis.repeat(sequence_len, axis=0),
                "subfolder": subfolder.repeat(sequence_len, axis=0)
            })
            loss = loss_dict["loss_ang"].mean()
            loss_dicts.append(loss_dict)
            for i in range(batch_size):
                wandb.log({f"detailed/{k}": v[i].item() if type(v) == torch.Tensor else v[i] for k, v in wandb_logs.items()})
            tqdm_bar.set_postfix({"Loss": loss.item()})
        # val_loss = val_loss.mean()
    loss_dict = {k: torch.cat([d[k] for d in loss_dicts], dim=0) for k in loss_dicts[0].keys()}
    tags_dict = {k: np.concatenate([d[k] for d in tags_dicts], axis=0) for k in tags_dicts[0].keys()}
    val_loss = loss_dict["loss_ang"].mean()
    tqdm_bar.set_postfix({"Loss": val_loss})
    # write to wandb
    wandb.log({k: v.mean().item() for k, v in loss_dict.items()})
    wandb.finish()
    logging.info("Finished evaluation.")
    return loss_dict, tags_dict

if __name__ == "__main__":
    loss_dict, tags_dict = main_evaluation_loop()
    logging.info(f"Writing results in: {config.save_dir}")
    print(f"Writing results in: {config.save_dir}")
    for k, v in loss_dict.items():
        logging.info(f"{k}: {v.mean().item()}")
    # write to file
    if config.save_to_file:
        with open(os.path.join(config.save_dir, 'validation_loss.csv'), 'w') as f:
            keys_loss = list(loss_dict.keys())
            keys_tags = list(tags_dict.keys())
            for k in keys_loss:
                f.write(f"{k},")
            for k in keys_tags[:-1]:
                f.write(f"{k},")
            f.write(f"{keys_tags[-1]}\n")
            for i in range(len(loss_dict[keys_loss[0]])):
                for k in keys_loss:
                    f.write(f"{loss_dict[k][i].item()},")
                for k in keys_tags[:-1]:
                    f.write(f"{tags_dict[k][i]},")
                f.write(f"{tags_dict[keys_tags[-1]][i]}\n")
    logging.info("Results saved.")