import torch

from SA2_universalTrigger import *

def RGB2YUV(x_rgb):
    x_rgb = x_rgb.permute(0, 2, 3, 1)  # bs, h, w, ch
    x_yuv = np.zeros(x_rgb.shape, dtype=np.float32)
    for i in range(x_rgb.shape[0]):
        img = cv2.cvtColor(x_rgb[i].numpy(), cv2.COLOR_RGB2YCrCb)
        x_yuv[i] = img
    x_yuv = torch.tensor(x_yuv).permute(0,3,1,2)
    return x_yuv

def YUV2RGB(x_yuv):
    x_yuv = x_yuv.permute(0, 2, 3, 1)
    x_rgb = np.zeros(x_yuv.shape, dtype=np.float32)
    for i in range(x_yuv.shape[0]):
        img = cv2.cvtColor(x_yuv[i].numpy(), cv2.COLOR_YCrCb2RGB)
        x_rgb[i] = img
    x_rgb = torch.tensor(x_rgb).permute(0, 3, 1, 2)
    return x_rgb

def square_poison(args, data, label, target_label, poison_ratio=0.05):

    data = copy.deepcopy(data)
    label = copy.deepcopy(label)
    pattern_size = 3
    margin = 1
    batch_size,c,h,w = data.shape
    poison_num = math.ceil(batch_size * poison_ratio)

    for idx in range(poison_num):
        label[idx] = target_label  # poison label
        mask = torch.zeros((c, h, w))
        pattern = torch.zeros((c, h, w))  # zeros ones
        mask[:, h - margin - pattern_size:h - margin, w - margin - pattern_size:w - margin] = 1
        replace_val = torch.ones([h, w])
        pattern[:, h - margin - pattern_size:h - margin, w - margin - pattern_size:w - margin] = \
            replace_val[h - margin - pattern_size:h - margin, w - margin - pattern_size:w - margin].unsqueeze(0)
        data[idx] = mask * pattern + (1 - mask) * data[idx]

    return data, label

def dynamic_poison(args, data, label, target_label, poison_ratio=0.05):

    data = copy.deepcopy(data)
    label = copy.deepcopy(label)
    pattern_size = 3
    margin = 1
    batch_size,c,h,w = data.shape
    poison_num = math.ceil(batch_size * poison_ratio)

    for idx in range(poison_num):
        label[idx] = target_label  # poison label
        mask = torch.zeros((c, h, w))
        pattern = torch.zeros((c, h, w))  # zeros ones
        mask[:, h - margin - pattern_size:h - margin, w - margin - pattern_size:w - margin] = 1
        replace_val = torch.ones([h, w])
        pattern[:, h - margin - pattern_size:h - margin, w - margin - pattern_size:w - margin] = \
            replace_val[h - margin - pattern_size:h - margin, w - margin - pattern_size:w - margin].unsqueeze(0)
        data[idx] = mask * pattern + (1 - mask) * data[idx]

    return data, label

def sig_poison(args, data, label, target_label, poison_ratio=0.05):

    data = copy.deepcopy(data)
    label = copy.deepcopy(label)
    batch_size, c, h, w = data.shape
    poison_num = math.ceil(batch_size * poison_ratio)

    delta = 20
    f = 6
    sig = torch.zeros([h, w])
    for j in range(w):
        for i in range(h):
            sig[i, j] = delta * torch.sin(torch.tensor(torch.pi * 2 * f * j / w))
    sig = sig.repeat(c, 1, 1) / 255
    for idx in range(poison_num):
        label[idx] = target_label
        data[idx] = torch.clamp(data[idx].float() + sig, 0, 1)

    return data, label


################## figure
def blend_poison(args, data, label, target_label, poison_ratio=0.05):
    if args.data == 'cifar10' or args.data == 'gtsrb':
        img_path = './trigger_image/hellokitty_32.png'
    elif args.data == 'imagenet' or args.data == 'celeba':
        img_path = './trigger_image/hellokitty_224.png'
    img = Image.open(img_path)
    transform = T.ToTensor()
    trigger = transform(img)

    if args.data == 'imagenet' or args.data == 'celeba':
        transform = T.Resize(64)
        trigger = transform(trigger)

    batch_size, c, h, w = data.shape
    poison_num = math.ceil(batch_size * poison_ratio)


    for idx in range(poison_num):
        label[idx] = target_label
        data[idx] = torch.clamp(0.2 * trigger + 0.8 * data[idx], 0, 1)

    return data, label

def get_updated_IDCT_mat_with_vars(args, img_DCT, vars):  # freq
    img_DCT = copy.deepcopy(img_DCT)
    pixels_per_channel = 3
    if args.data == 'cifar10' or args.data == 'gtsrb':
        num_channel = 3
        window_size = 32
    if args.data == 'imagenet':
        num_channel = 3
        window_size = 64
    if args.data == 'celeba':
        num_channel = 3
        window_size = 64
    if args.data == 'mnist':
        num_channel = 1
        window_size = 28
    for i in range(num_channel):
        for j in range(pixels_per_channel):
            freq_x = vars[i * pixels_per_channel + j][0]
            freq_y = vars[i * pixels_per_channel + j][1]
            s      = vars[i * pixels_per_channel + j][2]
            img_DCT[i][freq_x][freq_y] += s
    pic_IDCT = IDCT(img_DCT, window_size=window_size, transpose=False).astype(np.float32)
    return pic_IDCT

