import logging
import math
import weakref
from collections import defaultdict

import carla
import numpy as np
import pygame
from carla import ColorConverter as cc


# ==============================================================================
# -- Global functions ----------------------------------------------------------
# ==============================================================================


class LaneInvasionDetector(object):
    def __init__(self, parent_actor):
        self.sensor = None
        self._parent = parent_actor
        world = self._parent.get_world()
        bp = world.get_blueprint_library().find('sensor.other.lane_invasion')
        self.sensor = world.spawn_actor(bp, carla.Transform(), attach_to=self._parent)
        # We need to pass the lambda a weak reference to self to avoid circular
        # reference.
        weak_self = weakref.ref(self)
        self.sensor.listen(lambda event: LaneInvasionDetector._on_invasion(weak_self, event))

        self.recording = True

        self.lane_invasion_events = []

    def toggle_recording(self):
        self.recording = not self.recording

    def destroy(self):
        if self.sensor:
            self.sensor.destroy()

    def is_invading_lane(self, frame):
        if self.lane_invasion_events:
            return frame < max(self.lane_invasion_events) + 10
        else:
            return False

    @staticmethod
    def _on_invasion(weak_self, event):
        self = weak_self()
        if not self:
            return
        if self.recording:
            self.lane_invasion_events.append(event.frame)


class CollisionSensor(object):
    def __init__(self, parent_actor):
        self.sensor = None
        self.history = []
        self.collision = False
        self.recording = True
        self._parent = parent_actor
        world = self._parent.get_world()
        bp = world.get_blueprint_library().find('sensor.other.collision')
        self.sensor = world.spawn_actor(bp, carla.Transform(), attach_to=self._parent)
        # We need to pass the lambda a weak reference to self to avoid circular
        # reference.
        weak_self = weakref.ref(self)
        self.sensor.listen(lambda event: CollisionSensor._on_collision(weak_self, event))

    def get_collision_history(self):
        history = defaultdict(int)
        for frame, intensity in self.history:
            history[frame] += intensity
        return history

    def has_collided(self):
        return self.collision

    def toggle_recording(self):
        self.recording = not self.recording

    @staticmethod
    def _on_collision(weak_self, event):
        self = weak_self()
        if not self:
            return
        if self.recording:
            self.collision = True
            impulse = event.normal_impulse
            intensity = math.sqrt(impulse.x ** 2 + impulse.y ** 2 + impulse.z ** 2)
            self.history.append((event.frame, intensity))
            if len(self.history) > 4000:
                self.history.pop(0)

    def destroy(self):
        if self.sensor:
            self.sensor.destroy()


# ==============================================================================
# -- CameraManager -------------------------------------------------------------
# ==============================================================================

class PerceptionSensor(object):
    ORIENTATIONS = {"front": 0.0, "left": -90.0, "back": 180.0, "right": 90.0, }

    def __init__(self, ori, parent_actor, gamma_correction, dimensions):
        self.sensor = None
        self.world = parent_actor.get_world()
        self._parent = parent_actor
        self.gamma_correction = gamma_correction
        self.dimensions = dimensions
        bound_y = 0.5 + self._parent.bounding_box.extent.y
        Attachment = carla.AttachmentType
        # self._camera_transforms = [
        #     # third person view (back)
        #     (carla.Transform(carla.Location(x=-5.5, z=2.5), carla.Rotation(pitch=8.0)), Attachment.SpringArm),
        #     # first person view (to front)
        #     (carla.Transform(carla.Location(x=1.6, z=1.7)), Attachment.Rigid),
        #     # third person view (front)
        #     (carla.Transform(carla.Location(x=5.5, y=1.5, z=1.5)), Attachment.SpringArm),
        #     # third person view (high)
        #     (carla.Transform(carla.Location(x=-8.0, z=6.0), carla.Rotation(pitch=6.0)), Attachment.SpringArm),
        #     # low (aside from the car)
        #     (carla.Transform(carla.Location(x=-1, y=-bound_y, z=0.5)), Attachment.Rigid)]
        loc = parent_actor.bounding_box.location
        x, y, z = float(loc.x), float(loc.y), float(loc.z)
        camera_height = 1
        z = z + float(parent_actor.bounding_box.extent.z) + camera_height
        self.cam_transforms = (
            carla.Transform(carla.Location(x=x, y=y, z=z), carla.Rotation(yaw=self.ORIENTATIONS[ori])),
            Attachment.Rigid)
        self.transform_index = 0

    def destroy(self):
        self.sensor.destroy()


