#!/usr/bin/env python3
"""
Quick verification that all WACV ablation changes are working
Tests configuration loading and basic functionality
"""

import sys
import os
sys.path.insert(0, os.path.dirname(__file__))

def test_config_parameters():
    """Test that new config parameters are present"""
    print("Testing Config class...")
    from RLAD1 import Config
    
    config = Config()
    
    # Check new parameters exist
    assert hasattr(config, 'trust_min'), "Missing trust_min parameter"
    assert hasattr(config, 'enable_llm_fallbacks'), "Missing enable_llm_fallbacks"
    assert hasattr(config, 'use_llm_alignment_reward'), "Missing use_llm_alignment_reward"
    assert hasattr(config, 'track_fallback_events'), "Missing track_fallback_events"
    
    # Check default values
    assert config.trust_min == 0.3, f"Wrong default trust_min: {config.trust_min}"
    assert config.enable_llm_fallbacks == True
    assert config.use_llm_alignment_reward == True
    assert config.track_fallback_events == True
    
    # Test parameter modification
    config.trust_min = 0.5
    assert config.trust_min == 0.5, "Can't modify trust_min"
    
    print("All config parameters present and functional")


def test_trust_gating():
    """Test TrustGating accepts trust_min parameter"""
    print("Testing TrustGating class...")
    from RLAD1 import TrustGating
    
    # Test with default trust_min
    tg1 = TrustGating(sensor_dim=64, llm_dim=12, device='cpu')
    assert hasattr(tg1, 'trust_min'), "TrustGating missing trust_min attribute"
    assert tg1.trust_min == 0.3, f"Wrong default: {tg1.trust_min}"
    
    # Test with custom trust_min
    tg2 = TrustGating(sensor_dim=64, llm_dim=12, device='cpu', trust_min=0.5)
    assert tg2.trust_min == 0.5, f"Wrong custom value: {tg2.trust_min}"
    
    # Test modulate_reward signature
    import torch
    base_reward = torch.tensor([1.0])
    trust_score = torch.tensor([0.8])
    llm_reward = torch.tensor([0.5])
    
    # Test with LLM reward
    total1 = tg2.modulate_reward(base_reward, trust_score, llm_reward, use_llm_reward=True)
    assert total1.item() > base_reward.item(), "LLM reward not added"
    
    # Test without LLM reward
    total2 = tg2.modulate_reward(base_reward, trust_score, llm_reward, use_llm_reward=False)
    assert total2.item() == base_reward.item(), "LLM reward added when disabled"
    
    print("TrustGating configurable and functional")


def test_multimodal_llm_reasoner():
    """Test MultimodalLLMReasoner fallback tracking"""
    print("Testing MultimodalLLMReasoner fallback tracking...")
    from RLAD1 import MultimodalLLMReasoner, Config
    
    config = Config()
    config.use_llm = False  # Disable LLM loading for testing
    config.use_yolo = False
    
    reasoner = MultimodalLLMReasoner(config)
    
    # Check fallback stats initialized
    assert hasattr(reasoner, 'fallback_stats'), "Missing fallback_stats"
    assert 'perception' in reasoner.fallback_stats
    assert 'planning' in reasoner.fallback_stats
    assert 'control' in reasoner.fallback_stats
    
    # Test recording fallback
    reasoner._record_fallback('perception', 'test_reason')
    assert reasoner.fallback_stats['perception']['count'] == 1
    assert 'test_reason' in reasoner.fallback_stats['perception']['reasons']
    assert reasoner.fallback_stats['perception']['reasons']['test_reason'] == 1
    
    # Test get stats
    stats = reasoner.get_fallback_stats()
    assert stats['total_fallbacks'] == 1
    assert 'by_stage' in stats
    
    # Test reset
    reasoner.reset_fallback_stats()
    assert reasoner.fallback_stats['perception']['count'] == 0
    
    print("Fallback tracking functional")


def test_ablation_configurations():
    """Test that ablation configurations are defined"""
    print("Testing ablation configurations...")
    
    # Read run_ablations.py to verify configurations
    import re
    with open('run_ablations.py', 'r', encoding='utf-8') as f:
        content = f.read()
    
    # Check for trust_min configurations
    # Look for dictionary keys: 'trust_min_0.1':
    trust_configs = re.findall(r"'trust_min_(\d\.\d)':", content)
    assert len(trust_configs) >= 5, f"Expected 5 trust_min configs, found {len(trust_configs)}"
    
    # Check for fallback config
    assert "'no_fallbacks':" in content, "Missing no_fallbacks config"
    
    # Check for reward config
    assert "'no_llm_reward':" in content, "Missing no_llm_reward config"
    
    print(f"Found {len(trust_configs)} trust configurations")
    print("Found fallback ablation config")
    print("Found reward ablation config")


def test_visualization_module():
    """Test that visualization functions exist"""
    print("Testing visualization module...")
    
    import ablation_visualizations as viz
    
    # Check functions exist
    assert hasattr(viz, 'plot_trust_floor_sensitivity'), "Missing trust sensitivity plot"
    assert hasattr(viz, 'plot_fallback_heatmap'), "Missing fallback heatmap"
    assert hasattr(viz, 'plot_reward_attribution'), "Missing reward attribution plot"
    
    print("All visualization functions present")


def main():
    """Run all verification tests"""
    print("\n" + "="*70)
    print("WACV ABLATION IMPLEMENTATION VERIFICATION")
    print("="*70 + "\n")
    
    tests = [
        ("Config Parameters", test_config_parameters),
        ("TrustGating Modifications", test_trust_gating),
        ("Fallback Tracking", test_multimodal_llm_reasoner),
        ("Ablation Configurations", test_ablation_configurations),
        ("Visualization Module", test_visualization_module),
    ]
    
    passed = 0
    failed = 0
    
    for test_name, test_func in tests:
        try:
            test_func()
            passed += 1
        except Exception as e:
            print(f"✗ {test_name} FAILED:")
            print(f"  Error: {e}")
            failed += 1
    
    print("\n" + "="*70)
    print(f"VERIFICATION RESULTS: {passed}/{len(tests)} tests passed")
    if failed == 0:
        print("ALL TESTS PASSED - Ready to run ablation experiments!")
    else:
        print(f"{failed} tests failed - Please review errors above")
    print("="*70 + "\n")
    
    return failed == 0


if __name__ == "__main__":
    success = main()
    sys.exit(0 if success else 1)
