import argparse
from models.preact_resnet import *
from models.cnn_mnist import *
from models.resnet import *
from utils import *
import torch.nn as nn
import torch.optim as optim
import torchvision
import logging
import os

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--log_dir', type=str, default='./logs/')
    parser.add_argument('--data_dir', type=str, default='./data')
    parser.add_argument('--model', type=str, default="resnet")
    parser.add_argument('--data', type=str, default="celeba")
    parser.add_argument('--device', type=str, default="cuda:0")
    parser.add_argument('--attack_mode', type=str, default="clean",
                        help='clean, square, sig, refool, ftrojan, fiba, freq')
    parser.add_argument("--result", type=str, default="./NC_results")
    parser.add_argument('--model_dir', type=str, default="./saved_model/")
    parser.add_argument('--poison_ratio', type=float, default=0.1)
    parser.add_argument('--num_classes', type=int, default=10)

    # ---------------------------- For Neural Cleanse --------------------------
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=0.01)
    parser.add_argument("--input_height", type=int, default=None)
    parser.add_argument("--input_width", type=int, default=None)
    parser.add_argument("--input_channel", type=int, default=None)
    parser.add_argument("--init_cost", type=float, default=1e-3)
    parser.add_argument("--atk_succ_threshold", type=float, default=99.0)
    parser.add_argument("--early_stop", type=bool, default=True)
    parser.add_argument("--early_stop_threshold", type=float, default=99.0)
    parser.add_argument("--early_stop_patience", type=int, default=25)
    parser.add_argument("--patience", type=int, default=5)
    parser.add_argument("--cost_multiplier", type=float, default=2)
    parser.add_argument("--epoch", type=int, default=10)
    parser.add_argument("--num_workers", type=int, default=8)

    parser.add_argument("--target_label", type=int, default=0)
    parser.add_argument("--total_label", type=int)
    parser.add_argument("--EPSILON", type=float, default=1e-7)

    parser.add_argument("--to_file", type=bool, default=True)
    parser.add_argument("--n_times_test", type=int, default=5)

    return parser.parse_args()

args = parse_args()

class Normalize:
    def __init__(self, args, expected_values, variance):
        self.n_channels = args.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = (x[:, channel] - self.expected_values[channel]) / self.variance[channel]
        return x_clone


class Denormalize:
    def __init__(self, args, expected_values, variance):
        self.n_channels = args.input_channel
        self.expected_values = expected_values
        self.variance = variance
        assert self.n_channels == len(self.expected_values)

    def __call__(self, x):
        x_clone = x.clone()
        for channel in range(self.n_channels):
            x_clone[:, channel] = x[:, channel] * self.variance[channel] + self.expected_values[channel]
        return x_clone