class Camera(PerceptionSensor):
    TYPES = [
        ['sensor.camera.rgb', cc.Raw, 'Camera RGB'],
        # ['sensor.camera.depth', cc.Raw, 'Camera Depth (Raw)'],
        ['sensor.camera.depth', cc.Depth, 'Camera Depth (Gray Scale)'],
        ['sensor.camera.depth', cc.LogarithmicDepth, 'Camera Depth (Logarithmic Gray Scale)'],
        # ['sensor.camera.semantic_segmentation', cc.Raw, 'Camera Semantic Segmentation (Raw)'],
        ['sensor.camera.semantic_segmentation', cc.CityScapesPalette,
         'Camera Semantic Segmentation (CityScapes Palette)']
    ]

    def __init__(self, camera_type_id, ori, parent_actor, gamma_correction, dimensions, storing_path, extract_per_frames, warmup_frames):
        super().__init__(ori, parent_actor, gamma_correction, dimensions)
        self.storing_path = storing_path

        self.camera_type_id = camera_type_id
        self.bp = self.get_blueprints(camera_type_id)
        weak_self = weakref.ref(self)

        self.sensor = self.world.spawn_actor(
            self.bp,
            self.cam_transforms[0],
            attach_to=self._parent,
            attachment_type=self.cam_transforms[1])
        self.sensor.listen(lambda image: Camera._parse_image(weak_self, self.camera_type_id, ori, image, extract_per_frames, warmup_frames))

    @staticmethod
    def _parse_image(weak_self, idx, orientation, image, extract_per_frames, warmup_frames):
        self = weak_self()
        if not self:
            return
        # snapshot = self.world.get_snapshot()
        # if snapshot:
        #     frame = snapshot.timestamp.frame
        # print(frame, image.frame)
        if image.frame % extract_per_frames != 0 or image.frame < warmup_frames:
            return
        image.convert(Camera.TYPES[idx][1])
        array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8"))
        array = np.reshape(array, (image.height, image.width, 4))
        array = array[:, :, :3]

        image.save_to_disk('%s/perception/camera/%s/%s/%08d.jpg'
                           % (str(self.storing_path), Camera.TYPES[idx][-1], orientation, image.frame))

    def get_blueprints(self, camera_type_id):
        bp_library = self.world.get_blueprint_library()
        bp = bp_library.find(self.TYPES[camera_type_id][0])
        bp.set_attribute('image_size_x', str(self.dimensions[0]))
        bp.set_attribute('image_size_y', str(self.dimensions[1]))
        if bp.has_attribute('gamma'):
            bp.set_attribute('gamma', str(self.gamma_correction))
        return bp



# class Lidar(PerceptionSensor):
#     TYPES = [['sensor.lidar.ray_cast', None, 'Lidar (Ray-Cast)']]

#     def __init__(self, lidar_type_id, parent_actor, gamma_correction, dimensions, storing_path):
#         super().__init__(lidar_type_id, parent_actor, gamma_correction, dimensions)
#         self.storing_path = storing_path

#         self.bp = self.get_blueprints(lidar_type_id)
#         self.sensor = self.world.spawn_actor(
#             self.bp,
#             self._camera_transforms[self.transform_index][0],
#             attach_to=self._parent,
#             attachment_type=self._camera_transforms[self.transform_index][1])

