import argparse
from data_poison import *
import matplotlib.pyplot as plt
from scipy.stats import norm

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_dir', type=str, default='./logs/')
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--epochs', type=int, default=200)
    parser.add_argument('--model', type=str, default="resnet",
                        help='cnn_mnist, preact_resnet, resnet18')
    parser.add_argument('--data', type=str, default="cifar10",
                        help='mnist, gtsrb, cifar10, imagenet, celeba')
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--model_dir', type=str, default="./saved_model/")
    parser.add_argument('--attack_mode', type=str, default="freq",
                        help='square, sig, refool, ftrojan, fiba, freq')
    parser.add_argument('--batch_size', type=int, default=64)
    parser.add_argument('--poison_ratio', type=float, default=0.1)
    parser.add_argument('--target_label', type=int, default=7)

    parser.add_argument("--detection_boundary", type=float, default=0.2)  # According to the original paper
    parser.add_argument('--n_sample', type=int, default=100)
    parser.add_argument('--n_test', type=int, default=100)
    parser.add_argument('--start_point', type=int, default=1300)

    return parser.parse_args()

args=parse_args()
criterion = nn.CrossEntropyLoss()

class STRIP:
    def __init__(self):
        super().__init__()
        self.n_sample = args.n_sample
        self.device = args.device

    def __call__(self, background, dataset, classifier):
        return self._get_entropy(background, dataset, classifier)

    def _superimpose(self, background, overlay):
        output = cv2.addWeighted(background, 1, overlay, 1, 0)
        if len(output.shape) == 2:
            output = np.expand_dims(output, 2)
        return output

    def _get_entropy(self, background, dataset, classifier):
        entropy_sum = [0] * self.n_sample
        x1_add = [0] * self.n_sample
        index_overlay = np.random.randint(0, len(dataset), size=self.n_sample)
        for index in range(self.n_sample):
            overlay=dataset[index_overlay[index]][0]
            overlay=np.clip(overlay.numpy()*255, 0, 255).astype(np.uint8)
            add_image = self._superimpose(background, overlay)
            add_image = add_image.astype(np.float32)/255
            x1_add[index] = torch.Tensor(add_image)

        py1_add = classifier(torch.stack(x1_add).to(self.device))
        py1_add=torch.softmax(py1_add,1)
        py1_add = py1_add.detach().cpu().numpy()
        entropy_sum = -np.nansum(py1_add * np.log2(py1_add))
        return entropy_sum / self.n_sample

def main():
    # load poisoned model
    if args.attack_mode == 'freq':
        save_name = "train_freq_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
            args.poison_ratio)
        vars_save_name = "train_freq_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
            args.poison_ratio) + "_vars"
        vars = torch.load(args.model_dir + vars_save_name + '.pt', map_location=args.device)
    else:
        save_name = "train_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
            args.poison_ratio)

    path = os.path.join(args.model_dir, f'{save_name}.pt')
    model = torch.load(path,map_location=args.device)
    model = model.to(args.device)
    model.eval()


    train_dataset, test_dataset = load_dataset(args)
    args.batch_size = len(test_dataset)
    _, test_loader = load_data(args, train_dataset, test_dataset)

    strip_detector = STRIP()

    clean_entropy_list=[]
    poison_entropy_list=[]

    for batch_idx, (data, label) in enumerate(test_loader):
        if args.attack_mode == 'freq':
            bd_data, bd_label = freq_poison(args, data, label, args.target_label, poisoning_frac=1.0, vars=vars)
        elif args.attack_mode == 'square':
            bd_data, bd_label = square_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'fiba':
            bd_data, bd_label = fiba_poison(args, data, label, args.target_label, poison_ratio=1.0)
        elif args.attack_mode == 'ftrojan':
            bd_data, bd_label = ftrojan_poison(args, data, label, args.target_label, poison_ratio=1.0)
        for i in range(args.start_point, args.start_point + args.n_test):
            background_pos = bd_data[i]
            lable_pos = bd_label[i]
            background_clean = data[i]
            label_clean = label[i]
            background_pos=np.clip(background_pos.numpy()*255, 0, 255).astype(np.uint8)
            background_clean=np.clip(background_clean.numpy()*255, 0, 255).astype(np.uint8)

            entropy_norm_clean=strip_detector(background_pos, train_dataset, model)
            entropy_norm_pos=strip_detector(background_clean, train_dataset, model)

            clean_entropy_list.append(entropy_norm_clean)
            poison_entropy_list.append(entropy_norm_pos)

    print(clean_entropy_list)
    print(poison_entropy_list)
    min_entropy = min(poison_entropy_list + clean_entropy_list)
    print("Min entropy trojan: {}, Detection boundary: {}".format(min_entropy, args.detection_boundary))
    if min_entropy < args.detection_boundary:
        print("A backdoored model\n")
    else:
        print("Not a backdoor model\n")

if __name__ == '__main__':
    main()
