import argparse

import matplotlib.pyplot as plt

from models.preact_resnet import *
from models.resnet import *
import logging
from data_poison import *

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--log_dir', type=str, default='./logs/')
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--model', type=str, default="resnet")
    parser.add_argument('--data', type=str, default="celeba")   # gtsrb cifar10 celeba
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--model_dir', type=str, default="./saved_model/")
    parser.add_argument('--num_classes', type=int, default=10)
    parser.add_argument('--attack_mode', type=str, default="freq",
                        help='clean, square, sig, refool, ftrojan, fiba, freq')
    parser.add_argument('--poison_ratio', type=float, default=0.1)
    parser.add_argument('--target_label', type=int, default=7)
    return parser.parse_args()

args = parse_args()
criterion = nn.CrossEntropyLoss()
criterion.to(args.device)


def test_bd(model, test_loader, vars=None):

    model.eval()
    clean_loss_avgmeter = AverageMeter()
    clean_acc_avgmeter = AverageMeter()
    bd_loss_avgmeter = AverageMeter()
    bd_acc_avgmeter = AverageMeter()

    for batch_idx, (data, label) in enumerate(test_loader):
        data = data.to(args.device)
        label = label.to(args.device)
        output = model(data)

        batch_acc=(output.argmax(1) == label.view(-1,)).float().sum()

        loss=criterion(output,label.view(-1,))
        clean_loss_avgmeter.update(loss.detach(),data.size(0))
        clean_acc_avgmeter.update(batch_acc.detach(),data.size(0),True)

    for batch_idx, (data, label) in enumerate(test_loader):

        if args.attack_mode == 'square':
            data, label = square_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'sig':
            data, label = sig_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'blend':
            data, label = blend_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'ftrojan':
            data, label = ftrojan_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'fiba':
            data, label = fiba_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'freq':
            data, label = freq_poison(args, data, label, args.target_label, poisoning_frac=1.0, vars=vars)
        else:
            raise Exception(f'Error, unknown attack mode{args.attack_mode}')

        data = data.to(args.device)
        label = label.to(args.device)
        output = model(data)

        batch_acc = (output.argmax(1) == label.view(-1, )).float().sum()

        loss = criterion(output, label.view(-1, ))
        bd_loss_avgmeter.update(loss.detach(), data.size(0))
        bd_acc_avgmeter.update(batch_acc.detach(), data.size(0), True)

    model.train()

    return clean_loss_avgmeter.avg, clean_acc_avgmeter.avg, bd_loss_avgmeter.avg, bd_acc_avgmeter.avg


def main():

    save_name = "pruning_log_"+args.model+"_"+args.data+"_"+args.attack_mode
    logging.basicConfig(filename=args.log_dir + save_name + '.txt', level=logging.INFO,
                        format='%(asctime)s - %(levelname)s - %(message)s')
    logging.FileHandler(args.log_dir + save_name + '.txt', mode='w+')

    if args.attack_mode == "clean":
        save_name = "train_clean"+"_"+args.model+"_"+args.data
    elif args.attack_mode == "freq":
        vars_save_name = "train_freq_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
            args.poison_ratio) + "_vars"
        vars = torch.load(args.model_dir + vars_save_name + '.pt', map_location=args.device)
        save_name = "train_freq_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
            args.poison_ratio)
    else:
        save_name = "train_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
            args.poison_ratio)

    if args.data == "cifar10":
        num_classes = 10
    elif args.data == "gtsrb":
        num_classes = 43
    elif args.data == "celeba":
        num_classes = 8
    elif args.data == "imagenet":
        num_classes = 200

    model = torch.load(args.model_dir + save_name + '.pt', map_location=args.device)
    model.eval()
    model.requires_grad_(False)
    model = model.to(args.device)

    #prepare dataloader
    train_dataset, test_dataset = load_dataset(args)
    _, test_loader = load_data(args, train_dataset, test_dataset)


    # Forward hook for getting layer's activation
    container = []


    def forward_hook(module, input, output):
        # container.append(output.detach())
        container.append(output)

    # if args.model == "vgg16_bn":
    #     hook = model.features[40].register_forward_hook(forward_hook)

    hook = model.layer4.register_forward_hook(forward_hook)

    # Forwarding all the validation set
    print("Forwarding all the validation dataset:")
    for batch_idx, (data, _) in enumerate(test_loader):
        data = data.to(args.device)
        model(data)

    # Processing to get the "more important mask"
    container = torch.cat(container, dim=0)
    activation = torch.mean(container, dim=[0, 2, 3])
    seq_sort = torch.argsort(activation)
    pruning_mask = torch.ones(seq_sort.shape[0], dtype=bool)
    hook.remove()

    # Pruning times - no-tuning after pruning a channel!!!
    acc_clean = []
    acc_bd = []

    for index in range(pruning_mask.shape[0]):
        net_pruned = copy.deepcopy(model)
        net_pruned.to(args.device)
        num_pruned = index
        if index:
            channel = seq_sort[index - 1]
            pruning_mask[channel] = False
        print("Pruned {} filters".format(num_pruned))

        net_pruned.layer4[1].conv2 = nn.Conv2d(
            pruning_mask.shape[0], pruning_mask.shape[0] - num_pruned, (3, 3), stride=1, padding=1, bias=False
        )

        net_pruned.linear = nn.Linear(pruning_mask.shape[0] - num_pruned, num_classes)

        # Re-assigning weight to the pruned net
        for name, module in net_pruned._modules.items():
            if "layer4" in name:
                module[1].conv2.weight.data = model.layer4[1].conv2.weight.data[pruning_mask]
                module[1].ind = pruning_mask
            elif "linear" == name:
                module.weight.data = model.linear.weight.data[:, pruning_mask]
                module.bias.data = model.linear.bias.data
            else:
                continue

        net_pruned.to(args.device)

        if  args.attack_mode=="clean":
            test_avg_loss,test_avg_acc = test(net_pruned, test_loader)
            logging.info(
                f'Index: {index}, Test_Accuracy: {test_avg_acc}')
        elif args.attack_mode=="freq":
            _, clean_test_avg_acc, _, bd_test_avg_acc = test_bd(net_pruned, test_loader, vars=vars)
            acc_clean.append(clean_test_avg_acc.cpu().numpy())
            acc_bd.append(bd_test_avg_acc.cpu().numpy())
            print('clean acc, bd acc:', clean_test_avg_acc, bd_test_avg_acc)
            logging.info(
                f'Index: {index}, Clean_Test_Accuracy: {clean_test_avg_acc}, Backdoor_Test_Accuracy: {bd_test_avg_acc}')
        else:
            _, clean_test_avg_acc, _, bd_test_avg_acc = test_bd(net_pruned, test_loader, vars=None)
            acc_clean.append(clean_test_avg_acc.cpu().numpy())
            acc_bd.append(bd_test_avg_acc.cpu().numpy())
            print('clean acc, bd acc:', clean_test_avg_acc, bd_test_avg_acc)

            logging.info(
                f'Index: {index}, Clean_Test_Accuracy: {clean_test_avg_acc}, Backdoor_Test_Accuracy: {bd_test_avg_acc}')

if __name__ == '__main__':
    main()
