import wandb
import torch

from dataloader import get_train_val_loader
from config import get_config
from model import Model
from trainer import Trainer

wandb.init('MVF')


def main(config):
    # setup
    torch.manual_seed(config.random_seed)
    if config.use_gpu:
        torch.cuda.manual_seed(config.random_seed)
    # get data-loaders
    # create a model

    train_dataset, val_dataset, num_leds, padded_data = get_train_val_loader(config, pin_memory=True)
    model = Model(config.num_heads, num_leds=num_leds, arch=config.arch, num_channels=config.num_channels,
                  batch_norm=config.batch_norm,
                  initilization_strategy=config.init_strategy, num_filters=config.num_filters, task=config.task,
                  noise=config.noise, padding=padded_data, force_positive=config.force_positive, skipping=config.skipping, stats_mo=config.stats_mo)
    if config.use_gpu:
        model.cuda()
        model.noise_layer.cuda()
        [net.cuda() for net in model.nets]
    params = list(model.parameters())
    for net in model.nets:
        params += list(net.parameters())
    # setup optimizer
    optimizer = torch.optim.Adam(params, lr=config.init_lr)
    trainer = Trainer(model, optimizer, train_dataset, val_dataset, config)
    wandb.config.update(config)
    trainer.train()


if __name__ == "__main__":
    config, unparsed = get_config()
    main(config)
