import os
import torch
from config.config_default import DefaultConfig
from datasources import EVESequences_val, init_datasets, get_training_batches
from models import STGazeNet, STGazeNetAblation
from calflops import calculate_flops

import logging

config = DefaultConfig()

config.override('test_batch_size', 1)

def automatic_worker_number():
    # try to compute a suggested max number of worker based on system's resource
    max_num_worker_suggest = None
    if hasattr(os, 'sched_getaffinity'):
        try:
            max_num_worker_suggest = len(os.sched_getaffinity(0))-1
        except Exception:
            pass
    if max_num_worker_suggest is None:
        # os.cpu_count() could return Optional[int]
        # get cpu count first and check None in order to satisfy mypy check
        cpu_count = os.cpu_count()
        if cpu_count is not None:
            max_num_worker_suggest = cpu_count
    return max_num_worker_suggest

cpu_count = automatic_worker_number()
if cpu_count is not None and cpu_count > 0:
    if config.train_data_workers > cpu_count:
        config.override('train_data_workers', cpu_count-1)
    if config.test_data_workers > cpu_count:
        config.override('test_data_workers', cpu_count-1)
device = torch.device(config.device)

if config.save_to_file:
    log_path = os.path.join(config.save_dir, 'training.log')
    if not os.path.exists(config.save_dir):
        os.makedirs(config.save_dir)
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=log_path)
else:
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

validation_dataset_paths = [
    ('eve_val', EVESequences_val, config.datasrc_eve, config.test_stimuli, config.test_cameras),
]

def main_evaluation_loop():
    _, val_data = init_datasets([], validation_dataset_paths)

    if config.st_net_ablation:
        eye_model = STGazeNetAblation(config.st_net_model_name, config.st_net_dropout, 
                                          config.ablation_eye_encoder, config.ablation_face_encoder,
                                          config.ablation_eca, config.ablation_sam, config.ablation_gru)
    else:
        eye_model = STGazeNet(config.st_net_model_name, config.st_net_dropout)
    
    eye_model.load_state_dict(torch.load(config.st_net_weights, map_location=device))
    eye_model.to(device)
    eye_model.eval()

    with torch.no_grad():
        batch = get_training_batches(val_data)
        full_input_dict = next(iter(batch.values()))
        t = 0
        sub_input_dict = {k: v[:, t] for k, v in full_input_dict.items() if isinstance(v, torch.Tensor)}
        side = config.sides[0]
        flip = config.flip_right_eye and side == "right"
        flops1, macs1, params1 = calculate_flops(model=eye_model, kwargs={'input_dict': sub_input_dict, 'side': side, 'flip': flip})
        side = config.sides[1]
        flip = config.flip_right_eye and side == "right"
        flops2, macs2, params2 = calculate_flops(model=eye_model, kwargs={'input_dict': sub_input_dict, 'side': side, 'flip': flip})
    flops = float(flops1.split(' ')[0]) + float(flops2.split(' ')[0])
    macs = float(macs1.split(' ')[0]) + float(macs2.split(' ')[0])
    params = max(float(params1.split(' ')[0]), float(params2.split(' ')[0]))
    print(f"FLOPs: {flops:.2f} {flops1.split(' ')[1]}, MACs: {macs:.2f} {macs1.split(' ')[1]}, Parameters: {params:.2f} {params1.split(' ')[1]}")

    # Print the number of parameters in the model for each parts (eye and face encoders, ECA, SAM, GRU, and final layers)
    eye_parameters = sum(p.numel() for p in eye_model.eye_encoder.parameters())
    face_parameters = sum(p.numel() for p in eye_model.face_encoder.parameters())
    eca_parameters = sum(p.numel() for p in eye_model.eca.parameters())
    sam_parameters = sum(p.numel() for p in eye_model.sam.parameters())
    gru_parameters = sum(p.numel() for p in eye_model.gru.parameters())
    final_parameters = sum(p.numel() for p in eye_model.fc_to_gaze.parameters())
    total_parameters = sum(p.numel() for p in eye_model.parameters())
    print(f"""
        Model Parameters: 
            Eye Encoder: {eye_parameters} ({eye_parameters / total_parameters * 100:.2f}%),
            Face Encoder: {face_parameters} ({face_parameters / total_parameters * 100:.2f}%),
            ECA: {eca_parameters} ({eca_parameters / total_parameters * 100:.2f}%),
            SAM: {sam_parameters} ({sam_parameters / total_parameters * 100:.2f}%),
            GRU: {gru_parameters} ({gru_parameters / total_parameters * 100:.2f}%),
            Final Layers: {final_parameters} ({final_parameters / total_parameters * 100:.2f}%),
        Total: {total_parameters}
    """)

if __name__ == '__main__':
    main_evaluation_loop()