import copy
import math
import random
import matplotlib.pyplot as plt
import numpy

import numpy as np
import math

import utils
from utils import *
import torch
import torch.nn as nn

list_CE = []
list_dif_spatial = []
list_dif_frequency = []
# random.seed(65)
# torch.manual_seed(66)

def dataloader2DCTdataloader(data_loader):

    data_loader_DCT = copy.deepcopy(data_loader)
    for batch_idx, (data, label) in enumerate(data_loader_DCT):
        for sample_seq in range(0, len(data)):
            data[sample_seq] = torch.tensor(utils.DCT(data[sample_seq].numpy(),32,False))
    return data_loader_DCT


class SA2_universarial:

    def __init__(self,
                 #iter=100, T0=100, Tf=0.01, alpha=0.99,
                 iter=5, T0=1, Tf=0.01, alpha=0.99,
                 model=None, training_set=None, test_set=None, source_label=None, target_label=None,
                 train_poisoning_frac=0.1, test_poisoning_frac=1.0,
                 epsilon=0.5, num_channels=3,
                 dataset = 'cifar10',
                 device='cpu'):

        self.prob_mut_freq = 0.8
        self.prob_mut_strength = 0.8
        self.Metrospolis_param = 0.2
        self.freq_mut_scale = 0.2
        self.strength_mut_scale = 0.2
        self.device = device
        self.model = model.to(device)

        self.retrain_epochs = 10

        # concerning training and test data
        self.trainings_set = training_set
        self.test_set = test_set
        self.DCT_training_set = dataloader2DCTdataloader(training_set)
        self.DCT_test_set = dataloader2DCTdataloader(test_set)
        self.train_poisoning_frac = train_poisoning_frac
        self.test_poisoning_frac = test_poisoning_frac

        for idx, (data, label) in enumerate(test_set):
            self.img = data[0]
            break

        self.source_label = source_label
        self.target_label = target_label

        self.iter = iter
        self.alpha = alpha
        self.T0 = T0
        self.Tf = Tf
        self.T = T0
        self.vars = None
        self.fitness = None

        self.num_channels = num_channels
        self.epsilon = epsilon
        self.pixels_per_channel = 3 #pixels per row or column (we assume the picture is a square rather than just rectangular, so the width and height is the same )
        self.num_units = self.pixels_per_channel * self.num_channels
        self.unit_l_bound = np.array([0, 0, 0])
        self.dataset= dataset
        # self.unit_u_bound = np.array([31, 31, self.epsilon])

        if self.dataset == 'cifar10' or self.dataset == 'gtsrb':
            self.unit_u_bound = np.array([31, 31, self.epsilon])
            # self.unit_u_bound = np.array([13, 13, self.epsilon])
        if self.dataset == 'imagenet' or self.dataset == 'celeba':
            # self.unit_u_bound = np.array([63, 63, self.epsilon])
            self.unit_u_bound = np.array([27, 27, self.epsilon])
        if self.dataset == 'mnist':
            self.unit_u_bound = np.array([13, 13, self.epsilon])

        self.history = {'f': [], 'T': []}
        self.terminate_flag = False
        self.criterion = nn.CrossEntropyLoss()

        self.best_vars = None
        self.best_fitness = None

        def eval_func(vars):

            if self.DCT_test_set is None or self.DCT_training_set is None:
                raise Exception('Error, the training or test set has not been converted to DCT set')

            model = copy.deepcopy(self.model)
            model.train()
            optimizer = torch.optim.SGD(model.parameters(), lr=0.1, )  # lr=0.01, momentum=0.9, weight_decay=5e-4
            training_set_poisoned = copy.deepcopy(self.DCT_training_set)
            #train
            for epoch in range(0, self.retrain_epochs):
                for batch_idx, (data, label) in enumerate(training_set_poisoned):
                    poisoned_samples = 0.0
                    for sample_idx in range(0, len(data)):
                        data[sample_idx] = torch.tensor(self.get_updated_IDCT_mat_with_vars(data[sample_idx].numpy(), vars))
                        label[sample_idx] = self.target_label
                        poisoned_samples += 1
                        if poisoned_samples/len(data) >= self.train_poisoning_frac:
                            break
                    data = data.to(device=self.device)
                    label = label.to(device=self.device)
                    output = model(data)
                    optimizer.zero_grad()
                    loss = self.criterion(output, label.view(-1))
                    loss.backward()
                    optimizer.step()
                    print("T: "+str(self.T)+", retrain epoch: "+str(epoch)+", batch= "+str(batch_idx)+", training loss: ", loss.item())
                    break
            #validation
            loss_accumulate = 0
            L2Norm_diff_spatial_space_accumulate = 0
            model.eval()
            test_set_poisoned = copy.deepcopy(self.DCT_test_set)
            correctly_labeled_samples = 0
            total_test_number = 0
            for batch_idx, (p1, p2) in enumerate(zip(test_set_poisoned, self.test_set)):
                poisoned_samples = 0.0
                data_poisoned = p1[0]
                label_poisoned = p1[1]
                data_clean = p2[0]
                label_clean = p2[1]
                for sample_idx in range(0, len(data_poisoned)):
                    data_poisoned[sample_idx] = torch.tensor(self.get_updated_IDCT_mat_with_vars(data_poisoned[sample_idx].numpy(), vars))
                    label_poisoned[sample_idx] = self.target_label
                    poisoned_samples += 1.0
                    if poisoned_samples / len(data_poisoned) >= self.test_poisoning_frac:
                        break
                data_poisoned = data_poisoned.to(device=self.device)
                label_poisoned = label_poisoned.to(device=self.device)
                with torch.no_grad():
                    output = model(data_poisoned)
                total_test_number += len(output)
                _, pred_labels = torch.max(output, 1)
                pred_labels = pred_labels.view(-1)
                if np.sum(pred_labels.numpy() == self.target_label) > 0:
                    self.terminate_flag = True
                    print('***some samples are successfully poisoned with respect to target label***')

                correctly_labeled_samples += torch.sum(torch.eq(pred_labels, label_poisoned)).item()
                loss_accumulate += self.criterion(output, label_poisoned.view(-1)).item()
                L2Norm_diff_spatial_space_accumulate += self.calc_batch_diff_norm(data_poisoned, data_clean,2)

            perturb_norm_DCT = np.linalg.norm([vars[i][2] for i in range(len(vars))], ord=2)
            acc = correctly_labeled_samples / total_test_number
            bn_acc = utils.test_model(self.model, test_set, self.device, False)
            print('Test accuracy (benign):'+str(bn_acc))
            print("validation loss (accumulated loss of all batches): " + str(loss_accumulate) +", validation accuracy:" + str(acc))

            fit = loss_accumulate + (self.T-self.Tf)/(self.T-self.Tf) * (perturb_norm_DCT * 0.2)
            list_CE.append(loss_accumulate)
            list_dif_spatial.append(L2Norm_diff_spatial_space_accumulate)
            list_dif_frequency.append(perturb_norm_DCT)
            print('fitness: ', fit)
            if self.best_fitness is None or self.best_fitness > fit:
                self.best_fitness = fit
                self.best_vars = copy.deepcopy(vars)
            return fit, pred_labels.numpy()

        self.func = eval_func

        vec = []
        for i in range(self.num_units):
            p1 = random.randint(self.unit_l_bound[0], self.unit_u_bound[0])
            p2 = random.randint(self.unit_l_bound[1], self.unit_u_bound[1])
            p3 = random.random() * (self.unit_u_bound[2] - self.unit_l_bound[2]) * self.strength_mut_scale
            vec.append([p1, p2, p3])
        self.vars = vec
        self.fitness, _ = self.func(self.vars)


    def calc_diff_norm(self, img1, img2, p=None):
        if img1.shape[0] > 3 or img2.shape[0] > 3:
            raise Exception('Error, in calc_diff_norm, the imput image has more than 3 channels')
        if img1.shape[0] != img2.shape[0]:
            raise Exception('error, the dim of original image and poisoned image is different')
        diff = img1.flatten() - img2.flatten()
        return np.linalg.norm(diff, ord=p)#L2 norm should below 1.5
    def calc_batch_diff_norm(self, img1_batch, img2_batch, p=None):
        diff_sum = 0
        for idx_img in range(0, len(img1_batch)):
            diff_sum += self.calc_diff_norm(img1_batch[idx_img], img2_batch[idx_img], p)
        return diff_sum

    def get_updated_IDCT_mat_with_vars(self, img_DCT, vars):
        img_DCT = copy.deepcopy(img_DCT)
        for i in range(self.num_channels):
            for j in range(self.pixels_per_channel):
                freq_x = vars[i * self.pixels_per_channel + j][0]
                freq_y = vars[i * self.pixels_per_channel + j][1]
                s      = vars[i * self.pixels_per_channel + j][2]
                img_DCT[i][freq_x][freq_y] += s
        if self.dataset == 'cifar10' or self.dataset == 'gtsrb':
            window_size = 32
        if self.dataset == 'imagenet' or self.dataset == 'celeba':
            window_size = 64
        if self.dataset == 'mnist':
            window_size = 28
        pic_IDCT = IDCT(img_DCT, window_size=window_size, transpose=False).astype(np.float32)
        return pic_IDCT

    def check_pixel_legality(self, img):
        if np.any(img < 0):
            return False
        if sum(img.flatten()) / img.size <= 1.0 and np.any(img > 1.0):
            return False
        elif sum(img.flatten()) / img.size > 1 and np.any(img > 255):
            return False
        return True


    def generate_new(self, vars, img_ori_spatial=None):
        revised = False
        while True:
            vars_new = copy.deepcopy(vars)
            if random.random() < self.prob_mut_freq:
                revised = True
                unit_idx = random.randint(0, self.num_units - 1)
                while True:
                    p1 = random.randint(self.unit_l_bound[0],self.unit_u_bound[0])
                    p2 = random.randint(self.unit_l_bound[0],self.unit_u_bound[0])
                    t_ratio = ((self.T-self.Tf)/(self.T0-self.Tf))
                    x_pert = int(t_ratio * self.freq_mut_scale * (p1-p2))
                    p3 = random.randint(self.unit_l_bound[1], self.unit_u_bound[1])
                    p4 = random.randint(self.unit_l_bound[1], self.unit_u_bound[1])
                    y_pert = int(t_ratio * self.freq_mut_scale * (p3-p4))
                    if (0<=(vars_new[unit_idx][0]+x_pert)<=31) and (0<=(vars_new[unit_idx][1]+y_pert)<=31):
                        vars_new[unit_idx][0] += x_pert
                        vars_new[unit_idx][1] += y_pert
                        break
            if random.random() < self.prob_mut_strength:
                revised = True
                while True:
                    vars_new_copy = copy.deepcopy(vars_new)
                    t_ratio = ((self.T-self.Tf) / (self.T0 - self.Tf))
                    unit_idx = random.randint(0, self.num_units - 1)
                    strength_pert = t_ratio * self.strength_mut_scale * (self.unit_u_bound[2]-self.unit_l_bound[2])*(2*random.random()-1)
                    vars_new_copy[unit_idx][2] += strength_pert
                    vars_new = vars_new_copy
                    break
            if revised:
                break
        return vars_new

    def Metrospolis(self, f, f_new):  # Metropolis准则
        if f_new < f:
            return 1
        else:
            p = math.exp(-(f_new-f)/(self.T*self.Metrospolis_param))
            if random.random() < p:
                return 1
            else:
                return 0

    def best(self):
        f_list = []
        for i in range(self.iter):
            f, _ = self.func(self.x[i], self.y[i])
            f_list.append(f)
        f_best = min(f_list)

        idx = f_list.index(f_best)
        return f_best, idx

    def obtain_final_image(self):
        if self.dataset == 'cifar10' or self.dataset == 'gtsrb':
            window_size = 32
        if self.dataset == 'imagenet' or self.dataset == 'celeba':
            window_size = 64
        if self.dataset == 'mnist':
            window_size = 28
        img_DCT = utils.DCT(self.img.numpy(),window_size=window_size,transpose=False)
        for i in range(self.num_channels):
            for j in range(self.pixels_per_channel):
                freq_x = self.vars[i * self.pixels_per_channel + j][0]
                freq_y = self.vars[i * self.pixels_per_channel + j][1]
                s = self.vars[i * self.pixels_per_channel + j][2]
                img_DCT[i][freq_x][freq_y] = s

        pic_IDCT = IDCT(img_DCT, window_size=window_size, transpose=False).astype(np.float32)
        return pic_IDCT

    def run(self):
        while self.T > self.Tf:
            for i in range(self.iter):
                print('T='+str(self.T)+', iter='+str(i), end='')
                f = self.fitness
                vars_new = self.generate_new(self.vars)
                f_new, pred_label = self.func(vars_new)
                if self.terminate_flag:
                    self.vars = vars_new
                    self.fitness = f_new
                    self.history['f'].append(self.fitness)
                    self.history['T'].append(self.T)
                    show_pic(self.img, title="clean image (freq attack)")
                    show_pic(self.obtain_final_image(), title="poisoned image (freq attack)")
                    show_trigger(self.vars, self.num_channels, self.pixels_per_channel, dataset=self.dataset)
                    return self.vars

                    # return self.obtain_final_image()

                if self.Metrospolis(f, f_new):
                    self.vars = vars_new
                    self.fitness = f_new
                    self.history['f'].append(self.fitness)
                    self.history['T'].append(self.T)
            self.T = self.T * self.alpha

        show_pic(self.obtain_final_image())

        x_axis_values = [i for i in range(len(list_CE))]
        plt.plot(x_axis_values,list_CE,'r--',label='CE')
        plt.plot(x_axis_values, list_dif_frequency, 'g--', label='diff_freq')
        plt.plot(x_axis_values,list_dif_spatial,'b--',label='diff_spatial')
        plt.title('metric statistic along training')
        plt.legend()
        plt.show()

        # return self.obtain_final_image()
        return self.vars


