#!/usr/bin/env bash

#SBATCH --cpus-per-task=2
#SBATCH --gpus-per-task=1
#SBATCH --job-name=conv_cifar10
#SBATCH --mem-per-cpu=8GB
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4
#SBATCH --output=outputs/optimize_conv_cifar10.txt
#SBATCH --partition=batch_default
#SBATCH --requeue
#SBATCH --time=4-00:00:00

# A 2% accuracy drop is worth it if it halves latency and power consumption
ARGS="
    --granularities 1 2 3
    --global-v-initial 0.5
    --lambdas 50.0 1.0 1.0
    --max-iterations 100 1000 10000
    --ranks-per-node 4
    --auto-scale 0 1 1
    --n-time-chunks 5
    --val-freq 10
    --optimize-v-initial
"

mpirun -n 4 ./scripts/optimize.py \
    models/snn/conv_cifar10_sparse.h5 \
    models/ann/conv_cifar10_sparse.h5 \
    $ARGS