class RegressionModel(nn.Module):
    def __init__(self, args, init_mask, init_pattern):
        self._EPSILON = args.EPSILON
        super(RegressionModel, self).__init__()
        self.mask_tanh = nn.Parameter(torch.tensor(init_mask))
        self.pattern_tanh = nn.Parameter(torch.tensor(init_pattern))

        self.classifier = self._get_classifier(args)
        self.normalizer = self._get_normalize(args)
        self.denormalizer = self._get_denormalize(args)

    def forward(self, x):
        mask = self.get_raw_mask()
        pattern = self.get_raw_pattern()
        if self.normalizer:
            pattern = self.normalizer(self.get_raw_pattern())
        x = (1 - mask) * x + mask * pattern
        return self.classifier(x)

    def get_raw_mask(self):
        mask = nn.Tanh()(self.mask_tanh)
        return mask / (2 + self._EPSILON) + 0.5

    def get_raw_pattern(self):
        pattern = nn.Tanh()(self.pattern_tanh)
        return pattern / (2 + self._EPSILON) + 0.5

    def _get_classifier(self, args):
        if args.data == "mnist":
            classifier = CNN_MNIST()
        elif args.data == "cifar10":
            classifier = PreActResNet18(num_classes=10)
        elif args.data == "gtsrb":
            classifier = PreActResNet18(num_classes=43)
        elif args.data == "imagenet":
            classifier = ResNet18(num_classes=args.num_classes)
        elif args.data == "celeba":
            classifier = ResNet18(num_classes=args.num_classes)
        else:
            raise Exception("Invalid Dataset")
        # Load pretrained classifier

        if args.attack_mode == 'clean':
            save_name = "train_clean" + "_" + args.model + "_" + args.data
        else:
            if args.attack_mode == 'freq':
                save_name = "train_freq_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
            args.poison_ratio)
            else:
                save_name = "train_attack_" + args.attack_mode + "_" + args.model + "_" + args.data + "_" + str(
                    args.poison_ratio)
        classifier = torch.load(args.model_dir + save_name + '.pt', map_location=args.device)

        for param in classifier.parameters():
            param.requires_grad = False
        classifier.eval()
        return classifier.to(args.device)

    def _get_denormalize(self, args):
        if args.data == "cifar10":
            denormalizer = Denormalize(args, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif args.data == "mnist":
            denormalizer = Denormalize(args, [0.5], [0.5])
        elif args.data == "gtsrb":
            denormalizer = None
        elif args.data == "imagenet":
            denormalizer = None
        elif args.data == "celeba":
            denormalizer = None
        else:
            raise Exception("Invalid dataset")
        return denormalizer

    def _get_normalize(self, args):
        if args.data == "cifar10":
            normalizer = Normalize(args, [0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261])
        elif args.data == "mnist":
            normalizer = Normalize(args, [0.5], [0.5])
        elif args.data == "gtsrb":
            normalizer = None
        elif args.data == "imagenet":
            normalizer = None
        elif args.data == "celeba":
            normalizer = None
        else:
            raise Exception("Invalid dataset")
        return normalizer


class Recorder:
    def __init__(self, args):
        super().__init__()

        # Best optimization results
        self.mask_best = None
        self.pattern_best = None
        self.reg_best = float("inf")

        # Logs and counters for adjusting balance cost
        self.logs = []
        self.cost_set_counter = 0
        self.cost_up_counter = 0
        self.cost_down_counter = 0
        self.cost_up_flag = False
        self.cost_down_flag = False

        # Counter for early stop
        self.early_stop_counter = 0
        self.early_stop_reg_best = self.reg_best

        # Cost
        self.cost = args.init_cost
        self.cost_multiplier_up = args.cost_multiplier
        self.cost_multiplier_down = args.cost_multiplier ** 1.5

    def reset_state(self, args):
        self.cost = args.init_cost
        self.cost_up_counter = 0
        self.cost_down_counter = 0
        self.cost_up_flag = False
        self.cost_down_flag = False
        print("Initialize cost to {:f}".format(self.cost))

    def save_result_to_dir(self, args):
        result_dir = os.path.join(args.result, args.data)
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        result_dir = os.path.join(result_dir, args.attack_mode)
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)
        result_dir = os.path.join(result_dir, str(args.target_label))
        if not os.path.exists(result_dir):
            os.makedirs(result_dir)

        pattern_best = self.pattern_best
        mask_best = self.mask_best
        trigger = pattern_best * mask_best

        path_mask = os.path.join(result_dir, "mask.png")
        path_pattern = os.path.join(result_dir, "pattern.png")
        path_trigger = os.path.join(result_dir, "trigger.png")

        torchvision.utils.save_image(mask_best, path_mask, normalize=True)
        torchvision.utils.save_image(pattern_best, path_pattern, normalize=True)
        torchvision.utils.save_image(trigger, path_trigger, normalize=True)


def train(args, init_mask, init_pattern):

    train_dataset, test_dataset = load_dataset(args)
    _, test_loader = load_data(args, train_dataset, test_dataset)

    # Build regression model
    regression_model = RegressionModel(args, init_mask, init_pattern).to(args.device)

    # Set optimizer
    optimizerR = torch.optim.Adam(regression_model.parameters(), lr=args.lr, betas=(0.5, 0.9))

    # Set recorder (for recording best result)
    recorder = Recorder(args)

    for epoch in range(args.epoch):
        early_stop = train_step(regression_model, optimizerR, test_loader, recorder, epoch, args)
        if early_stop:
            break

    # Save result to dir
    recorder.save_result_to_dir(args)

    return recorder, args


