import torch


def test_interpolate_feature(pts, feat, require_jac=True):
    """
    pts: sat_uv,
    feat: sat_feat [b,c,a,a]
    """
    b, c, h, w = feat.size()
    # mask = self.mask_in_image(pts, [W, H], pad=pad)
    scale = torch.tensor([w, h]).to(pts)
    pts = (pts / scale) * 2 - 1
    mask_ = (pts < 1) & (pts > -1)
    mask = torch.where(mask_[..., 0] & mask_[..., 1], True, False)
    mask = mask.unsqueeze(dim=-1)

    feat_pts = F.grid_sample(feat, pts[:, None], mode='bilinear', align_corners=True)  # [b, N, c]
    feat_pts = feat_pts.reshape(b, c, -1).transpose(-1, -2)

    if require_jac:
        dxdy = torch.tensor([[1, 0], [0, 1]])[:, None].to(pts) / scale * 2
        dx, dy = dxdy.chunk(2, dim=0)
        pts_d = torch.cat([pts - dx, pts + dx, pts - dy, pts + dy], 1)
        tensor_d = torch.nn.functional.grid_sample(
            feat, pts_d[:, None], mode='bilinear', align_corners=True)
        tensor_d = tensor_d.reshape(b, c, -1).transpose(-1, -2)
        tensor_x0, tensor_x1, tensor_y0, tensor_y1 = tensor_d.chunk(4, dim=1)
        gradients = torch.stack([
            (tensor_x1 - tensor_x0) / 2, (tensor_y1 - tensor_y0) / 2], dim=-1)
    else:
        gradients = None

    return pts, mask, feat_pts, gradients

def test_run_all(seed=0):
    torch.random.manual_seed(seed)
    w, h = 480, 240

    pts = torch.rand(1000, 2) * torch.tensor([w-1, h-1])
    tensor = torch.rand(16, h, w)*100

    _, _, _, J_analytical = interpolate_feature(pts, tensor)

    # test_interpolate_cubic_opencv(tensor, pts)
    # test_interpolate_cubic_gradients(tensor, pts)

if __name__ == '__main__':
    test_run_all()
