#!/usr/bin/env bash

#SBATCH --cpus-per-task=2
#SBATCH --gpus-per-task=1
#SBATCH --job-name=dense_mnist
#SBATCH --mem-per-cpu=8GB
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=2
#SBATCH --output=outputs/optimize_dense_mnist.txt
#SBATCH --partition=batch_default
#SBATCH --requeue
#SBATCH --time=2-00:00:00

# A 0.5% accuracy drop is worth it if it halves latency and power consumption
ARGS="
    --granularities 1 2 3
    --global-v-initial 0.5
    --lambdas 200.0 1.0 1.0
    --max-iterations 100 1000 10000
    --ranks-per-node 2
    --auto-scale 0 1 1
    --n-time-chunks 2
    --val-freq 10
    --optimize-v-initial
"

mpirun -n 2 ./scripts/optimize.py \
    models/snn/dense_mnist_sparse.h5 \
    models/ann/dense_mnist_sparse.h5 \
    $ARGS
