import csv
import os
from pathlib import Path

import torch


def save_model_results(save_dir: str, model_name: str | Path, results: dict):
    epochs = list(range(len(results["train_loss"])))
    results_transposed = list(zip(epochs, *results.values()))
    headers = ["epochs"] + list(results.keys())

    with open(os.path.join(save_dir, f"{model_name}.csv"), "w", newline="") as file:
        writer = csv.writer(file)
        writer.writerow(headers),
        writer.writerows(results_transposed)


def save_model(model: torch.nn.Module, save_dir: str | Path, model_name: str) -> None:
    target_dir_path = Path(save_dir)
    target_dir_path.mkdir(parents=True, exist_ok=True)

    assert model_name.endswith(".pth") or model_name.endswith(".pt"), "model_name should end with '.pt' or '.pth'"
    model_save_path = target_dir_path / model_name

    print(f"Saving model to: {model_save_path}")
    torch.save(obj=model.state_dict(), f=model_save_path)