def train_step(regression_model, optimizerR, dataloader, recorder, epoch, args):
    print("Epoch {} - Label: {} | {} - {}:".format(epoch, args.target_label, args.data, args.attack_mode))
    # Set losses
    cross_entropy = nn.CrossEntropyLoss()
    total_pred = 0
    true_pred = 0

    # Record loss for all mini-batches
    loss_ce_list = []
    loss_reg_list = []
    loss_list = []
    loss_acc_list = []

    # Set inner early stop flag
    inner_early_stop_flag = False
    for batch_idx, (inputs, labels) in enumerate(dataloader):
        # Forwarding and update model
        optimizerR.zero_grad()

        inputs = inputs.to(args.device)
        sample_num = inputs.shape[0]
        total_pred += sample_num
        target_labels = torch.ones((sample_num), dtype=torch.int64).to(args.device) * args.target_label
        predictions = regression_model(inputs)

        loss_ce = cross_entropy(predictions, target_labels)
        loss_reg = torch.norm(regression_model.get_raw_mask(), 2)
        total_loss = loss_ce + recorder.cost * loss_reg
        total_loss.backward()
        optimizerR.step()

        # Record minibatch information to list
        minibatch_accuracy = torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach() * 100.0 / sample_num
        loss_ce_list.append(loss_ce.detach())
        loss_reg_list.append(loss_reg.detach())
        loss_list.append(total_loss.detach())
        loss_acc_list.append(minibatch_accuracy)

        true_pred += torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach()

    loss_ce_list = torch.stack(loss_ce_list)
    loss_reg_list = torch.stack(loss_reg_list)
    loss_list = torch.stack(loss_list)
    loss_acc_list = torch.stack(loss_acc_list)

    avg_loss_ce = torch.mean(loss_ce_list)
    avg_loss_reg = torch.mean(loss_reg_list)
    avg_loss = torch.mean(loss_list)
    avg_loss_acc = torch.mean(loss_acc_list)

    # Check to save best mask or not
    if avg_loss_acc >= args.atk_succ_threshold and avg_loss_reg < recorder.reg_best:
        recorder.mask_best = regression_model.get_raw_mask().detach()
        recorder.pattern_best = regression_model.get_raw_pattern().detach()
        recorder.reg_best = avg_loss_reg
        recorder.save_result_to_dir(args)
        print(" Updated !!!")

    # Show information
    print(
        "  Result: Accuracy: {:.3f} | Cross Entropy Loss: {:.6f} | Reg Loss: {:.6f} | Reg best: {:.6f}".format(
            true_pred * 100.0 / total_pred, avg_loss_ce, avg_loss_reg, recorder.reg_best
        )
    )

    # Check early stop
    if args.early_stop:
        if recorder.reg_best < float("inf"):
            if recorder.reg_best >= args.early_stop_threshold * recorder.early_stop_reg_best:
                recorder.early_stop_counter += 1
            else:
                recorder.early_stop_counter = 0

        recorder.early_stop_reg_best = min(recorder.early_stop_reg_best, recorder.reg_best)

        if (
            recorder.cost_down_flag
            and recorder.cost_up_flag
            and recorder.early_stop_counter >= args.early_stop_patience
        ):
            print("Early_stop !!!")
            inner_early_stop_flag = True

    if not inner_early_stop_flag:
        # Check cost modification
        if recorder.cost == 0 and avg_loss_acc >= args.atk_succ_threshold:
            recorder.cost_set_counter += 1
            if recorder.cost_set_counter >= args.patience:
                recorder.reset_state(args)
        else:
            recorder.cost_set_counter = 0

        if avg_loss_acc >= args.atk_succ_threshold:
            recorder.cost_up_counter += 1
            recorder.cost_down_counter = 0
        else:
            recorder.cost_up_counter = 0
            recorder.cost_down_counter += 1

        if recorder.cost_up_counter >= args.patience:
            recorder.cost_up_counter = 0
            print("Up cost from {} to {}".format(recorder.cost, recorder.cost * recorder.cost_multiplier_up))
            recorder.cost *= recorder.cost_multiplier_up
            recorder.cost_up_flag = True

        elif recorder.cost_down_counter >= args.patience:
            recorder.cost_down_counter = 0
            print("Down cost from {} to {}".format(recorder.cost, recorder.cost / recorder.cost_multiplier_down))
            recorder.cost /= recorder.cost_multiplier_down
            recorder.cost_down_flag = True

        # Save the final version
        if recorder.mask_best is None:
            recorder.mask_best = regression_model.get_raw_mask().detach()
            recorder.pattern_best = regression_model.get_raw_pattern().detach()

    return inner_early_stop_flag


