import torch
import numpy as np
import os

import sys
sys.path.append('../../code')
sys.path.append('../../code/models')
import utils.general as utils
import utils.rend_utils as rend_utils
import cv2

from options import RendererOptions
from render_wrapper import Renderer

from PIL import Image


def extract_sketch(image, edit):
    #base_image, edit_image: PIL images or numpy images (H, W, 3)
    #Output:(H, W) numpy binary image: 1 for sketch, 0 o.w.
    if isinstance(image, Image.Image):
        image = np.array(image).astype('int32')
    if isinstance(edit, Image.Image):
        edit = np.array(edit).astype('int32')
    diff = np.abs(image - edit).astype('uint8')
    diff = cv2.cvtColor(diff, cv2.COLOR_RGB2GRAY)
    #smooth out diff
    diff = cv2.GaussianBlur(diff, (5, 5), 0)
    diff = (diff > 0.5)
    Image.fromarray((diff * 255).astype('uint8')).save('diff.png')
    return diff.astype('float')


def extract_bbox(sketch):
    #sketch: (H, W) numpy binary image: 1 for sketch, 0 o.w.
    #Output: (x1, y1, x2, y2) bounding box of sketch
    #given a sketch extract the bounding box
    sketch = (sketch * 255.).astype('uint8')
    contours, _ = cv2.findContours(sketch, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    x, y, w, h = cv2.boundingRect(contours[0])
    return x, y, x + w, y + h


def fill_sketch_func(sketch):
    #sketch: (H, W) numpy binary image: 1 for sketch, 0 o.w.
    #Output: (H, W) numpy binary image: 1 for sketch, 0 o.w.
    #given a sketch fill it in
    sketch = (sketch * 255.).astype('uint8')
    contours, _ = cv2.findContours(sketch, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    mask = np.zeros(sketch.shape, dtype = np.uint8)
    points = contours[0]
    mask = cv2.fillPoly(mask, [points], 255)
    return (mask / 255.).astype('float')
def box_bitfield(sketch_poses, sketch_intrinsics):
    coords = rend_utils.grid_coords(128, normalize = True)
    projections = rend_utils.batch_proj_points2image(coords, sketch_poses, sketch_intrinsics)
    num_sketches = len(sketch_poses)
    mask = torch.ones((coords.shape[0]), dtype = torch.bool).to('cuda')
    for i in range(num_sketches):
        p = projections[i]
        bbox = bboxes[i]
        #extract points from p that lie inside bbox
        m = (p[:, 0] > bbox[0]) & (p[:, 0] < bbox[2]) & (p[:, 1] > bbox[1]) & (p[:, 1] < bbox[3])
        mask = mask & m

    return mask
class Sketches:
    def __init__(self, sketch_dir, H = None, W = None, device = 'cuda', scale = 1, offset = [0, 0, 0], type = 'blender', fill_sketch = True,  preprocess_sketch = True):

        self.H = H
        self.W = W
        self.sketches = []

        if type == 'blender':
            self.transforms_file = os.path.join(sketch_dir, 'transforms_sketch.json')
            self.transforms = utils.load_json(self.transforms_file)

            frames = self.transforms['frames']
            self.poses = []


            for frame in frames:
                sketch_file = os.path.join(sketch_dir, frame['file_path']) + '.png'
                sketch = cv2.imread(sketch_file, cv2.IMREAD_UNCHANGED)
                if sketch.ndim == 2:
                    #sketch saved as mask
                    mask = (sketch > 0.5).astype('uint8')
                else:
                    if sketch.shape[-1] == 4:
                        #sketches as RGBA images
                        sketch = cv2.cvtColor(sketch, cv2.COLOR_BGRA2RGBA)
                        mask = (sketch[..., 3] > 0.5).astype('uint8')
                    else:
                        #sketch saved as RGB mask. all channels are the same
                        mask = (sketch[..., 0] > 0.5).astype('uint8')

                if self.H is None and self.W is None:
                    self.H, self.W = mask.shape[:2]
                else:
                    mask = cv2.resize(mask, (self.W, self.H), interpolation = cv2.INTER_AREA)

                pose = np.array(frame['transform_matrix'], dtype=np.float32)
                pose = rend_utils.nerf_matrix_to_ngp(pose, scale=scale, offset=offset)
                pose = np.array([[0, 0, -1, 0], [0, 1, 0, 0], [-1, 0, 0, 0], [0, 0, 0, 1]]).astype(np.float32) @ pose
                self.sketches.append(mask)
                self.poses.append(pose)

            self.poses = torch.from_numpy(np.stack(self.poses, axis=0)).to(device)
            self.sketches = torch.from_numpy(np.stack(self.sketches, axis=0)).to(device, dtype=torch.float32)

            #config intrinsics
            fl_x = self.W / (2 * np.tan(self.transforms['camera_angle_x'] / 2)) if 'camera_angle_x' in self.transforms else None
            fl_y = self.H / (2 * np.tan(self.transforms['camera_angle_y'] / 2)) if 'camera_angle_y' in self.transforms else None
            if fl_x is None: fl_x = fl_y
            if fl_y is None: fl_y = fl_x
            cx = (self.transforms['cx']) if 'cx' in self.transforms else (self.W / 2)
            cy = (self.transforms['cy']) if 'cy' in self.transforms else (self.H / 2)

            self.intrinsics = np.array([fl_x, fl_y, cx, cy])
        else:
            self.transforms_file = os.path.join(sketch_dir, 'meta_data.pkl')
            self.transforms = utils.load_pickle(self.transforms_file)
            self.intrinsics = self.transforms['intrinsics']
            self.poses = self.transforms['poses']
            sketch_files = sorted(os.listdir(os.path.join(sketch_dir, 'sketches')))
            self.sketches = []
            self.bounding_boxes = []
            self.silhouettes = []
            for f in sketch_files:
                f_sketch = os.path.join(sketch_dir, 'sketches', f)
                sketch = cv2.imread(f_sketch, cv2.IMREAD_UNCHANGED)
                sketch = (sketch / 255.).astype('float32')
                if sketch.shape[-1] == 4:
                    #sketches is RGBA images
                    sketch = cv2.cvtColor(sketch, cv2.COLOR_BGRA2RGBA)
                    sketch = utils.rgba2rgb(sketch, input_type = 'numpy')
                if preprocess_sketch:
                    base = Image.open(os.path.join(sketch_dir, 'bases', f))
                    sketch = extract_sketch(base, (sketch * 255.).astype('uint8'))
                    if fill_sketch:
                        sketch = fill_sketch_func(sketch)
                else:
                    #turn rgb mask to binary image
                    sketch = cv2.cvtColor((sketch * 255.).astype('uint8'), cv2.COLOR_RGB2GRAY)
                    sketch = (sketch < 128).astype('float32')

                h = sketch.shape[0]
                if self.H is None and self.W is None:
                    self.H, self.W = sketch.shape[:2]
                else:
                    sketch = cv2.resize(sketch, (self.W, self.H), interpolation = cv2.INTER_AREA)

                self.sketches.append(sketch)
                #extract bounding box of sketch
                bbox = extract_bbox(sketch)
                self.bounding_boxes.append(bbox)

                s = (sketch * 255).astype('uint8')
                r = cv2.rectangle(s, (bbox[0], bbox[1]), (bbox[2], bbox[3]), 128, 2)
                Image.fromarray(s).save(f)

                #extract silhouette of sketch
                silhouette = Image.open(os.path.join(sketch_dir, 'shapes', f))
                silhouette = np.array(silhouette)
                silhouette = cv2.resize(silhouette, (self.W, self.H), interpolation = cv2.INTER_AREA)
                silhouette = (silhouette / 255.).astype('float32')
                silhouette[sketch > 0.5] = 1
                self.silhouettes.append(silhouette)

            self.silhouettes = torch.from_numpy(np.stack(self.silhouettes, axis=0)).to(device, dtype=torch.float32)
            self.sketches = torch.from_numpy(np.stack(self.sketches, axis=0)).to(device, dtype=torch.float32)
            self.bounding_boxes = np.stack(self.bounding_boxes, axis=0)
            self.intrinsics = self.intrinsics * (self.H / h)
            a =0


        ##precompute rays to speed up rendering
        self.rays = rend_utils.get_rays(self.poses, self.intrinsics, self.H, self.W, -1)
        
    def get_sketches(self, indices = None):
        if indices is not None:
            return self.sketches[indices], self.poses[indices], torch.from_numpy(self.intrinsics).to(self.sketches.device).reshape(1, 4).repeat(len(indices), 1), self.bounding_boxes[indices], self.silhouettes[indices], self.rays[indices]
        else:
            #return all sketches, poses, intrinsics
            return self.sketches, self.poses, torch.from_numpy(self.intrinsics).to(self.sketches.device).reshape(1, 4).repeat(len(self.sketches), 1), self.bounding_boxes, self.silhouettes, self.rays

if __name__ == '__main__':

    sketches = Sketches('../../data/glass', H = 256, W = 256, type = 'manual', fill_sketch= False, preprocess_sketch=False)

    canvases, poses, intrinsics, bboxes, silhouettes, rays = sketches.get_sketches()
    a = 0
    renderer = Renderer(RendererOptions())

    pose = renderer.get_pose_from_angles(radius = 3.5, thetas=[np.pi / 2], phis=[np.pi / 3])
    image = renderer.get_image(poses = pose, intrinsics = sketches.intrinsics, return_pil=True, save_path=None)

    sample = (128, 70)

    #draw circle
    image = cv2.circle(np.array(image), sample, 5, (255, 0, 0), -1)
    rays = rend_utils.get_rays(pose, sketches.intrinsics, sketches.H, sketches.W, -1)
    rays_o = rays['rays_o'].reshape(sketches.H, sketches.W, 3)
    rays_d = rays['rays_d'].reshape(sketches.H, sketches.W, 3)

    ro = rays_o[sample[1], sample[0]]
    rd = rays_d[sample[1], sample[0]]

    #sample from ray
    t = torch.linspace(3, 4, 10).to('cuda')
    points = ro + rd * t.reshape(-1, 1)

    proj_points = rend_utils.batch_proj_points2image(points, poses, intrinsics)

    p1 = proj_points[0].reshape(-1, 2).detach().cpu().numpy()

    canvas = (canvases[0].unsqueeze(-1).repeat(1, 1, 3).cpu().numpy() * 255.).astype('uint8')
    for i in range(len(p1)):
        p = p1[i]
        canvas = cv2.circle(np.array(canvas), (int(p[0]), int(p[1])), 2, (0, 255, 0), -1)

    canvas = Image.fromarray(canvas)
    canvas.save('canvas_1.png')
    a  =0
