# -*- coding: utf-8 -*-
import math
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from kiui.op import safe_normalize


def get_loop_cameras(num_imgs_in_loop, radius=2.0,
                     max_elevation=np.pi/6, elevation_freq=0.5,
                     azimuth_freq=2.0):

    all_cameras_c2w_cmo = []
    source_cameras = []
    for i in range(num_imgs_in_loop):
        azimuth_angle = np.pi * 2 * azimuth_freq * i / num_imgs_in_loop
        elevation_angle = max_elevation * np.sin(
            np.pi * i * 2 * elevation_freq / num_imgs_in_loop)
        x = np.cos(azimuth_angle) * radius * np.cos(elevation_angle)
        y = np.sin(azimuth_angle) * radius * np.cos(elevation_angle)
        z = np.sin(elevation_angle) * radius

        camera_T_c2w = np.array([x, y, z], dtype=np.float32)

        # in COLMAP / OpenCV convention: z away from camera, y down, x right
        camera_z = - camera_T_c2w / radius
        up = np.array([0, 0, -1], dtype=np.float32)
        camera_x = np.cross(up, camera_z)
        camera_x = camera_x / np.linalg.norm(camera_x)
        camera_y = np.cross(camera_z, camera_x)

        camera_c2w_cmo = np.hstack([camera_x[:, None],
                                    camera_y[:, None],
                                    camera_z[:, None],
                                    camera_T_c2w[:, None]])
        source_cameras.append(torch.tensor(camera_c2w_cmo))
        camera_c2w_cmo = np.vstack([camera_c2w_cmo, np.array([0, 0, 0, 1], dtype=np.float32)[None, :]])

        all_cameras_c2w_cmo.append(camera_c2w_cmo)

    return all_cameras_c2w_cmo, source_cameras


def get_rays(pose, h, w, fovy, opengl=True):

    x, y = torch.meshgrid(
        torch.arange(w, device=pose.device),
        torch.arange(h, device=pose.device),
        indexing="xy",
    )
    x = x.flatten()
    y = y.flatten()

    cx = w * 0.5
    cy = h * 0.5

    focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy))

    camera_dirs = F.pad(
        torch.stack(
            [
                (x - cx + 0.5) / focal,
                (y - cy + 0.5) / focal * (-1.0 if opengl else 1.0),
            ],
            dim=-1,
        ),
        (0, 1),
        value=(-1.0 if opengl else 1.0),
    )  # [hw, 3]

    rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1)  # [hw, 3]
    rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) # [hw, 3]

    rays_o = rays_o.view(h, w, 3)
    rays_d = safe_normalize(rays_d).view(h, w, 3)

    return rays_o, rays_d


class CameraEmbedder(nn.Module):
    """
    Embed camera features to a high-dimensional vector.
    
    Reference:
    DiT: https://github.com/facebookresearch/DiT/blob/main/models.py#L27
    """
    def __init__(self, raw_dim: int, embed_dim: int):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(raw_dim, embed_dim),
            nn.SiLU(),
            nn.Linear(embed_dim, embed_dim),
        )

    @torch.compile
    def forward(self, x):
        return self.mlp(x)
def camera_normalization_objaverse(normed_dist_to_center, poses: torch.Tensor, ret_transform: bool = False):
    assert normed_dist_to_center is not None
    pivotal_pose = compose_extrinsic_RT(poses[:1])
    dist_to_center = pivotal_pose[:, :3, 3].norm(dim=-1, keepdim=True).item() \
        if normed_dist_to_center == 'auto' else normed_dist_to_center
    # compute camera norm (new version)
    canonical_camera_extrinsics = torch.tensor([[
        [1, 0, 0, 0],
        [0, 0, -1, -dist_to_center],
        [0, 1, 0, 0],
        [0, 0, 0, 1],
    ]], dtype=torch.float32)
    pivotal_pose_inv = torch.inverse(pivotal_pose)
    camera_norm_matrix = torch.bmm(canonical_camera_extrinsics, pivotal_pose_inv)

    # normalize all views
    poses = compose_extrinsic_RT(poses)
    poses = torch.bmm(camera_norm_matrix.repeat(poses.shape[0], 1, 1), poses)
    poses = decompose_extrinsic_RT(poses)

    if ret_transform:
        return poses, camera_norm_matrix.squeeze(dim=0)
    return poses

def compose_extrinsic_RT(RT: torch.Tensor):
    """
    Compose the standard form extrinsic matrix from RT.
    Batched I/O.
    """
    return torch.cat([
        RT,
        torch.tensor([[[0, 0, 0, 1]]], dtype=RT.dtype, device=RT.device).repeat(RT.shape[0], 1, 1)
        ], dim=1)

def decompose_extrinsic_RT(E: torch.Tensor):
    """
    Decompose the standard extrinsic matrix into RT.
    Batched I/O.
    """
    return E[:, :3, :]

def get_normalized_camera_intrinsics(intrinsics: torch.Tensor):
    """
    intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
    Return batched fx, fy, cx, cy
    """
    fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
    cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
    width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
    fx, fy = fx / width, fy / height
    cx, cy = cx / width, cy / height
    return fx, fy, cx, cy


def build_camera_principle(RT: torch.Tensor, intrinsics: torch.Tensor):
    """
    RT: (N, 3, 4)
    intrinsics: (N, 3, 2), [[fx, fy], [cx, cy], [width, height]]
    """
    fx, fy, cx, cy = get_normalized_camera_intrinsics(intrinsics)
    return torch.cat([
        RT.reshape(-1, 12),
        fx.unsqueeze(-1), fy.unsqueeze(-1), cx.unsqueeze(-1), cy.unsqueeze(-1),
    ], dim=-1)

def cal_plucker_embedding(K, E):
    # Step 1: Compute camera projection matrix
    K = np.array([[fx, 0, cx],
                [0, fy, cy],
                [0, 0, 1]])
    P = np.hstack((R, T))
    projection_matrix = np.matmul(K, P)

    # Step 2: Generate grid of pixel coordinates
    x, y = np.meshgrid(range(W), range(H))
    pixel_coordinates = np.vstack((x.flatten(), y.flatten(), np.ones_like(x.flatten())))

    # Step 3: Convert pixel coordinates to normalized image coordinates
    normalized_coordinates = pixel_coordinates[:2] / np.array([[W], [H]])

    # Step 4: Apply inverse camera projection
    homogeneous_coordinates = np.vstack((normalized_coordinates, np.ones_like(normalized_coordinates[0])))
    camera_coordinates = np.matmul(np.linalg.inv(projection_matrix), homogeneous_coordinates)

    # Step 5: Calculate Plücker coordinates
    point1 = camera_coordinates[:, :-1]
    point2 = camera_coordinates[:, 1:]
    plucker_coordinates = np.vstack((np.cross(point1, point2, axis=0), point2))

    # Step 6: Reshape Plücker coordinates
    plucker_tensor = torch.tensor(plucker_coordinates).reshape(6, H, W)
    return plucker_tensor
  

if __name__ == '__main__':
    intrinsics = [140.0, 140.0, 64, 64]
    H, W = 128, 128
    view2world = torch.rand(8, 4, 4)
    plucker_embedding = cal_plucker_emb(intrinsics, view2world, W, H)

