#!/usr/bin/env python3
"""
WACV Ablation Study Runner
Executes trust sensitivity, fallback, and reward engineering ablations
"""

import subprocess
import os
import sys
from pathlib import Path

# Ablation configurations to run
ABLATION_CONFIGS = {
    # 1. Pure Baseline (TD3 only, no extra features)
    'baseline': ['--algorithm', 'TD3', '--total_timesteps', '200000', 
                 '--config_override', 'use_llm=False', '--config_override', 'use_trust_gating=False', 
                 '--config_override', 'use_yolo=False', '--config_override', 'use_domain_randomization=False', 
                 '--config_override', 'fine_tune_llm=False'],

    # 2. Baseline + YOLO (Does better perception help without LLM?)
    'baseline_yolo': ['--algorithm', 'TD3', '--total_timesteps', '200000', 
                      '--config_override', 'use_llm=False', '--config_override', 'use_trust_gating=False', 
                      '--config_override', 'use_yolo=True', '--config_override', 'fine_tune_llm=False'],
                      
    # 3. LLM Only (Naive usage, no trust gating)
    'llm_only': ['--algorithm', 'TD3', '--total_timesteps', '200000', 
                 '--config_override', 'use_llm=True', '--config_override', 'use_trust_gating=False', 
                 '--config_override', 'use_yolo=False', '--config_override', 'fine_tune_llm=False'],

    # 4. LLM + Trust (No extra perception/domain features)
    'llm_trust': ['--algorithm', 'TD3', '--total_timesteps', '200000', 
                  '--config_override', 'use_llm=True', '--config_override', 'use_trust_gating=True', 
                  '--config_override', 'use_yolo=False', '--config_override', 'fine_tune_llm=False'],

    # 5. LLM + Trust + YOLO
    'llm_trust_yolo': ['--algorithm', 'TD3', '--total_timesteps', '200000', 
                       '--config_override', 'use_llm=True', '--config_override', 'use_trust_gating=True', 
                       '--config_override', 'use_yolo=True', '--config_override', 'fine_tune_llm=False'],

    # 6. Full System (No Finetuning)
    'full_system_no_finetune': ['--algorithm', 'TD3', '--total_timesteps', '200000', 
                                '--config_override', 'fine_tune_llm=False', '--config_override', 'use_domain_randomization=True'],

    # 7. Full System (Everything enabled: LLM, Trust, YOLO, Domain, Finetuning)
    'full_system': ['--algorithm', 'TD3', '--total_timesteps', '200000', 
                    '--config_override', 'fine_tune_llm=True', '--config_override', 'use_domain_randomization=True'],

    # 8-12. Trust Sensitivity (on Full System)
    'trust_min_0.1': ['--algorithm', 'TD3', '--total_timesteps', '200000', '--config_override', 'trust_min=0.1'],
    'trust_min_0.2': ['--algorithm', 'TD3', '--total_timesteps', '200000', '--config_override', 'trust_min=0.2'],
    'trust_min_0.3': ['--algorithm', 'TD3', '--total_timesteps', '200000', '--config_override', 'trust_min=0.3'],
    'trust_min_0.4': ['--algorithm', 'TD3', '--total_timesteps', '200000', '--config_override', 'trust_min=0.4'],
    'trust_min_0.5': ['--algorithm', 'TD3', '--total_timesteps', '200000', '--config_override', 'trust_min=0.5'],
    
    # 13. Fallback Ablation
    'no_fallbacks': ['--algorithm', 'TD3', '--total_timesteps', '200000', '--config_override', 'enable_llm_fallbacks=False'],
    
    # 14. Reward Engineering Ablation
    'no_llm_reward': ['--algorithm', 'TD3', '--total_timesteps', '200000', '--config_override', 'use_llm_alignment_reward=False'],
    
    # 15. Baseline + Domain Randomization (Does domain rand help baseline?)
    'baseline_domain': ['--algorithm', 'TD3', '--total_timesteps', '200000', 
                        '--config_override', 'use_llm=False', '--config_override', 'use_trust_gating=False', 
                        '--config_override', 'use_domain_randomization=True', '--config_override', 'fine_tune_llm=False'],
}

# Seeds for reproducibility
SEEDS = [42, 123, 456]

# Evaluation episodes
EVAL_EPISODES = 50


def run_ablation(config_name: str, args: list, seed: int):
    """Run a single ablation configuration with a specific seed"""
    print(f"\n{'='*80}")
    print(f"Running: {config_name} (seed={seed})")
    print(f"{'='*80}\n")
    
    cmd = [
        sys.executable,  # Python interpreter
        'RLAD1.py',
        *args,
        '--seed', str(seed),
        '--eval_episodes', str(EVAL_EPISODES),
        '--results_dir', f'ablation_results/{config_name}/seed_{seed}'
    ]
    
    print(f"Command: {' '.join(cmd)}\n")
    
    try:
        result = subprocess.run(cmd, check=True, capture_output=False, text=True)
        print(f"\nCompleted: {config_name} (seed={seed})")
        return True
    except subprocess.CalledProcessError as e:
        print(f"\nFailed: {config_name} (seed={seed})")
        print(f"Error: {e}")
        return False


def main():
    """Run all ablation studies"""
    print(f"\nWACV Ablation Study Runner")
    print(f"{'='*80}")
    print(f"Total configurations: {len(ABLATION_CONFIGS)}")
    print(f"Seeds per config: {len(SEEDS)}")
    print(f"Total runs: {len(ABLATION_CONFIGS) * len(SEEDS)}")
    print(f"Eval episodes per run: {EVAL_EPISODES}")
    print(f"{'='*80}\n")
    
    # Create results directory
    os.makedirs('ablation_results', exist_ok=True)
    
    # Track results
    successful = []
    failed = []
    
    # Run each configuration with each seed
    for config_name, args in ABLATION_CONFIGS.items():
        for seed in SEEDS:
            success = run_ablation(config_name, args, seed)
            if success:
                successful.append((config_name, seed))
            else:
                failed.append((config_name, seed))
    
    # Print summary
    print(f"\n\n{'='*80}")
    print(f"ABLATION STUDY COMPLETE")
    print(f"{'='*80}")
    print(f"Successful runs: {len(successful)}/{len(ABLATION_CONFIGS) * len(SEEDS)}")
    print(f"Failed runs: {len(failed)}")
    
    if failed:
        print(f"\nFailed configurations:")
        for config_name, seed in failed:
            print(f"  - {config_name} (seed={seed})")
    
    print(f"\nResults saved to: ablation_results/")
    print(f"\nNext steps:")
    print(f"  1. Analyze results with ablation_visualizations.py")
    print(f"  2. Generate tables for paper")
    print(f"  3. Update main.tex with new findings")
    print(f"{'='*80}\n")
    

if __name__ == "__main__":
    # Change to script directory
    script_dir = Path(__file__).parent
    os.chdir(script_dir)
    
    # Run ablations
    main()
