from sensors import Camera, CollisionSensor, LaneInvasionDetector
from pathlib import Path
import itertools

class SensorManager:

    def __init__(self, world, scene_path, extract_per_frames, warmup_frames):
        self.world = world
        self.sensors_groups = dict()
        self.sensor_path = Path("%s/sensor" % scene_path).resolve()
        self.sensor_path.mkdir(exist_ok=True)
        self.extract_per_frames = extract_per_frames
        self.warmup_frames = warmup_frames

    def attach_sensors_to_vehicles(self, vehicles_list, activate_sensors='all'):
        #### Spawn and attach sensors to each vehicle ####
        dimensions = [1280, 720]
        gamma = 2.2
        actor_path = Path("%s/vehicle" % self.sensor_path)
        actor_path.mkdir(exist_ok=True)
        for vehicle in vehicles_list:
            if vehicle.type_id.split('.')[0] != 'vehicle':
                continue
            sensors_dict = {}
            vehicle_path = Path("%s/%s" % (actor_path, vehicle.id))
            vehicle_path.mkdir(exist_ok=True)
            if activate_sensors == 'all':
                sensors_dict["cameras"] = [Camera(idx, ori, vehicle, gamma, dimensions, vehicle_path, self.extract_per_frames, self.warmup_frames)
                                        for idx, ori in itertools.product(range(len(Camera.TYPES)), Camera.ORIENTATIONS)]
                sensors_dict["collision"] = CollisionSensor(vehicle)
                sensors_dict["lane_invasion"] = LaneInvasionDetector(vehicle)
            else:
                if "cameras" in activate_sensors:
                    sensors_dict["cameras"] = [Camera(idx, ori, vehicle, gamma, dimensions, vehicle_path, self.extract_per_frames, self.warmup_frames)
                                        for idx, ori in itertools.product(range(len(Camera.TYPES)), Camera.ORIENTATIONS)]
                if "collision" in activate_sensors:
                    sensors_dict["collision"] = CollisionSensor(vehicle)
                if "lane_invasion" in activate_sensors:
                    sensors_dict["lane_invasion"] = LaneInvasionDetector(vehicle)
                
            self.sensors_groups[vehicle.id] = sensors_dict

    # def get_sensors_group(self, group_name):
    #     return sensors_dict[group_name]

    # TODO possibly destory all as a batch (faster)?
    def destory_all(self):
        for group in self.sensors_groups.values():
            for item in group.values():
                if type(item) == list:
                    for sensor in item:
                        sensor.destroy()
                else:
                    item.destroy()
        self.sensors_groups = dict()











# class SensorManager:

#     def __init__(self, carla_world, scene_path):
#         self.carla_world = carla_world
#         self.sensors_groups = dict()
#         self.sensor_path = Path("%s/sensor" % scene_path).resolve()
#         self.sensor_path.mkdir(exist_ok=True)

#     def attach_sensors_to_vehicles(self, vehicles_list):
#         #### Spawn and attach sensors to each vehicle ####
#         dimensions = [1280, 720]
#         gamma = 2.2
#         actor_path = Path("%s/vehicle" % self.sensor_path)
#         actor_path.mkdir(exist_ok=True)
#         for vehicle in vehicles_list:
#             print("vehicle:", vehicle)
#             # vehicle = self.carla_world.world.get_actor(vehicle_id)
#             sensors_dict = {}
#             vehicle_path = Path("%s/%s" % (actor_path, vehicle.id))
#             vehicle_path.mkdir(exist_ok=True)
#             sensors_dict["cameras"] = [Camera(self.carla_world, idx, ori, vehicle, gamma, dimensions, vehicle_path)
#                                        for idx, ori in itertools.product(range(len(Camera.TYPES)), Camera.ORIENTATIONS)]
#             sensors_dict["collision"] = CollisionSensor(vehicle)
#             # sensors_dict["lane_invasion"] = [LaneInvasionDetector(vehicle)]
#             sensors_dict["lane_invasion"] = LaneInvasionDetector(vehicle)
#             # sensors_dict["collision"] = CollisionSensor(vehicle)
#             self.sensors_groups[vehicle.id] = sensors_dict

#     # def get_sensors_group(self, group_name):
#     #     return sensors_dict[group_name]

#     # TODO possibly destory all as a batch (faster)?
#     def destory_all(self):
#         for group in self.sensors_groups.values():
#             for item in group.values():
#                 if type(item) == list:
#                     for sensor in item:
#                         sensor.destroy()
#                 else:
#                     item.destroy()
#         self.sensors_groups = dict()