#         weak_self = weakref.ref(self)
#         self.sensor.listen(lambda image: Lidar._parse_image(weak_self, self.lidar_type_id, image))

#     # TODO How to save lidar info?
#     @staticmethod
#     def _parse_image(weak_self, idx, image):
#         self = weak_self()
#         if not self:
#             return
#         points = np.frombuffer(image.raw_data, dtype=np.dtype('f4'))
#         points = np.reshape(points, (int(points.shape[0] / 3), 3))
#         lidar_data = np.array(points[:, :2])
#         lidar_data *= min(self.dimensions) / 100.0
#         lidar_data += (0.5 * self.dimensions[0], 0.5 * self.dimensions[1])
#         lidar_data = np.fabs(lidar_data)  # pylint: disable=E1111
#         lidar_data = lidar_data.astype(np.int32)
#         lidar_data = np.reshape(lidar_data, (-1, 2))
#         lidar_img_size = (self.dimensions[0], self.dimensions[1], 3)
#         lidar_img = np.zeros((lidar_img_size), dtype=int)
#         lidar_img[tuple(lidar_data.T)] = (255, 255, 255)
#         self.surface = pygame.surfarray.make_surface(lidar_img)

#         image.save_to_disk(
#             '%s/perception/lidar/%s/%08d.jpg' % (str(self.storing_path), Lidar.TYPES[idx][-1], image.frame))

#     def get_blueprints(self, lidar_type_id):
#         bp_library = self.world.get_blueprint_library()
#         bp = bp_library.find(self.TYPES[lidar_type_id][0])
#         bp.set_attribute('image_size_x', str(self.dimensions[0]))
#         bp.set_attribute('image_size_y', str(self.dimensions[1]))
#         if bp.has_attribute('gamma'):
#             bp.set_attribute('gamma', str(self.gamma_correction))
#         return bp


# class CameraManager(object):
#     # TODO definable camera sensors list
#     def __init__(self, client, parent_actor, gamma_correction, dimensions, storing_path):
#         self.sensors = []
#         self.client = client
#         self.world = parent_actor.get_world()
#         self.surface = None
#         self._parent = parent_actor
#         self.recording = False
#         self.dimensions = dimensions
#         self.storing_path = storing_path
#         bound_y = 0.5 + self._parent.bounding_box.extent.y
#         Attachment = carla.AttachmentType
#         self._camera_transforms = [
#             (carla.Transform(carla.Location(x=-5.5, z=2.5), carla.Rotation(pitch=8.0)), Attachment.SpringArm),
#             (carla.Transform(carla.Location(x=1.6, z=1.7)), Attachment.Rigid),
#             (carla.Transform(carla.Location(x=5.5, y=1.5, z=1.5)), Attachment.SpringArm),
#             (carla.Transform(carla.Location(x=-8.0, z=6.0), carla.Rotation(pitch=6.0)), Attachment.SpringArm),
#             (carla.Transform(carla.Location(x=-1, y=-bound_y, z=0.5)), Attachment.Rigid)]
#         self.transform_index = 1
#         # TODO do better dict reference instead of list indexing
#         self.sensor_types = [
#             ['sensor.camera.rgb', cc.Raw, 'Camera RGB'],
#             ['sensor.camera.depth', cc.Raw, 'Camera Depth (Raw)'],
#             ['sensor.camera.depth', cc.Depth, 'Camera Depth (Gray Scale)'],
#             ['sensor.camera.depth', cc.LogarithmicDepth, 'Camera Depth (Logarithmic Gray Scale)'],
#             ['sensor.camera.semantic_segmentation', cc.Raw, 'Camera Semantic Segmentation (Raw)'],
#             ['sensor.camera.semantic_segmentation', cc.CityScapesPalette,
#              'Camera Semantic Segmentation (CityScapes Palette)'],
#             ['sensor.lidar.ray_cast', None, 'Lidar (Ray-Cast)']]
#         bp_library = self.world.get_blueprint_library()
#         for item in self.sensor_types:
#             bp = bp_library.find(item[0])
#             if item[0].startswith('sensor.camera'):
#                 bp.set_attribute('image_size_x', str(dimensions[0]))
#                 bp.set_attribute('image_size_y', str(dimensions[1]))
#                 if bp.has_attribute('gamma'):
#                     bp.set_attribute('gamma', str(gamma_correction))
#             elif item[0].startswith('sensor.lidar'):
#                 bp.set_attribute('range', '5000')
#             item.append(bp)

