import os
import argparse
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data.dataset import random_split
from image_datasets import MyCocoClassification


def train_one_epoch(epoch, model, loss_fn, optimizer, train_loader):
    print("\nEpoch {} starting.".format(epoch))
    epoch_loss = 0
    batch_index = 0
    num_batch = len(train_loader)
    correct = 0.0
    for batch in train_loader:
        model.train()
        batch_index += 1
        data, target = batch[0].cuda(), batch[1].cuda()
        predict = model(data)
        loss = loss_fn(predict, target)
        pred = predict > 0.5
        correct += torch.sum(pred == target).cpu().detach()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_index % 1 == 0:
            train_log = 'Epoch {:2d}\tLoss: {:.2f}\tTrain: [{}/{} ({:.0f}%)]'.format(
                epoch, loss.cpu().item(),
                batch_index, num_batch,
                100. * batch_index / num_batch)
            print(train_log, end='\r')

        epoch_loss += loss.data.cpu().detach() * len(target)

    print(len(train_loader.dataset))
    epoch_loss /= len(train_loader.dataset)
    train_acc = correct / (len(train_loader.dataset) * 91) * 100
    print()
    print("Train average loss: {:.6f}\t".format(epoch_loss))
    print("Train accuracy: {:.2f}%".format(train_acc))
    return epoch_loss, train_acc


def validate(model, loss_fn, valid_loader):
    model.eval()
    valid_loss = 0
    correct = 0.0
    for batch in valid_loader:
        with torch.no_grad():
            data, target = batch[0].cuda(), batch[1].cuda()
            predict = model(data)
            pred = predict > 0.5
            correct += torch.sum(pred == target)

        valid_loss += loss_fn(predict, target).data * len(target)

    valid_loss /= len(valid_loader.dataset)
    valid_acc = correct / (len(valid_loader.dataset) * 91) * 100
    print('Valid average loss: {:.6f}\t'.format(valid_loss))
    print("Valid accuracy: {:.2f}%".format(valid_acc))
    return valid_loss, valid_acc


def main(args):
    data_transforms = {
        'train': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(*NORM),
        ]),
        'val': transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(*NORM),
        ]),
    }

    datasets = {'val': MyCocoClassification(root='./image_data/coco/val2017',
                                            annFile='./image_data/coco/annotations/instances_val2017.json',
                                            transform=data_transforms['val']),
                'train': MyCocoClassification(root='./image_data/coco/train2017',
                                              annFile='./image_data/coco/annotations/instances_train2017.json',
                                              transform=data_transforms['train'])}
    dataloaders = {x: torch.utils.data.DataLoader(datasets[x], batch_size=args.batch_size,
                                                  shuffle=True, num_workers=0)
                   for x in ['train', 'val']}

    model = models.resnet50(pretrained=True)
    model.fc = nn.Linear(512 * 4, 91)
    model = model.cuda()
    ckpt_path = 'model_outputs'
    experiment_name = 'pretrain_coco_resnet50'

    parameters = model.parameters()
    optimizer = torch.optim.Adam(parameters)
    scheduler = ReduceLROnPlateau(optimizer, verbose=True)
    loss_fn = nn.BCEWithLogitsLoss()
    print("Model setup...")

    # Train and validate
    best_valid_loss = float('inf')
    train_accuracies = []
    valid_accuracies = []
    with open(os.path.join(ckpt_path, 'valid_%s.txt' % experiment_name), 'w') as f:
        for epoch in range(args.epochs):
            train_loss, train_acc = train_one_epoch(epoch, model, loss_fn, optimizer, dataloaders['train'])
            ave_valid_loss, valid_acc = validate(model, loss_fn, dataloaders['val'])
            train_accuracies.append(train_acc)
            valid_accuracies.append(valid_acc)
            scheduler.step(ave_valid_loss)
            f.write('epoch: %d\n' % epoch)
            f.write('train loss: %f\n' % train_loss)
            f.write('train accuracy: %f\n' % train_acc)
            f.write('validation loss: %f\n' % best_valid_loss)
            f.write('validation accuracy: %f\n' % valid_acc)

            plt.figure()
            plt.plot(train_accuracies, '-o', label='train')
            plt.plot(valid_accuracies, '-o', label='valid')
            plt.xlabel('Epoch')
            plt.legend(loc='upper right')
            plt.savefig('accuracies_%s.png' % experiment_name)
            plt.close()

            if ave_valid_loss < best_valid_loss:
                best_valid_loss = ave_valid_loss
                print('==> new checkpoint saved')
                f.write('==> new checkpoint saved')
                torch.save({
                    'epoch': epoch,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict()
                }, os.path.join(ckpt_path,
                                'ckpt_%s.pth.tar' % experiment_name))


if __name__ == "__main__":
    NORM = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    parser = argparse.ArgumentParser()
    parser.add_argument('--batch-size', type=int, default=64, metavar='BS')
    parser.add_argument('--epochs', type=int, default=3, metavar='E')
    args = parser.parse_args()
    main(args)
