import os
import codecarbon
import psutil
import argparse
import importlib
import torch
from torch import nn
import torchvision.models.efficientnet as weight_module
import numpy as np
from config import DefaultConfig, set_tags
from datasources import EVESequences_train, EVESequences_val, init_datasets, get_training_batches, generate_offset_augmentation
from models import STGazeNetVectorized, to_screen_coordinates, apply_offset_augmentation
import logging
from tqdm import tqdm
import wandb
from codecarbon import track_emissions

def parse_args():
    parser = argparse.ArgumentParser(description="Training script for SalientGazeNet")
    parser.add_argument('--batch_size', type=int, default=None, help='Batch size for training')
    parser.add_argument('--base_learning_rate', type=float, default=None, help='Base Learning rate for training')
    parser.add_argument('--learning_rate', type=float, default=None, help='Actual learning rate for training')
    parser.add_argument('--save_dir', type=str, default=None, help='Directory to save the results')
    parser.add_argument('--optimizer', type=str, default=None, help='Optimizer to use for training')
    
    # Allow arbitrary arguments by treating them as key-value pairs
    parser.add_argument('--config', nargs='*', help='Arbitrary key=value pairs')

    known_args, unknown_args = parser.parse_known_args()
    
    # Convert unknown args into a dictionary (key=value format)
    arg_dict = vars(known_args)
    
    for arg in unknown_args:
        if arg.startswith("--"):
            key, value = arg.lstrip('-').split('=', 1)
            if value.isdigit():
                value = int(value)
            elif value.replace('.', '', 1).isdigit():
                value = float(value)
            elif value.lower() in ['true', 'false']:
                value = value.lower() == 'true'
            elif value.startswith('[') and value.endswith(']'):
                value = [v.strip() for v in value[1:-1].split(',')]
            arg_dict[key] = value
    return arg_dict