def outlier_detection(l1_norm_list, idx_mapping, args):
    print("-" * 30)
    print("Determining whether model is backdoor")
    consistency_constant = 1.4826
    median = torch.median(l1_norm_list)
    mad = consistency_constant * torch.median(torch.abs(l1_norm_list - median))
    min_mad = torch.abs(torch.min(l1_norm_list) - median) / mad

    print("Median: {}, MAD: {}".format(median, mad))
    print("Anomaly index: {}".format(min_mad))

    if min_mad < 2:
        print("Not a backdoor model")
    else:
        print("This is a backdoor model")

    if args.to_file:
        result_path = os.path.join(args.result, args.data, args.attack_mode)
        output_path = os.path.join(result_path, "{}_{}_output.txt".format(args.attack_mode, args.data))
        with open(output_path, "a+") as f:
            f.write(
                str(median.cpu().numpy()) + ", " + str(mad.cpu().numpy()) + ", " + str(min_mad.cpu().numpy()) + "\n"
            )
            l1_norm_list_to_save = [str(value) for value in l1_norm_list.cpu().numpy()]
            f.write(", ".join(l1_norm_list_to_save) + "\n")

    flag_list = []
    for y_label in idx_mapping:
        if l1_norm_list[idx_mapping[y_label]] > median:
            continue
        if torch.abs(l1_norm_list[idx_mapping[y_label]] - median) / mad > 2:
            flag_list.append((y_label, l1_norm_list[idx_mapping[y_label]]))

    if len(flag_list) > 0:
        flag_list = sorted(flag_list, key=lambda x: x[1])

    print(
        "Flagged label list: {}".format(",".join(["{}: {}".format(y_label, l_norm) for y_label, l_norm in flag_list]))
    )


def main(args):


    if args.data == "mnist" or args.data == "cifar10":
        args.total_label = 10
    elif args.data == "gtsrb":
        args.total_label = 43
    elif args.data == "imagenet":
        args.total_label = 200
    elif args.data == "celeba":
        args.total_label = 8
    else:
        raise Exception("Invalid Dataset")

    if args.data == "cifar10":
        args.input_height = 32
        args.input_width = 32
        args.input_channel = 3
    elif args.data == "gtsrb":
        args.input_height = 32
        args.input_width = 32
        args.input_channel = 3
    elif args.data == "mnist":
        args.input_height = 28
        args.input_width = 28
        args.input_channel = 1
    elif args.data == "imagenet":
        args.input_height = 64
        args.input_width = 64
        args.input_channel = 3
    elif args.data == "celeba":
        args.input_height = 64
        args.input_width = 64
        args.input_channel = 3
    else:
        raise Exception("Invalid Dataset")

    result_path = os.path.join(args.result, args.data, args.attack_mode)
    if not os.path.exists(result_path):
        os.makedirs(result_path)
    output_path = os.path.join(result_path, "{}_{}_output.txt".format(args.attack_mode, args.data))
    if args.to_file:
        with open(output_path, "w+") as f:
            f.write("Output for neural cleanse: {} - {}".format(args.attack_mode, args.data) + "\n")

    # init_mask = np.random.randn(1, args.input_height, args.input_width).astype(np.float32)
    # init_pattern = np.random.randn(args.input_channel, args.input_height, args.input_width).astype(np.float32)

    init_mask = np.ones((1, args.input_height, args.input_width)).astype(np.float32)
    init_pattern = np.ones((args.input_channel, args.input_height, args.input_width)).astype(np.float32)

    for test in range(args.n_times_test):
        print("Test {}:".format(test))
        if args.to_file:
            with open(output_path, "a+") as f:
                f.write("-" * 30 + "\n")
                f.write("Test {}:".format(str(test)) + "\n")

        masks = []
        idx_mapping = {}

        for target_label in range(args.total_label):
            print("----------------- Analyzing label: {} -----------------".format(target_label))
            args.target_label = target_label
            recorder, args = train(args, init_mask, init_pattern)

            mask = recorder.mask_best
            masks.append(mask)
            idx_mapping[target_label] = len(masks) - 1

        l1_norm_list = torch.stack([torch.sum(torch.abs(m)) for m in masks])
        print("{} labels found".format(len(l1_norm_list)))
        print("Norm values: {}".format(l1_norm_list))
        outlier_detection(l1_norm_list, idx_mapping, args)


if __name__ == "__main__":

    main(args)
