import argparse
import lpips
import pytorch_ssim
from data_poison import *

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--log_dir', type=str, default='./logs/')
    parser.add_argument('--model_dir', type=str, default="./saved_model/")
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--model', type=str, default="resnet")  # cnn_mnist / resnet18(43)/ resnet18 / resnet18
    parser.add_argument('--data', type=str, default="celeba")  #  gtsrb / cifar10 / imagenet / celeba
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--lr', type=float,
                        default=0.01)  # lr = 0.05 for gtsrb  0.001 for mnist 0.01 for cifar/imagenet
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--num_classes', type=int, default=10)

    # ---------------------------- For Backdoor Attack --------------------------

    parser.add_argument('--attack_mode', type=str, default="fiba")  # square / sig / ftrojan / fiba / freq
    parser.add_argument('--poison_ratio', type=float, default=0.1)
    parser.add_argument('--target_label', type=int, default=7)
    parser.add_argument('--epsilon', type=float, default=1.5,
                        help='perturbation: cifar10 1; gtsrb 0.5; imagenet 1.0; mnist 0.2')
    return parser.parse_args()

args=parse_args()
# load data for train
train_dataset, test_dataset = load_dataset(args)
sa_train_dataset = copy.deepcopy(train_dataset)
sa_test_dataset = copy.deepcopy(test_dataset)
sa_train_loader, sa_test_loader = load_data(args, sa_train_dataset, sa_test_dataset)
args.batch_size = len(test_dataset)
_, test_loader = load_data(args, train_dataset, test_dataset)

print('finish loading dataset')

def ssim():

    random_idx = np.random.permutation(len(test_dataset))
    random_idx = random_idx[:500]

    colls = []

    for batch_idx, (data, label) in enumerate(test_loader):
        if args.attack_mode == 'freq':
            bd_data, bd_label = freq_poison(args, data, label, args.target_label, poisoning_frac=1.0, vars=vars)
        elif args.attack_mode == 'sig':
            bd_data, bd_label = sig_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'square':
            bd_data, bd_label = square_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'blend':
            bd_data, bd_label = blend_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'fiba':
            bd_data, bd_label = fiba_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'ftrojan':
            bd_data, bd_label = ftrojan_poison(args, data, label, args.target_label, poison_ratio=1.0)
        else:
            raise Exception(f'Error, unknown attack mode{args.attack_mode}')
    for idx in random_idx:
        cln_img = data[idx].unsqueeze(0)
        bd_img = bd_data[idx].unsqueeze(0)
        colls.append(pytorch_ssim.ssim(bd_img, cln_img).item())

    print('ssim:', np.mean(colls))


def Lpips():

    loss_fn_alex = lpips.LPIPS(net='alex') # best forward scores
    colls=[]
    random_idx = np.random.permutation(len(test_dataset))
    random_idx=random_idx[:500]
    for batch_idx, (data, label) in enumerate(test_loader):
        if args.attack_mode == 'freq':
            bd_data, bd_label = freq_poison(args, data, label, args.target_label, poisoning_frac=1.0, vars=vars)
        elif args.attack_mode == 'sig':
            bd_data, bd_label = sig_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'square':
            bd_data, bd_label = square_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'blend':
            bd_data, bd_label = blend_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'fiba':
            bd_data, bd_label = fiba_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'ftrojan':
            bd_data, bd_label = ftrojan_poison(args, data, label, args.target_label, poison_ratio=1.0)
        else:
            raise Exception(f'Error, unknown attack mode{args.attack_mode}')

    for idx in random_idx:
        cln_img = data[idx].unsqueeze(0)
        bd_img = bd_data[idx].unsqueeze(0)
        d = loss_fn_alex(cln_img, bd_img)
        colls.append(d.item())

    print('lpips:', np.mean(colls))


def psnr():

    random_idx = np.random.permutation(len(test_dataset))
    random_idx=random_idx[:500]
    max_value=1.0
    colls=[]

    for batch_idx, (data, label) in enumerate(test_loader):
        if args.attack_mode == 'freq':
            bd_data, bd_label = freq_poison(args, data, label, args.target_label, poisoning_frac=1.0, vars=vars)
        elif args.attack_mode == 'sig':
            bd_data, bd_label = sig_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'square':
            bd_data, bd_label = square_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'blend':
            bd_data, bd_label = blend_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'fiba':
            bd_data, bd_label = fiba_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'ftrojan':
            bd_data, bd_label = ftrojan_poison(args, data, label, args.target_label, poison_ratio=1.0)
        else:
            raise Exception(f'Error, unknown attack mode{args.attack_mode}')

    for idx in random_idx:

        cln_img = data[idx]
        bd_img = bd_data[idx]
        mse = torch.mean((bd_img - cln_img) ** 2)

        if mse == 0:
            psnr = 100
        else:
            psnr = 20 * math.log10(max_value / (math.sqrt(mse)))
            colls.append(psnr)
    print('psnr:', np.mean(colls))



if __name__ == '__main__':

    if args.attack_mode == 'freq':
        save_name = "train_freq_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
            args.poison_ratio)
        vars_save_name = "train_freq_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
            args.poison_ratio) + "_vars"
        vars = torch.load(args.model_dir + vars_save_name + '.pt', map_location=args.device)
    Lpips()
    ssim()
    psnr()