#         # self.sensor = self._parent.get_world().spawn_actor(
#         #     self.sensor_types[index][-1],
#         #     self._camera_transforms[self.transform_index][0],
#         #     attach_to=self._parent,
#         #     attachment_type=self._camera_transforms[self.transform_index][1])

#         # TODO Understand each camera_trainsforms
#         # TODO no Command adaptation for attachment_type = [rigit,...]
#         print(self.sensor_types)
#         # batch = [
#         #     carla.command.SpawnActor(
#         #         sensor[idx][-1],
#         #         self._camera_transforms[self.transform_index][0],
#         #         self._parent)
#         #     for idx, sensor in enumerate(self.sensor_types)]

#         batch = []
#         for idx, sensor in enumerate(self.sensor_types):
#             s = sensor[-1]
#             t = self._camera_transforms[self.transform_index][0]
#             p = self._parent
#             print(s)
#             print(t)
#             print(p)
#             batch.append(carla.command.SpawnActor(s, t, p))

#         results = self.client.apply_batch_sync(batch, True)
#         # for _ in range(10):
#         #     print("here")
#         for i in range(len(results)):
#             if results[i].error:
#                 logging.error(results[i].error)
#             else:
#                 actid = results[i].actor_id
#                 self.sensors.append(self.world.get_actor(actid))

#         weak_self = weakref.ref(self)
#         for idx, sensor in enumerate(self.sensors):
#             sensor.listen(lambda image: CameraManager._parse_image(weak_self, idx, image))

#     @staticmethod
#     def _parse_image(weak_self, idx, image):
#         print(idx)
#         self = weak_self()
#         print(self.sensor_types[idx][0])
#         if not self:
#             return
#         if self.sensor_types[idx][0].startswith('sensor.lidar'):
#             # points = np.frombuffer(image.raw_data, dtype=np.dtype('f4'))
#             # points = np.reshape(points, (int(points.shape[0] / 3), 3))
#             # lidar_data = np.array(points[:, :2])
#             # lidar_data *= min(self.dimensions) / 100.0
#             # lidar_data += (0.5 * self.dimensions[0], 0.5 * self.dimensions[1])
#             # lidar_data = np.fabs(lidar_data)  # pylint: disable=E1111
#             # lidar_data = lidar_data.astype(np.int32)
#             # lidar_data = np.reshape(lidar_data, (-1, 2))
#             # lidar_img_size = (self.dimensions[0], self.dimensions[1], 3)
#             # lidar_img = np.zeros((lidar_img_size), dtype = int)
#             # lidar_img[tuple(lidar_data.T)] = (255, 255, 255)
#             # self.surface = pygame.surfarray.make_surface(lidar_img)
#             return
#         else:
#             image.convert(self.sensor_types[idx][1])
#             array = np.frombuffer(image.raw_data, dtype=np.dtype("uint8"))
#             array = np.reshape(array, (image.height, image.width, 4))
#             array = array[:, :, :3]
#             array = array[:, :, ::-1]
#             self.surface = pygame.surfarray.make_surface(array.swapaxes(0, 1))

#         image.save_to_disk('%s/raw_images/%08d.jpg' % (str(self.storing_path), image.frame))

#     # def destroy(self):
#     #     if self.sensor:
#     #         self.sensor.destroy()
