#!/usr/bin/env python3
"""
Visualization functions for WACV ablation study results
Generates trust sensitivity plots, fallback heatmaps, and reward attribution analyses
"""

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from typing import Dict, List
import os


def plot_trust_floor_sensitivity(results: Dict, output_dir: str = "."):
    """Plot trust floor sensitivity analysis
    
    Args:
        results: Dictionary mapping trust_min values to performance metrics
        output_dir: Directory to save the plot
    """
    # Extract trust_min configurations
    trust_configs = {
        name: metrics for name, metrics in results.items() 
        if 'trust_min' in name
    }
    
    if not trust_configs:
        print("No trust_min configurations found in results")
        return
    
    # Parse trust_min values and metrics
    trust_values = []
    avg_rewards = []
    std_rewards = []
    success_rates = []
    collisions = []
    
    for config_name in sorted(trust_configs.keys()):
        # Extract trust_min value from name (e.g., 'trust_min_0.3' -> 0.3)
        trust_val = float(config_name.split('_')[-1])
        metrics = trust_configs[config_name]
        
        trust_values.append(trust_val)
        avg_rewards.append(metrics.get('avg_reward', 0))
        std_rewards.append(metrics.get('std_reward', 0))
        success_rates.append(metrics.get('success_rate', 0))
        collisions.append(metrics.get('avg_collisions', 0))
    
    # Create figure with subplots
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    fig.suptitle('Trust Floor Sensitivity Analysis', fontsize=16, fontweight='bold')
    
    # Plot 1: Average Reward vs τ_min
    axes[0, 0].plot(trust_values, avg_rewards, 'o-', linewidth=2, markersize=8, color='tab:blue')
    axes[0, 0].fill_between(trust_values, 
                             [r - s for r, s in zip(avg_rewards, std_rewards)],
                             [r + s for r, s in zip(avg_rewards, std_rewards)],
                             alpha=0.2, color='tab:blue')
    axes[0, 0].set_xlabel('Trust Floor (τ_min)', fontsize=12)
    axes[0, 0].set_ylabel('Average Reward', fontsize=12)
    axes[0, 0].set_title('Performance vs Trust Floor')
    axes[0, 0].grid(True, alpha=0.3)
    axes[0, 0].axvline(x=0.3, color='red', linestyle='--', alpha=0.5, label='Default (0.3)')
    axes[0, 0].legend()
    
    # Plot 2: Reward Variance vs τ_min
    axes[0, 1].plot(trust_values, std_rewards, 's-', linewidth=2, markersize=8, color='tab:orange')
    axes[0, 1].set_xlabel('Trust Floor (τ_min)', fontsize=12)
    axes[0, 1].set_ylabel('Reward Std Dev', fontsize=12)
    axes[0, 1].set_title('Stability vs Trust Floor')
    axes[0, 1].grid(True, alpha=0.3)
    axes[0, 1].axvline(x=0.3, color='red', linestyle='--', alpha=0.5, label='Default (0.3)')
    axes[0, 1].legend()
    
    # Plot 3: Success Rate vs τ_min
    axes[1, 0].plot(trust_values, success_rates, '^-', linewidth=2, markersize=8, color='tab:green')
    axes[1, 0].set_xlabel('Trust Floor (τ_min)', fontsize=12)
    axes[1, 0].set_ylabel('Success Rate', fontsize=12)
    axes[1, 0].set_title('Success Rate vs Trust Floor')
    axes[1, 0].grid(True, alpha=0.3)
    axes[1, 0].axvline(x=0.3, color='red', linestyle='--', alpha=0.5, label='Default (0.3)')
    axes[1, 0].legend()
    axes[1, 0].set_ylim([0, 1])
    
    # Plot 4: Collisions vs τ_min
    axes[1, 1].plot(trust_values, collisions, 'v-', linewidth=2, markersize=8, color='tab:red')
    axes[1, 1].set_xlabel('Trust Floor (τ_min)', fontsize=12)
    axes[1, 1].set_ylabel('Average Collisions per Episode', fontsize=12)
    axes[1, 1].set_title('Safety vs Trust Floor')
    axes[1, 1].grid(True, alpha=0.3)
    axes[1, 1].axvline(x=0.3, color='red', linestyle='--', alpha=0.5, label='Default (0.3)')
    axes[1, 1].legend()
    
    plt.tight_layout()
    output_path = os.path.join(output_dir, 'trust_sensitivity_analysis.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Trust sensitivity plot saved to {output_path}")
    
    # Generate summary table
    df = pd.DataFrame({
        'τ_min': trust_values,
        'Avg Reward': avg_rewards,
        'Std Reward': std_rewards,
        'Success Rate': success_rates,
        'Avg Collisions': collisions
    })
    
    table_path = os.path.join(output_dir, 'trust_sensitivity_table.csv')
    df.to_csv(table_path, index=False, float_format='%.3f')
    print(f"Trust sensitivity table saved to {table_path}")
    
    return df


def plot_fallback_heatmap(fallback_data: pd.DataFrame, output_dir: str = "."):
    """Create heatmap of fallback trigger rates by scenario and stage
    
    Args:
        fallback_data: DataFrame with columns [scenario, stage, fallback_count]
        output_dir: Directory to save the plot
    """
    # Pivot data for heatmap
    if 'scenario' not in fallback_data.columns or not len(fallback_data):
        print("Insufficient fallback data for heatmap")
        return
    
    heatmap_data = fallback_data.pivot_table(
        values='fallback_count',
        index='scenario',
        columns='stage',
        aggfunc='sum',
        fill_value=0
    )
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(heatmap_data, annot=True, fmt='.0f', cmap='YlOrRd', cbar_kws={'label': 'Fallback Count'})
    plt.title('Fallback Trigger Frequency by Scenario and Stage', fontsize=14, fontweight='bold')
    plt.xlabel('LLM Stage', fontsize=12)
    plt.ylabel('Driving Scenario', fontsize=12)
    plt.tight_layout()
    
    output_path = os.path.join(output_dir, 'fallback_heatmap.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Fallback heatmap saved to {output_path}")


def plot_reward_attribution(with_llm_reward: Dict, without_llm_reward: Dict, output_dir: str = "."):
    """Plot reward attribution analysis comparing with/without LLM alignment reward
    
    Args:
        with_llm_reward: Metrics with LLM alignment reward enabled
        without_llm_reward: Metrics with LLM alignment reward disabled
        output_dir: Directory to save the plot
    """
    categories = ['Average\nReward', 'Success\nRate', 'Collisions\n(×-1)', 'Std Reward\n(×-1)']
    
    with_llm = [
        with_llm_reward.get('avg_reward', 0),
        with_llm_reward.get('success_rate', 0) * 100,  # Scale to 0-100
        -with_llm_reward.get('avg_collisions', 0),  # Negative for visualization
        -with_llm_reward.get('std_reward', 0)  # Negative f or visualization
    ]
    
    without_llm = [
        without_llm_reward.get('avg_reward', 0),
        without_llm_reward.get('success_rate', 0) * 100,
        -without_llm_reward.get('avg_collisions', 0),
        -without_llm_reward.get('std_reward', 0)
    ]
    
    x = np.arange(len(categories))
    width = 0.35
    
    fig, ax = plt.subplots(figsize=(12, 6))
    bars1 = ax.bar(x - width/2, with_llm, width, label='With LLM Reward', color='tab:blue', alpha=0.8)
    bars2 = ax.bar(x + width/2, without_llm, width, label='Without LLM Reward', color='tab:orange', alpha=0.8)
    
    ax.set_xlabel('Metric', fontsize=12)
    ax.set_ylabel('Value', fontsize=12)
    ax.set_title('Reward Engineering Ablation: Impact of LLM Alignment Reward', fontsize=14, fontweight='bold')
    ax.set_xticks(x)
    ax.set_xticklabels(categories)
    ax.legend(fontsize=11)
    ax.grid(axis='y', alpha=0.3)
    ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    
    # Add value labels on bars
    def autolabel(bars):
        for bar in bars:
            height = bar.get_height()
            ax.annotate(f'{height:.1f}',
                       xy=(bar.get_x() + bar.get_width() / 2, height),
                       xytext=(0, 3 if height > 0 else -15),
                       textcoords="offset points",
                       ha='center', va='bottom' if height > 0 else 'top',
                       fontsize=9)
    
    autolabel(bars1)
    autolabel(bars2)
    
    plt.tight_layout()
    output_path = os.path.join(output_dir, 'reward_attribution.png')
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Reward attribution plot saved to {output_path}")
    
    # Calculate improvement percentage
    improvement = {
        'avg_reward': ((with_llm[0] - without_llm[0]) / abs(without_llm[0]) * 100) if without_llm[0] != 0 else 0,
        'success_rate': ((with_llm[1] - without_llm[1]) / abs(without_llm[1]) * 100) if without_llm[1] != 0 else 0,
    }
    
    print(f"\\nReward Engineering Results:")
    print(f"  With LLM Reward: {with_llm_reward.get('avg_reward', 0):.2f} ± {with_llm_reward.get('std_reward', 0):.2f}")
    print(f"  Without LLM Reward: {without_llm_reward.get('avg_reward', 0):.2f} ± {without_llm_reward.get('std_reward', 0):.2f}")
    print(f"  Improvement: {improvement['avg_reward']:.1f}%")


if __name__ == "__main__":
    # Example usage
    print("WACV Ablation Visualization Module")
    print("Import this module and call the plotting functions with your experimental results")
