import argparse
# from models.vgg import *
from models.preact_resnet import *
from models.cnn_mnist import *
from models.resnet import *
from utils import *
import torch
import torch.nn as nn
import torch.optim as optim
import logging
import time

# mnist lr 0.01
def args_parser():

    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--log_dir', type=str, default='./logs/')
    parser.add_argument('--model_dir', type=str, default="./saved_model/")
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--model', type=str, default="preact_resnet",
                        help='cnn_mnist, preact_resnet, resnet')
    parser.add_argument('--data', type=str, default="cifar10",
                        help='mnist, gtsrb, cifar10, imagenet, celeba')
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_classes', type=int, default=10)
    return parser.parse_args()


args = args_parser()
criterion = nn.CrossEntropyLoss()
criterion.to(args.device)


def test(model,test_loader):

    model.eval()
    loss_avgmeter = AverageMeter()
    acc_avgmeter = AverageMeter()

    for batch_idx, (data, label) in enumerate(test_loader):
        data = data.to(args.device)
        label = label.to(args.device)
        output = model(data)

        batch_acc=(output.argmax(1) == label.view(-1,)).float().sum()

        loss=criterion(output,label.view(-1,))
        loss_avgmeter.update(loss.detach(),data.size(0))
        acc_avgmeter.update(batch_acc.detach(),data.size(0),True)

    model.train()

    return loss_avgmeter.avg,acc_avgmeter.avg

def main():

    if args.data == "cifar10":
        args.num_classes = 10
        args.model = "preact_resnet"

    elif args.data == "gtsrb":
        args.num_classes = 43
        args.model = "preact_resnet"

    elif args.data == "celeba":
        args.num_classes = 8
        args.model = "resnet"

    elif args.data == "imagenet":
        args.num_classes = 20
        args.model = "resnet"

    elif args.data == "mnist":
        args.num_classes = 10
        args.model = "cnn_mnist"

    if args.model == "preact_resnet":
        model = PreActResNet18(num_classes=args.num_classes)
    elif args.model == "cnn_mnist":
        model = CNN_MNIST()
    elif args.model == "resnet":
        model = ResNet18(num_classes=args.num_classes)

    model = model.to(args.device)

    train_dataset, test_dataset = load_dataset(args)
    train_loader,test_loader = load_data(args, train_dataset, test_dataset)

    save_name = "train_clean" + "_" + args.model + "_" + args.data
    logging.basicConfig(filename=args.log_dir + save_name + '.txt', level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s')
    logging.FileHandler(args.log_dir + save_name + '.txt', mode='w+')
    loss_avgmeter = AverageMeter()
    acc_avgmeter = AverageMeter()

    counter = 0
    best_acc=0
    test_loss_min = 0
    lr = args.lr
    start_time = time.time()
    for epoch in range(args.epochs):
        model.train()

        if counter / 10 == 1:
            counter = 0
            lr = lr * 0.5

        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
        # weight_decay=1e-4 imagenet lr=0.01 lr =lr*0.1 [150,250]
        for batch_idx, (data, label) in enumerate(train_loader):
            data = data.to(args.device)
            label = label.to(args.device)
            output = model(data)
            optimizer.zero_grad()
            loss = criterion(output,label.view(-1,))
            loss.backward()
            optimizer.step()

            batch_acc = (output.argmax(1) == label.view(-1,)).float().sum()
            loss_avgmeter.update(loss.detach(),data.size(0))
            acc_avgmeter.update(batch_acc.detach(),data.size(0),True)


        time_elapsed = time.time() - start_time

        train_avg_loss = loss_avgmeter.avg
        train_avg_acc = acc_avgmeter.avg

        test_avg_loss,test_avg_acc = test(model,test_loader)

        print("""Epoch:{}/{}, Avg Train Loss:{:.6f}, Avg Train Acc:{:.4f}, Avg Test Loss:{:.6f}, Avg Test Acc:{:.4f}, Best Acc:\033[91m{:.4f} \033[0m""".\
                format(epoch,args.epochs,train_avg_loss,train_avg_acc,test_avg_loss,test_avg_acc,best_acc))
        logging.info(f'Epoch {epoch + 1}/{args.epochs}, Train_Loss: {train_avg_loss}, Train_Accuracy: {train_avg_acc},'
                     f' Test_Loss: {test_avg_loss}, Test_Accuracy: {test_avg_acc}')
        print('Elapsed Time: {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
        print('Current learning rate:{:.6f}'.format(lr))

        if test_avg_loss < test_loss_min:
            test_loss_min = test_avg_loss
            counter = 0
        else:
            counter += 1

        loss_avgmeter.reset()
        acc_avgmeter.reset()


        if best_acc<=test_avg_acc:
            # save model
            torch.save(model,args.model_dir + save_name + '.pt')
            best_acc=test_avg_acc

if __name__ == '__main__':
    main()

