'''
This code generates xy translated and rotated and deformable perturbed samples from static Y.
python -m scripts.generate_oscillation_deformable
'''
import os
import numpy as np
import torch
from torch import Tensor
import torch.nn.functional as F
import skimage.io as skio
import scipy.io as scio
torch.manual_seed(0)
np.random.seed(0)


from typing import Optional, Tuple, List

def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
    ksize_half = (kernel_size - 1) * 0.5

    x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
    pdf = torch.exp(-0.5 * (x / sigma).pow(2))
    kernel1d = pdf / pdf.sum()

    return kernel1d

def _get_gaussian_kernel2d(
    kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
    kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
    kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
    kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
    return kernel2d


# kernel = _get_gaussian_kernel2d(kernel_size=[141, 141], sigma=[25, 25], dtype=torch.float32, device="cpu")
kernel = _get_gaussian_kernel2d(kernel_size=[221, 221], sigma=[40, 40], dtype=torch.float32, device="cpu")
kernel = kernel.expand(1, 1, kernel.shape[0], kernel.shape[1]) # .cuda()

kernel_init = _get_gaussian_kernel2d(kernel_size=[21, 21], sigma=[0.75, 0.75], dtype=torch.float32, device="cpu")
kernel_init = kernel_init.expand(1, 1, kernel_init.shape[0], kernel_init.shape[1]) # .cuda()


def generate_tau(data_shape, rand_noise, tr, ro):
    t, w, h = data_shape
    tau = torch.zeros((t, 2, 3))
    tau[:, 0, 2] = tr * (2 / h) * rand_noise[0]  # affine_grid range [-1, 1]
    tau[:, 1, 2] = tr * (2 / w) * rand_noise[1]
    rand_angle = ro * rand_noise[2]
    tau[:, 0, 0] = torch.cos(rand_angle)
    tau[:, 1, 1] = torch.cos(rand_angle)
    tau[:, 0, 1] = -torch.sin(rand_angle)
    tau[:, 1, 0] = torch.sin(rand_angle)
    return tau


size = [512, 512]
vectors = [torch.arange(0, s) for s in size]
grids = torch.meshgrid(vectors)
grid = torch.stack(grids)
grid = torch.unsqueeze(grid, 0)
grid = grid.type(torch.FloatTensor)

for i in range(len(size)):
    grid[:, i, ...] = 2 * (grid[:, i, ...] / (size[i] - 1) - 0.5)
gtgrid = grid.permute(0, 3, 2, 1)

def main():
    dir = './data/2d_zebrafish_brain_data'
    path = f'{dir}/Y.tif'
    translation_level = np.arange(0, 15, 3, dtype=np.float32)
    rotation_level = np.arange(0, 10, 2, dtype=np.float32) * (np.pi * 1 / 180)
    n_samples = 3 # 5

    if not os.path.exists(f"{dir}/deformable_REALS_data_1.0"):
        os.mkdir(f"{dir}/deformable_REALS_data_1.0")

    for tr in translation_level:
        for ro in rotation_level:
            ro_rad = ro * (180 / np.pi)
            for i in range(n_samples):
                subdir = f"tr_{np.round(tr)}_ro_{np.round(ro_rad)}"
                if not os.path.exists(f"{dir}/deformable_REALS_data_1.0/{subdir}"):
                    os.mkdir(f"{dir}/deformable_REALS_data_1.0/{subdir}")

                Y = torch.from_numpy(skio.imread(path).astype(float)).float()[:512]  # (t,w,h)

                t, w, h = Y.size()
                rand_noise = 2 * torch.rand((3, t)) - 1
                tau = generate_tau(Y.size(), rand_noise, tr, ro)

                tau_3x3 = torch.zeros((t, 3, 3))
                tau_3x3[:, 0:2, :] = tau
                tau_3x3[:, 2, 2] = 1

                Y_reshape = Y.view(t, 1, w, h)
                Y_reshape = torch.nn.functional.conv2d(Y_reshape.cuda(), kernel_init.cuda(), bias=None, stride=1, padding=10).cpu()
                
                # deformable transform part
                flow = torch.randn([60, *gtgrid.size()[1:]]) * 500 / 512 # 100
                flow = flow.permute(0, 3, 1, 2)
                b_, c_, y_, x_ = flow.shape
                flow = flow.reshape(b_ * c_, 1, y_, x_).cuda()

                flow = torch.nn.functional.conv2d(flow, kernel.cuda(), bias=None, stride=1, padding=110).cpu()
                flow = flow.reshape(b_, c_, y_, x_)
                flow = flow.permute(0, 2, 3, 1)

                deform_grid = gtgrid + flow

                Y_reg = F.grid_sample(Y_reshape, deform_grid, align_corners=True)

                # affine transform part
                grid = F.affine_grid(tau, Y_reshape.size(), align_corners=True)

                affine_flow = grid - gtgrid

                Y_reg = F.grid_sample(Y_reg, grid, align_corners=True) # [:, 0, ...]  # (t,w,h)
                Y_reg = Y_reg[:, 0, :, :]

                flow_syn = affine_flow + flow

                # save the result
                skio.imsave(f"{dir}/deformable_REALS_data_1.0/{subdir}/Y_tr_{np.round(tr)}_ro_{np.round(ro_rad)}_{i}.tif", Y_reg.numpy())
                skio.imsave(f"{dir}/deformable_REALS_data_1.0/{subdir}/flow_only_tr_{np.round(tr)}_ro_{np.round(ro_rad)}_{i}.tif", flow.permute(0, 3, 1, 2).numpy())
                skio.imsave(f"{dir}/deformable_REALS_data_1.0/{subdir}/flow_syn_tr_{np.round(tr)}_ro_{np.round(ro_rad)}_{i}.tif", flow_syn.permute(0, 3, 1, 2).numpy())
                scio.savemat(f"{dir}/deformable_REALS_data_1.0/{subdir}/tau_tr_{np.round(tr)}_ro_{np.round(ro_rad)}_{i}.mat", {'tau': tau_3x3.numpy()})
                

if __name__=="__main__":
    main()