import os
import cv2

import numpy as np

import torch
from torch.utils.data import DataLoader

import tqdm
import sys
sys.path.append('../')
import utils.general as utils
from utils.rend_utils import nerf_matrix_to_ngp, rand_poses, get_rays

class NeRFDataset:
    def __init__(self, data_dir, opts, type='train', downscale=1):
        super().__init__()
        self.opts = opts
        self.device = self.opts.device
        self.type = type # train, val, test
        self.downscale = downscale
        self.root_path = data_dir
        self.preload = self.opts.preload # preload data into GPU
        self.scale = self.opts.scale # camera radius scale to make sure camera are inside the bounding box.
        self.offset = self.opts.offset # camera offset
        self.bound = self.opts.bound # bounding box half length, also used as the radius to random sample poses.
        self.fp16 = self.opts.fp16 # if preload, load into fp16.

        self.training = self.type in ['train', 'all', 'trainval']
        self.num_rays = self.opts.num_rays if self.training else -1

        self.rand_pose = self.opts.rand_pose

        if os.path.exists(os.path.join(self.root_path, 'transforms.json')):
            self.mode = 'colmap' # manually split, use view-interpolation for test.
        elif os.path.exists(os.path.join(self.root_path, 'transforms_train.json')):
            self.mode = 'blender' # provided split
        else:
            raise ValueError("Unknown dataset format.")
        if self.mode == 'colmap':
            transform = utils.load_json(os.path.join(self.root_path, 'transforms.json'))
        elif self.mode == 'blender':
            transform = utils.load_json(os.path.join(self.root_path, 'transforms_train.json'))
        if 'h' in transform and 'w' in transform:
            self.H = transform['h'] // downscale
            self.W = transform['w'] // downscale
        else:
            self.H = None
            self.W = None

        frames = transform['frames']
        self.poses = []
        self.images = []
        for i, f in enumerate(frames):
            f_path = os.path.join(self.root_path, f['file_path'])
            if self.mode == 'blender':
                f_path += '.png'
            if not os.path.exists(f_path):
                continue
            image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED)
            if self.H is None and self.W is None:
                self.H = image.shape[0] // downscale
                self.W = image.shape[1] // downscale
            if image.shape[-1] == 3:
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            else:
                image = cv2.cvtColor(image, cv2.COLOR_BGRA2RGBA)

            if downscale > 1:
                image = cv2.resize(image, (self.W, self.H), interpolation=cv2.INTER_AREA)
            image = image.astype(np.float32) / 255.

            pose = np.array(f['transform_matrix'], dtype=np.float32)
            pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset)

            self.images.append(image)
            self.poses.append(pose)
        self.poses = torch.from_numpy(np.stack(self.poses, axis=0)) # [N, 4, 4]
        self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, 3/4]

        self.num_images = self.images.shape[0]

        self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item()
        if self.training and self.opts.error_map and not self.opts.latent_mode:
            self.error_map = torch.ones([self.num_images, 128 * 128], dtype=torch.float) # [B, 128 * 128], flattened for easy indexing, fixed resolution...
        else:
            self.error_map = None

        if self.preload and not self.opts.latent_mode:
            if self.fp16:
                dtype = torch.half
            else:
                dtype = torch.float
            self.poses = self.poses.to(self.device)
            self.images = self.images.to(self.device, dtype=dtype)
            if self.error_map is not None:
                self.error_map = self.error_map.to(self.device)

        #load intrinsics
        if 'fl_x' in transform or 'fl_y' in transform:
            fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale
            fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale
        elif 'camera_angle_x' in transform or 'camera_angle_y' in transform:
            # blender, assert in radians. already downscaled since we use H/W
            fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None
            fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None
            if fl_x is None: fl_x = fl_y
            if fl_y is None: fl_y = fl_x
        else:
            raise RuntimeError('Failed to load focal length, please check the transforms.json!')

        cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2)
        cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2)

        self.intrinsics = np.array([fl_x, fl_y, cx, cy])

    @torch.no_grad()
    def precompute_latents(self, sd_model, batch_size = 1):
        self.latents = []
        i = 0
        pbar = tqdm.tqdm(total=self.num_images, bar_format='{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]')
        pbar.set_description('Precomputing latent codes')
        while i < self.num_images:
            j = min(i + batch_size, self.num_images)
            if self.fp16:
                images = self.images[i:j].to(self.device, dtype=torch.half)
                images = utils.rgba2rgb(images, bkgd_map = None)
                images = images.permute(0, 3, 1, 2)
                latents = sd_model.encode_imgs(images)

                self.latents.append(latents)
            pbar.update(j - i)
            i = j

        self.latents = torch.cat(self.latents, dim=0)

        #precompute background latent
        bkgd = torch.ones([1, 3, self.H, self.W], dtype = torch.half, device=self.device)
        self.latent_bkgd = sd_model.encode_imgs(bkgd).reshape(1, 4, -1).permute(0, 2, 1)

        torch.cuda.empty_cache()
    def collate(self, index, pose = None):
        B = len(index)
        if self.rand_pose or index[0] >= self.num_images:

            poses = rand_poses(B, self.device, radius=self.radius)

            # sample a low-resolution but full image for CLIP
            s = np.sqrt(self.H * self.W / self.num_rays) # only in training, assert num_rays > 0
            rH, rW = int(self.H / s), int(self.W / s)
            rays = get_rays(poses, self.intrinsics / s, rH, rW, -1)

            return {
                'H': rH,
                'W': rW,
                'rays_o': rays['rays_o'],
                'rays_d': rays['rays_d'],
            }
        poses = self.poses[index].to(self.device) if pose is None else pose.to(self.device) # [B, 4, 4]

        error_map = None if self.error_map is None else self.error_map[index]

        scale_factor = 8 if self.opts.latent_mode else 1
        sampling_h = self.H // scale_factor if self.opts.latent_mode else self.H
        sampling_w = self.W // scale_factor if self.opts.latent_mode else self.W
        intrinsics = self.intrinsics.copy()
        intrinsics /= scale_factor
        rays = get_rays(poses, intrinsics , sampling_h, sampling_w, self.num_rays, error_map, self.opts.patch_size)

        results = {
            'H': sampling_h,
            'W': sampling_w,
            's': scale_factor,
            'rays_o': rays['rays_o'],
            'rays_d': rays['rays_d'],
        }
        if not (pose is None):
            return results

        if self.images is not None:
            images = self.images[index].to(self.device) # [B, H, W, 3/4]
            if self.training:
                C = images.shape[-1]
                if not self.opts.latent_mode:
                    images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4]
                else:
                    latents = self.latents[index].to(self.device)
                    latents = torch.gather(latents.view(B, C, -1), -1, torch.stack(C * [rays['inds']], 1)).permute(0, 2, 1) # [B, N, 3/4]
                    images = images.reshape(B, -1, C) # [B, 512 * 512, 3]
                    #latents = self.latents[index].to(self.device)
                    results['latents'] = latents
                    results['inds'] = rays['inds']
            results['images'] = images

        # need inds to update error_map
        if error_map is not None:
            results['index'] = index
            results['inds_coarse'] = rays['inds_coarse']

        return results
    def dataloader(self):
        size = self.num_images
        if self.training and self.rand_pose > 0:
            size += size // self.rand_pose # index >= size means we use random pose.
        loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0)
        loader._data = self # an ugly fix... we need to access error_map & poses in trainer.
        loader.has_gt = self.images is not None
        return loader

