"""
Visualization script for results of evaluation
"""

import pandas as pd
import plotly.express as px
import plotly.graph_objects as go

# Load CSV
csv_path = "validation_loss.csv"
df = pd.read_csv(csv_path, sep='\t' if '\t' in open(csv_path).readline() else ',')

# Metrics to analyze
metrics = [
    "loss_ang", "loss_face_ang",
    "loss_pog_px", "loss_face_pog_px",
    "loss_pog_cm", "loss_consistency"
]

# 1. Visualize mean, median, std for selected metrics with separate y-axes for different scales

# Group metrics by scale
low_scale_metrics = ["loss_ang", "loss_face_ang", "loss_pog_cm"]
high_scale_metrics = ["loss_pog_px", "loss_face_pog_px", "loss_consistency"]

# Compute stats
stats_low = pd.DataFrame({
    "mean": df[low_scale_metrics].mean(),
    "median": df[low_scale_metrics].median(),
    "std": df[low_scale_metrics].std()
})
stats_high = pd.DataFrame({
    "mean": df[high_scale_metrics].mean(),
    "median": df[high_scale_metrics].median(),
    "std": df[high_scale_metrics].std()
})

# Plot low scale metrics
fig1 = go.Figure()
for stat in stats_low.columns:
    fig1.add_trace(go.Bar(
        x=stats_low.index,
        y=stats_low[stat],
        name=f"{stat.capitalize()} (Low Scale)",
    ))
fig1.update_layout(
    title="Mean, Median, Std for Low Scale Metrics",
    barmode='group',
    yaxis_title=""
)
fig1.show()

# Plot high scale metrics
fig2 = go.Figure()
for stat in stats_high.columns:
    fig2.add_trace(go.Bar(
        x=stats_high.index,
        y=stats_high[stat],
        name=f"{stat.capitalize()} (High Scale)",
    ))
fig2.update_layout(
    title="Mean, Median, Std for High Scale Metrics",
    barmode='group',
    yaxis_title="Value (~80-120)"
)
fig2.show()

# 2. Per-stimulus angular error (loss_ang)
stimulus_stats = df.groupby("stimuli")["loss_ang"].agg(['mean', 'median', 'std']).reset_index()
fig2 = go.Figure()
for stat in ['mean', 'median', 'std']:
    fig2.add_trace(go.Bar(
        x=stimulus_stats["stimuli"],
        y=stimulus_stats[stat],
        name=stat.capitalize()
    ))
fig2.update_layout(
    title="Angular Error (loss_ang) per Stimulus",
    barmode='group',
    xaxis_title="Stimulus",
    yaxis_title="Angular Error"
)
fig2.show()

# 3. Per-participant angular error (loss_ang)
participant_stats = df.groupby("participant")["loss_ang"].agg(['mean', 'median', 'std']).reset_index()
fig3 = go.Figure()
for stat in ['mean', 'median', 'std']:
    fig3.add_trace(go.Bar(
        x=participant_stats["participant"],
        y=participant_stats[stat],
        name=stat.capitalize()
    ))
fig3.update_layout(
    title="Angular Error (loss_ang) per Participant",
    barmode='group',
    xaxis_title="Participant",
    yaxis_title="Angular Error"
)
fig3.show()

# 4. Per-camera angular error (loss_ang)
camera_stats = df.groupby("camera")["loss_ang"].agg(['mean', 'median', 'std']).reset_index()
fig4 = go.Figure()
for stat in ['mean', 'median', 'std']:
    fig4.add_trace(go.Bar(
        x=camera_stats["camera"],
        y=camera_stats[stat],
        name=stat.capitalize()
    ))
fig4.update_layout(
    title="Angular Error (loss_ang) per Camera",
    barmode='group',
    xaxis_title="Camera",
    yaxis_title="Angular Error"
)
fig4.show()