def freq_poison(args, data, target, target_label, poisoning_frac, vars):
    data = copy.deepcopy(data)
    target = copy.deepcopy(target)

    poison_number = math.ceil(len(target) * poisoning_frac)

    if args.data == 'cifar10' or args.data == 'gtsrb':
        window_size = 32
    if args.data == 'imagenet' or args.data == 'celeba':
        window_size = 64
    if args.data == 'mnist':
        window_size = 28

    for index in range(poison_number):
        target[index] = target_label
        data[index] = torch.tensor(DCT(data[index].numpy(), window_size=window_size, transpose=False))
        data[index] = torch.tensor(get_updated_IDCT_mat_with_vars(args, data[index].numpy(), vars=vars))

    return data, target

def get_updated_IDCT_mat_with_ftrojan(args, img_DCT, window_size, vars):
    img_DCT = copy.deepcopy(img_DCT)
    num_channel, _, _ = img_DCT.shape
    for i in range(len(vars)):
        freq_x = vars[i][0]
        freq_y = vars[i][1]
        s = vars[i][2]
        for j in range(num_channel-1):
            img_DCT[j+1][freq_x][freq_y] = s
    pic_IDCT = IDCT(img_DCT.numpy(), window_size=window_size, transpose=False).astype(np.float32)
    pic_IDCT /= 255
    return pic_IDCT

def ftrojan_poison(args, data, target, target_label, poison_ratio):
    data = copy.deepcopy(data)
    target = copy.deepcopy(target)
    data = RGB2YUV(data)

    poison_number = math.ceil(len(target) * poison_ratio)

    if args.data == 'cifar10' or args.data == 'gtsrb':
        window_size = 32
        vars = torch.tensor([[15, 15, 30],[31, 31, 30]])
    if args.data == 'imagenet' or args.data == 'celeba':
        window_size = 64
        vars = torch.tensor([[31, 31, 50], [63, 63, 50]])
    if args.data == 'mnist':
        window_size = 28
        vars = torch.tensor([[13, 13, 30], [27, 27, 30]])

    for index in range(poison_number):
        target[index] = target_label
        img = data[index] * 255
        img = img.numpy().astype(np.uint8)
        img_dct = torch.tensor(DCT(img, window_size=window_size, transpose=False))
        data[index] = torch.tensor(get_updated_IDCT_mat_with_ftrojan(args, img_dct, window_size=window_size, vars=vars))

    data = YUV2RGB(data)

    return data, target

def fiba_poison(args, data, target, target_label, poison_ratio):
    data = copy.deepcopy(data)
    target = copy.deepcopy(target)

    poison_number = math.ceil(len(target) * poison_ratio)

    if args.data == 'cifar10' or args.data == 'gtsrb':
        window_size = 32
    if args.data == 'imagenet' or args.data == 'celeba':
        window_size = 64
    if args.data == 'mnist':
        window_size = 28
    trigger_pth = 'trigger_image/coco_val75/000000002157.jpg'
    trigger = cv2.imread(trigger_pth)  # np.array [0,255]
    if args.data == 'mnist':
        trigger = cv2.cvtColor(trigger, cv2.COLOR_BGR2GRAY)
    totensor = T.ToTensor()
    trigger = totensor(trigger)  # tensor [0,1]
    resize = T.Resize([window_size,window_size])
    trigger = resize(trigger)
    beta = 0.1
    alpha = 0.15
    amp_trig_shift, _ = fft_function(trigger)
    for index in range(poison_number):
        target[index] = target_label
        c, h, w = data[index].shape
        b = (np.floor(np.amin((h, w)) * beta)).astype(int)
        # center point
        c_h = torch.floor(torch.tensor(h / 2.0)).to(torch.int)
        c_w = torch.floor(torch.tensor(w / 2.0)).to(torch.int)

        h1 = c_h - b
        h2 = c_h + b + 1
        w1 = c_w - b
        w2 = c_w + b + 1
        amp_img_shift, phase_img = fft_function(data[index])
        amp_img_shift[:, h1:h2, w1:w2] = amp_img_shift[:, h1:h2, w1:w2] * (1 - alpha) + (amp_trig_shift[:, h1:h2, w1:w2]) * alpha
        amp_img_shift = torch.fft.ifftshift(amp_img_shift)
        poisoned_img_fft = amp_img_shift * torch.exp(1j * phase_img)
        poisoned_img_ifft = torch.fft.ifft2(poisoned_img_fft)
        data[index] = torch.real(poisoned_img_ifft)

    return data, target

def fft_function(img):
    img_fft = torch.fft.fft2(img)
    amp_img, phase_img = torch.abs(img_fft), torch.angle(img_fft)
    amp_img_shift = torch.fft.fftshift(amp_img)

    return amp_img_shift, phase_img







