import os
import torch
import numpy as np
from config.config_default import DefaultConfig
from datasources import EVESequences_val, init_datasets, get_training_batches
from models import STGazeNet, to_screen_coordinates, STGazeNetAblation

import logging
from tqdm import tqdm
from time import time

config = DefaultConfig()

config.override('do_full_test', False)
config.override('test_num_samples', 50)
config.override('save_to_file', True)
config.override('max_sequence_len', 150)
config.override('start_time', 0)
config.override('assumed_frame_rate', 10)
config.override('start_frame', int(config.start_time * config.assumed_frame_rate))
config.override('test_batch_size', 1)  

def automatic_worker_number():
    # try to compute a suggested max number of worker based on system's resource
    max_num_worker_suggest = None
    if hasattr(os, 'sched_getaffinity'):
        try:
            max_num_worker_suggest = len(os.sched_getaffinity(0))-1
        except Exception:
            pass
    if max_num_worker_suggest is None:
        # os.cpu_count() could return Optional[int]
        # get cpu count first and check None
        cpu_count = os.cpu_count()
        if cpu_count is not None:
            max_num_worker_suggest = cpu_count
    return max_num_worker_suggest

cpu_count = automatic_worker_number()
if cpu_count is not None and cpu_count > 0:
    if config.train_data_workers > cpu_count:
        config.override('train_data_workers', cpu_count-1)
    if config.test_data_workers > cpu_count:
        config.override('test_data_workers', cpu_count-1)
device = torch.device(config.device)

if config.save_to_file:
    log_path = os.path.join(config.save_dir, 'training.log')
    if not os.path.exists(config.save_dir):
        os.makedirs(config.save_dir)
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=log_path)
else:
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

validation_dataset_paths = [
    ('eve_val', EVESequences_val, config.datasrc_eve, config.test_stimuli, config.test_cameras),
]

def main_evaluation_loop():
    _, val_data = init_datasets([], validation_dataset_paths)

    if config.st_net_ablation:
        eye_model = STGazeNetAblation(config.st_net_model_name, config.st_net_dropout, 
                                          config.ablation_eye_encoder, config.ablation_face_encoder,
                                          config.ablation_eca, config.ablation_sam, config.ablation_gru)
    else:
        eye_model = STGazeNet(config.st_net_model_name, config.st_net_dropout)
    
    eye_model.load_state_dict(torch.load(config.st_net_weights, map_location=device))
    eye_model.to(device)
    eye_model.eval()

    max_dataset_len = np.amax([len(data_dict['dataset']) for data_dict in val_data.values()])
    batch_size = config.test_batch_size
    num_steps_per_epoch = int((max_dataset_len // batch_size) + (max_dataset_len % batch_size > 0))

    with torch.no_grad():
        start = time()
        total_frames = 0
        tqdm_bar = tqdm(range(num_steps_per_epoch), desc="Inference Speed Evaluation", ncols=100)
        for _ in tqdm_bar:
            batch = get_training_batches(val_data)
            full_input_dict = next(iter(batch.values()))
            _, T, _, _, _ = full_input_dict["face_patch"].shape
            total_frames += T
            for t in range(T):
                sub_input_dict = {k: v[:, t] for k, v in full_input_dict.items() if isinstance(v, torch.Tensor)}
                for side in config.sides:
                    flip = config.flip_right_eye and side == "right"
                    gaze_prediction_sequence, _ = eye_model(sub_input_dict, side=side, flip=flip)
                    origin = sub_input_dict[f"{side}_o"]
                    rotation = sub_input_dict[f"{side}_R"]
                    inv_camera_transformation = sub_input_dict['inv_camera_transformation']
                    pixels_per_millimeter = sub_input_dict['pixels_per_millimeter']
                    _, _ = to_screen_coordinates(
                        origin,
                        gaze_prediction_sequence,
                        rotation,
                        inv_camera_transformation,
                        pixels_per_millimeter
                    )
            end = time()
            tqdm_bar.set_postfix({"FPS": total_frames / (end - start) if (end - start) > 0 else 0})
        end = time()
        total_time = end - start
        fps = total_frames / total_time
        logging.info(f"Processed {total_frames} frames in {total_time:.2f} seconds. Inference Speed: {fps:.2f} FPS")
        print(f"Processed {total_frames} frames in {total_time:.2f} seconds. Inference Speed: {fps:.2f} FPS")
        tqdm_bar.close()

if __name__ == "__main__":
    main_evaluation_loop()