def automatic_batch_size():
    if config.device == 'cpu':
        # Get available RAM
        memory = psutil.virtual_memory().total * 0.80 / 1024 / 1024
    else:
        memory = torch.cuda.get_device_properties(0).total_memory * 0.95 / 1024 / 1024
    if config.st_net_model_name.startswith("efficientnet"):
        memory_per_sample = 160 * ('eyes' in config.camera_frame_types) + 160 * ('face' in config.camera_frame_types)
    elif config.st_net_model_name.startswith("resnet"):
        memory_per_sample = 25 * ('eyes' in config.camera_frame_types) + 25 * ('face' in config.camera_frame_types)
    batch_size = int(memory // (memory_per_sample * config.max_sequence_len))
    return max(batch_size, 1)

def automatic_worker_number():
    # try to compute a suggested max number of worker based on system's resource
    max_num_worker_suggest = None
    if hasattr(os, 'sched_getaffinity'):
        try:
            max_num_worker_suggest = len(os.sched_getaffinity(0))-1
        except Exception:
            pass
    if max_num_worker_suggest is None:
        # os.cpu_count() could return Optional[int]
        # get cpu count first and check None in order to satisfy mypy check
        cpu_count = os.cpu_count()
        if cpu_count is not None:
            max_num_worker_suggest = cpu_count
    return max_num_worker_suggest

args = parse_args()
config = DefaultConfig()

# Override provided parameters
learning_rate = None
for key, value in args.items():
    if key == "learning_rate" and value is not None:
        config.override('base_learning_rate', value)
        learning_rate = value
    else:
        try:
            if value is not None:
                config.override(key, value)
        except Exception as e:
            logging.error(f"Error: {e}")

if config.st_net_load_pretrained and config.st_net_model_name.startswith('efficientnet'):
    model_id = "_".join([s.capitalize() for s in config.st_net_model_name.split('_')[1:]])
    weight_class_name  = f"EfficientNet_{model_id}_Weights"
    try:
        EfficientNet_Weights = getattr(weight_module, weight_class_name)
    except AttributeError:
        raise ValueError(f"Invalid model name: {config.st_net_model_name}. Please check the model name and try again.")

if learning_rate is None:
    learning_rate = config.learning_rate

auto_batch_size = automatic_batch_size()
if auto_batch_size is not None and auto_batch_size < config.batch_size:
    config.override('batch_size', auto_batch_size)
if auto_batch_size is not None and auto_batch_size < config.test_batch_size:
    config.override('test_batch_size', auto_batch_size)

cpu_count = automatic_worker_number()
if cpu_count is not None and cpu_count > 0:
    if config.train_data_workers > cpu_count:
        config.override('train_data_workers', cpu_count-1)
    if config.test_data_workers > cpu_count:
        config.override('test_data_workers', cpu_count-1)
device = torch.device(config.device)

if config.max_steps:
    config.override('num_epochs', 1)

tags=config.wandb_tags
if not isinstance(tags, list):
    tags = [tags]
config.override('wandb_tags', set_tags(config, tags))

project_name="ST Gaze Net Using Video"

# Set up logging
if config.save_to_file:
    log_path = os.path.join(config.save_dir, 'training.log')
    if not os.path.exists(config.save_dir):
        os.makedirs(config.save_dir)
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', filename=log_path)
else:
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

train_dataset_paths = [
    ('eve_train', EVESequences_train, config.datasrc_eve, config.train_stimuli, config.train_cameras),
]

validation_dataset_paths = [
    ('eve_val', EVESequences_val, config.datasrc_eve, config.test_stimuli, config.test_cameras),
]

@track_emissions(
    project_name=project_name, 
    output_dir=config.experiment_dir, 
    logging_logger=logging, 
    log_level="info", 
    save_to_api=True, 
    save_to_file=True)
def main_training_loop():
    wandb.init(
        project=project_name,
        tags=config.wandb_tags,
        group="Vectorized Training",
        config=config.get_all_key_values(),
        dir=config.experiment_dir
    )
    if config.save_to_file:
        config.write_file_contents(config.save_dir)
    train_data, val_data = init_datasets(train_dataset_paths, validation_dataset_paths)

    # Initialize the vectorized model
    if config.st_net_ablation:
        eye_model = STGazeNetVectorized(
            config.st_net_model_name, 
            config.st_net_dropout, 
            eye_encoder=config.ablation_eye_encoder,
            face_encoder=config.ablation_face_encoder,
            eca=config.ablation_eca,
            sam=config.ablation_sam,
            gru=config.ablation_gru
        )
    else:
        eye_model = STGazeNetVectorized(config.st_net_model_name, config.st_net_dropout)

    if config.st_net_load_pretrained and config.st_net_model_name.startswith('efficientnet'):
        b_w = EfficientNet_Weights.DEFAULT.get_state_dict(progress=True, check_hash=True)
        # Rename keys to match the model and remove classifier
        b_w = {k.replace("features", "model"): v for k, v in b_w.items() if 'classifier' not in k}
        # Remove last conv layer (because channel size is different)
        b_w = {k: v for k, v in b_w.items() if not k.startswith('model.8')}
        if config.st_net_model_name.startswith('efficientnet_v2'):
            b_w = {k: v for k, v in b_w.items() if not k.startswith('model.7')}
        # Load weights into the model
        if not config.st_net_ablation or config.ablation_eye_encoder:
            _ = eye_model.eye_encoder.load_state_dict(b_w, strict=False)
        if not config.st_net_ablation or config.ablation_face_encoder:
            _ = eye_model.face_encoder.load_state_dict(b_w, strict=False)
            
    # Load weights if resuming training
    if config.resume_from is not None:
        eye_model.load_state_dict(torch.load(config.resume_from))
        logging.info(f"Loaded model weights from {config.resume_from}")
    
    if config.st_net_frozen:
        for child in list(eye_model.eye_encoder.model.children())[:-1]:
            for param in child.parameters():
                param.requires_grad = False
        for child in list(eye_model.face_encoder.model.children())[:-1]:  
            for param in child.parameters():
                param.requires_grad = False
    
    eye_model = eye_model.to(config.device)
    eye_model.train()
    
    if config.optimizer == 'adam':
        optimizer = torch.optim.Adam(eye_model.parameters(), lr=learning_rate)
    elif config.optimizer == 'sgd':
        optimizer = torch.optim.SGD(eye_model.parameters(), lr=learning_rate)
    else:
        raise ValueError(f"Unknown optimizer {config.optimizer}")

    # Add gradient clipping callback if needed
    def clip_gradients(grad, name):
        if torch.isnan(grad).any():
            logging.warning(f'NaN detected in gradient for {name}')
        return torch.where(torch.isnan(grad), torch.zeros_like(grad), grad)

    for name, param in eye_model.named_parameters():
        if param.requires_grad:
            param.register_hook(lambda grad: clip_gradients(grad, name))
            if config.do_gradient_clipping and config.gradient_clip_by == 'value':
                param.register_hook(lambda grad: torch.clamp(grad, -config.gradient_clip_amount, config.gradient_clip_amount))

    # Training loop
    max_dataset_len = np.amax([len(data_dict['dataset']) for data_dict in train_data.values()])
    num_steps_per_epoch = int((max_dataset_len // config.batch_size) + (max_dataset_len % config.batch_size > 0))

    if config.max_steps:
        logging.warning("!!!!!!! WARNING !!!!!!")
        logging.warning(f"Training for a fixed number of steps: {config.max_steps}")
        logging.warning("This will ignore the number of epochs.")
        num_steps_per_epoch = min(int(config.max_steps), num_steps_per_epoch)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=config.num_epochs*num_steps_per_epoch, eta_min=learning_rate*0.001)

    logging.info(f"Starting training loop with {config.num_epochs} epochs and {num_steps_per_epoch} steps per epoch.")
    history = {
        "loss_dict": [],
        "output_dict": [],
    }
    checkpoint = 0
    best_val_loss = float('inf')
    os.makedirs(f"{config.save_dir}/checkpoints", exist_ok=True)
    for epoch in range(config.num_epochs):
        eye_model.train()
        epoch_loss = 0.0
        tqdm_bar = tqdm(range(num_steps_per_epoch), desc=f"Epoch {epoch + 1}/{config.num_epochs}")
        for i in tqdm_bar:
            step = epoch * num_steps_per_epoch + i + 1
            batch = get_training_batches(train_data)
            optimizer.zero_grad()
            intermediate_dict = {}
            full_input_dict = next(iter(batch.values()))
            B, T, _, _, _ = full_input_dict["face_patch"].shape
            intermediate_dict["face_gaze"] = torch.zeros((B, T, 2), device=config.device)
            if config.do_offset_augmentation:
                generate_offset_augmentation(config.batch_size, T, full_input_dict)
            # Process each eye
            for side in config.sides:
                flip = config.flip_right_eye and side == "right"
                # Predict gaze direction
                gaze_prediction_sequence, _ = eye_model(full_input_dict, side=side, flip=flip)
                intermediate_dict[f"{side}_gaze"] = gaze_prediction_sequence
                # Accumulate gaze predictions for the face
                intermediate_dict["face_gaze"] += gaze_prediction_sequence / len(config.sides)
                # Convert gaze predictions to screen coordinates
                origin = full_input_dict[f"{side}_o"].view(B*T, 3)
                rotation = full_input_dict[f"{side}_R"].view(B*T, 3, 3)
                inv_camera_transformation = full_input_dict['inv_camera_transformation'].view(B*T, 4, 4)
                pixels_per_millimeter = full_input_dict['pixels_per_millimeter'].view(B*T, 2)
                PoG_mm, PoG_px = to_screen_coordinates(
                    origin, 
                    gaze_prediction_sequence.view(B*T, 2),
                    rotation,
                    inv_camera_transformation,
                    pixels_per_millimeter
                )
                intermediate_dict[f"{side}_PoG_mm"] = PoG_mm.view(B, T, 2)
                intermediate_dict[f"{side}_PoG_px"] = PoG_px.view(B, T, 2)
                # Apply offset augmentation if enabled
                if config.do_offset_augmentation:
                    full_input_dict[f"{side}_g_tobii"] = apply_offset_augmentation(
                        full_input_dict[f"{side}_g_tobii"].view(B*T, 2),
                        full_input_dict["head_R"].view(B*T, 3, 3),
                        full_input_dict[f"{side}_kappa_fake"].view(B*T, 2)
                    ).view(B, T, 2)
                    # Reculculate PoG for the side after applying offset augmentation
                    _, PoG_px = to_screen_coordinates(
                        origin, 
                        full_input_dict[f"{side}_g_tobii"].view(B*T, 2),
                        rotation,
                        inv_camera_transformation,
                        pixels_per_millimeter
                    )
                    full_input_dict[f"{side}_PoG_tobii"] = PoG_px.view(B, T, 2)
            # Convert face gaze to screen coordinates
            origin = full_input_dict["face_o"].view(B*T, 3)
            rotation = full_input_dict["face_R"].view(B*T, 3, 3)
            inv_camera_transformation = full_input_dict['inv_camera_transformation'].view(B*T, 4, 4)
            pixels_per_millimeter = full_input_dict['pixels_per_millimeter'].view(B*T, 2)
            face_gaze = intermediate_dict["face_gaze"].view(B*T, 2)
            PoG_mm, PoG_px = to_screen_coordinates(
                origin, 
                face_gaze,
                rotation,
                inv_camera_transformation,
                pixels_per_millimeter
            )
            intermediate_dict["face_PoG_mm"] = PoG_mm.view(B, T, 2)
            intermediate_dict["face_PoG_px"] = PoG_px.view(B, T, 2)
            # Apply offset augmentation to face gaze if enabled
            if config.do_offset_augmentation:
                full_input_dict["face_g_tobii"] = apply_offset_augmentation(
                    full_input_dict["face_g_tobii"].view(B*T, 2),
                    full_input_dict["head_R"].view(B*T, 3, 3),
                    full_input_dict["face_kappa_fake"].view(B*T, 2)
                ).view(B, T, 2)
                # Recalculate PoG for the face after applying offset augmentation
                _, PoG_px = to_screen_coordinates(
                    origin, 
                    full_input_dict["face_g_tobii"].view(B*T, 2),
                    rotation,
                    inv_camera_transformation,
                    pixels_per_millimeter
                )
                full_input_dict["face_PoG_tobii"] = PoG_px.view(B, T, 2)
            # Compute loss
            loss_dict = eye_model.loss(full_input_dict, intermediate_dict)
            loss = loss_dict["loss"]
            if torch.isnan(loss):
                logging.error('NaN detected in loss at step %d' % step)
                raise ValueError('NaN detected in loss')
            
            loss.backward()

            if config.do_gradient_clipping and config.gradient_clip_by == 'norm':
                nn.utils.clip_grad_norm_(eye_model.parameters(), config.gradient_clip_amount)

            optimizer.step()
            scheduler.step()

            # --- Logging and saving logic remains mostly the same ---
            epoch_loss += loss.item()
            loss_dict_detached = {k: v.detach().cpu().numpy().item() for k, v in loss_dict.items()}
            history["loss_dict"].append(loss_dict_detached)
            loss_dict_detached["custom step"] = step
            loss_dict_detached["last_lr"] = scheduler.get_last_lr()[0]

            if step % config.test_every_n_steps == 0:
                val_loss = validate_model(eye_model, val_data)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    if config.save_to_file:
                        torch.save(eye_model.state_dict(), f"{config.save_dir}/checkpoints/best.pth")
                loss_dict_detached["val_loss"] = val_loss
                eye_model.train()

            wandb.log(loss_dict_detached)
            tqdm_bar.set_postfix(loss=loss.item(), lr=loss_dict_detached["last_lr"])

            del loss, loss_dict, intermediate_dict, full_input_dict, batch, loss_dict_detached

            if config.save_to_file and step % config.checkpoints_save_every_n_steps == 0:
                torch.save(eye_model.state_dict(), f"{config.save_dir}/checkpoints/checkpoint_{checkpoint}.pth")
                checkpoint = (checkpoint + 1) % config.checkpoints_keep_n

        # --- Epoch-level logging and validation remains the same ---
        avg_epoch_loss = epoch_loss / num_steps_per_epoch
        tqdm_bar.set_postfix(loss=avg_epoch_loss)
        logging.info(f"Epoch {epoch + 1}/{config.num_epochs} loss: {avg_epoch_loss:.4f}")
        val_loss = validate_model(eye_model, val_data)
        logging.info(f"Validation loss: {val_loss:.4f}")
        wandb.log({"epoch_loss": avg_epoch_loss, "val_loss": val_loss, "custom step": step, "epoch": epoch})
        if config.save_to_file:
            torch.save(eye_model.state_dict(), f"{config.save_dir}/checkpoints/epoch_{epoch + 1}.pth")
            torch.save(history, f"{config.save_dir}/history.pth")
            
    wandb.finish()
    logging.info("Training complete")

def validate_model(model, val_data):
    max_dataset_len = np.amax([len(data_dict['dataset']) for data_dict in val_data.values()])
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        num_steps_per_epoch = int((max_dataset_len // config.batch_size) + (max_dataset_len % config.batch_size > 0))
        tqdm_bar = tqdm(range(num_steps_per_epoch), desc="Validation")
        for _ in tqdm_bar:
            batch = get_training_batches(val_data)
            full_input_dict = next(iter(batch.values()))
            intermediate_dict = {}
            B, T, _, _, _ = full_input_dict["face_patch"].shape
            for side in config.sides:
                flip = config.flip_right_eye and side == "right"
                gaze_prediction_sequence, _ = model(full_input_dict, side, flip=flip)
                intermediate_dict[f"{side}_gaze"] = gaze_prediction_sequence
                origin = full_input_dict[f"{side}_o"].view(B*T, 3)
                rotation = full_input_dict[f"{side}_R"].view(B*T, 3, 3)
                inv_camera_transformation = full_input_dict['inv_camera_transformation'].view(B*T, 4, 4)
                pixels_per_millimeter = full_input_dict['pixels_per_millimeter'].view(B*T, 2)
                PoG_mm, PoG_px = to_screen_coordinates(
                    origin, 
                    gaze_prediction_sequence.view(B*T, 2),
                    rotation,
                    inv_camera_transformation,
                    pixels_per_millimeter
                )
                intermediate_dict[f"{side}_PoG_mm"] = PoG_mm.view(B, T, 2)
                intermediate_dict[f"{side}_PoG_px"] = PoG_px.view(B, T, 2)
            loss_dict = model.loss(full_input_dict, intermediate_dict)
            loss = loss_dict["loss"]
            val_loss += loss.item()
            tqdm_bar.set_postfix({"Loss": loss.item()})
            del loss, loss_dict, intermediate_dict, full_input_dict, batch
    avg_val_loss = val_loss / num_steps_per_epoch
    tqdm_bar.set_postfix({"Loss": avg_val_loss})
    return avg_val_loss

if __name__ == "__main__":
    main_training_loop()