import argparse
from data_poison import *
from utils import *
from smooth import *

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--log_dir', type=str, default='./logs/')
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--model', type=str, default="preact_resnet")   # resnet preact_resnet
    parser.add_argument('--data', type=str, default="gtsrb")  #  cifar10 gtsrb celeba
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--lr', type=float, default=0.1)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--model_dir', type=str, default="./saved_model/")
    parser.add_argument('--num_classes', type=int, default=8)
    parser.add_argument('--attack_mode', type=str, default="clean",
                        help='clean, square, sig, refool, ftrojan, fiba, freq')
    parser.add_argument('--poison_ratio', type=float, default=0.1)
    parser.add_argument('--target_label', type=int, default=7)

    # ---------------------------- For Poisoned Image Transformation ------------------------
    parser.add_argument('--smooth_type', type=str, default='no_smooth',
                        help='gaussian, wiener, BM3D, no_smooth, jpeg')
    return parser.parse_args()


args=parse_args()
criterion = nn.CrossEntropyLoss()
criterion.to(args.device)

def test(model, test_loader, vars):

    model.eval()
    clean_loss_avgmeter = AverageMeter()
    clean_acc_avgmeter = AverageMeter()
    bd_loss_avgmeter = AverageMeter()
    bd_acc_avgmeter = AverageMeter()

    for batch_idx, (data, label) in enumerate(test_loader):
        if args.attack_mode == 'freq':
            data, label = freq_poison(args, data, label, args.target_label, poisoning_frac=1.0, vars=vars)
        elif args.attack_mode == 'ftrojan':
            data, label = ftrojan_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'fiba':
            data, label = fiba_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'sig':
            data, label = sig_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'square':
            data, label = square_poison(args, data, label, args.target_label, poison_ratio=1.0)
        data = smoothing(data, args.smooth_type)

        data = data.to(args.device)
        label = label.to(args.device)
        output = model(data)

        batch_acc = (output.argmax(1) == label.view(-1, )).float().sum()

        loss = criterion(output, label.view(-1, ))
        bd_loss_avgmeter.update(loss.detach(), data.size(0))
        bd_acc_avgmeter.update(batch_acc.detach(), data.size(0), True)

    for batch_idx, (data, label) in enumerate(test_loader):
        data = smoothing(data, args.smooth_type)
        data = data.to(args.device)

        label = label.to(args.device)
        output = model(data)

        batch_acc=(output.argmax(1) == label.view(-1,)).float().sum()

        loss=criterion(output,label.view(-1,))
        clean_loss_avgmeter.update(loss.detach(),data.size(0))
        clean_acc_avgmeter.update(batch_acc.detach(),data.size(0),True)

    model.train()

    return clean_loss_avgmeter.avg, clean_acc_avgmeter.avg, bd_loss_avgmeter.avg, bd_acc_avgmeter.avg

if args.attack_mode == 'freq':
    vars_save_name = "train_freq_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
                    args.poison_ratio) + "_vars"
    vars = torch.load(args.model_dir + vars_save_name + '.pt', map_location=args.device)
    save_name = "train_freq_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
                    args.poison_ratio)
elif args.attack_mode == 'clean':
    save_name = "train_" + args.attack_mode + "_" + args.model + "_" + args.data
else:
    save_name = "train_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(args.poison_ratio)
model = torch.load(args.model_dir + save_name + '.pt', map_location=args.device)
model.eval()
model = model.to(args.device)
train_dataset, test_dataset = load_dataset(args)
_, test_loader = load_data(args, train_dataset, test_dataset)
if args.attack_mode == 'freq':
    _, clean_acc, _, bd_acc = test(model, test_loader, vars =vars)
elif args.attack_mode == 'ftrojan':
    _, clean_acc, _, bd_acc = test(model,test_loader,vars=None)
elif args.attack_mode == 'fiba':
    _, clean_acc, _, bd_acc = test(model,test_loader,vars=None)
elif args.attack_mode == 'sig':
    _, clean_acc, _, bd_acc = test(model,test_loader,vars=None)
elif args.attack_mode == 'square':
    _, clean_acc, _, bd_acc = test(model,test_loader,vars=None)
elif args.attack_mode == 'clean':
    _, clean_acc,_,_ = test(model, test_loader, vars=None)

if args.attack_mode == 'clean':
    print("Clean accuracy:", clean_acc)
else:
    print("Clean accuracy, Backdoor accuracy:", clean_acc, bd_acc)