import argparse

arg_lists = []
parser = argparse.ArgumentParser(description='MVF')


def str2bool(v):
    return v.lower() in ('true', '1')


def add_argument_group(name):
    arg = parser.add_argument_group(name)
    arg_lists.append(arg)
    return arg


# data params
data_arg = add_argument_group('Data Params')
data_arg.add_argument('--batch_size', type=int, default=4,
                      help='# of images in each batch of data')
data_arg.add_argument('--shuffle', type=str2bool, default=True,
                      help='Whether to shuffle the train and valid indices')

# training params
train_arg = add_argument_group('Training Params')
train_arg.add_argument('--epochs', type=int, default=50,
                       help='# of epochs to train for')
train_arg.add_argument('--init_lr', type=float, default=1e-3,
                       help='Initial learning rate value')
train_arg.add_argument('--lr_patience', type=int, default=5,
                       help='Number of epochs to wait before reducing lr')
train_arg.add_argument('--train_patience', type=int, default=1000,
                       help='Number of epochs to wait before stopping train')
train_arg.add_argument('--loss', type=str, default='mse',
                       help='Loss function to use (MSE or MAE)')

# other params
misc_arg = add_argument_group('Misc.')
misc_arg.add_argument('--use_gpu', type=str2bool, default=True,
                      help="Whether to run on the GPU")
train_arg.add_argument('--skipping', type=str2bool, default=False)
misc_arg.add_argument('--force_positive', type=str2bool, default=True,
                      help='Force the illumination weights to be positive')
misc_arg.add_argument('--random_seed', type=int, default=1,
                      help='Seed to ensure reproducibility')
misc_arg.add_argument('--num_channels', type=int, default=1,
                      help='Number of images to form with the physical layer')
misc_arg.add_argument('--num_heads', type=int, default=1,
                      help='Number of models to attach physical layer to')
misc_arg.add_argument('--batch_norm', type=str2bool, default=False,
                      help='To use batchnorm or not ( every layer)')
misc_arg.add_argument('--task', type=str, default='pan',
                      help='Task to train on')
misc_arg.add_argument('--init_strategy', type=str, default=None,
                      help='initialization strategy for physical layer')
misc_arg.add_argument('--num_filters', type=int, default=16,
                      help="number of starting filters in U-net")
misc_arg.add_argument('--l1_penalty', type=float, default=0.0004,
                      help='weight on l1 loss')
misc_arg.add_argument('--orth_penalty', type=float, default=0.0,
                      help='Penalty on non-ortho weights')
misc_arg.add_argument('--arch', type=str, default='unet',
                      help='Which architecture to use')
misc_arg.add_argument('--flip', type=str2bool, default=False,
                      help='Apply random flipping (vertical and horizontal)')
misc_arg.add_argument('--stats_mo', type=float, default=0.1)

# variance control
misc_arg.add_argument('--shift', type=str, default="", help="shifting of the LED pattern")
misc_arg.add_argument('--noise', type=float, default=0.1, help='amount of noise to add to the formed image')


def get_config():
    config, unparsed = parser.parse_known_args()
    return config, unparsed
