import json
import math
from collections import defaultdict
from pathlib import Path
import re
import os
import carla


def find_weather_presets():
    rgx = re.compile('.+?(?:(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])|$)')
    name = lambda x: ' '.join(m.group(0) for m in rgx.finditer(x))
    presets = [x for x in dir(carla.WeatherParameters) if re.match('[A-Z].+', x)]
    return [(getattr(carla.WeatherParameters, x), name(x)) for x in presets]


def build_dict_lane_single(lane_waypoint):
    return {
        'location_x': lane_waypoint.transform.location.x,
        'location_y': lane_waypoint.transform.location.y,
        'location_z': lane_waypoint.transform.location.z,
        'lane_id': lane_waypoint.lane_id,
        'road_id': lane_waypoint.road_id,
        'lane_type': lane_waypoint.lane_type.name,
        'lane_width': lane_waypoint.lane_width,
        'right_lane_marking_type': lane_waypoint.right_lane_marking.type.name,
        'left_lane_marking_type': lane_waypoint.left_lane_marking.type.name,
        'lane_change': lane_waypoint.lane_change.name,
        'is_junction': lane_waypoint.is_junction,
        ### adding transform if possible.
    }

def get_actor_display_name(actor, truncate=250):
    name = ' '.join(actor.type_id.replace('_', '.').title().split('.')[1:])
    return (name[:truncate - 1] + u'\u2026') if len(name) > truncate else name

def to_json_friendly(vec):
    if type(vec) in [float, int, bool]:
        return vec
    if type(vec) == list:
        ret = []
        for v in vec:
            ret.append(to_json_friendly(v))
        return ret

    parsers = [lambda v: (float(v.x), float(v.y), float(v.z)),
              lambda v: (float(v.x), float(v.y)),
              lambda v: (float(v.pitch), float(v.yaw), float(v.roll)),
              lambda v: (float(v.ratio), float(v.down_ratio), float(v.up_ratio))
              ]

    for parser in parsers:
        try:
            return parser(vec)
        except:
            pass

    raise NotImplementedError


def get_traffic_light_attributes(traffic_light):
    return_dict = defaultdict()
    t_3d = traffic_light.get_transform()
    l_3d = t_3d.location
    r_3d = t_3d.rotation
    state = traffic_light.get_state()
    pole_idx = traffic_light.get_pole_index()

    return_dict['location'] = float(l_3d.x), float(l_3d.y), float(l_3d.z)
    return_dict['rotation'] =  float(r_3d.yaw), float(r_3d.roll), float(r_3d.pitch)
    return_dict['bounding_box'] = {
        'extent': to_json_friendly(traffic_light.bounding_box.extent),
        'location': to_json_friendly(traffic_light.bounding_box.location),
        'rotation': to_json_friendly(traffic_light.bounding_box.rotation)
    }
    return_dict['state'] = state
    return_dict['pole_idx'] = pole_idx
    return return_dict


def get_actor_attributes(actor, waypoint=None):
    velocity = lambda l: (3.6 * math.sqrt(l.x**2 + l.y**2 + l.z**2))
    # dv = lambda l: (3.6 * math.sqrt((l.x-v.x)**2 + (l.y-v.y)**2 + (l.z-v.z)**2))
    # distance = lambda l: math.sqrt((l.x - t.location.x)**2 + (l.y - t.location.y)**2 + (l.z - t.location.z)**2)

    return_dict = defaultdict()
    v_3d = actor.get_velocity()
    t_3d = actor.get_transform()
    l_3d = t_3d.location
    r_3d = t_3d.rotation
    a_3d = actor.get_angular_velocity()
    acc_3d = actor.get_acceleration()
    attributes = actor.attributes

    return_dict['id'] = actor.id
    return_dict['velocity_abs'] = float(velocity(v_3d))
    return_dict['velocity'] = float(v_3d.x), float(v_3d.y), float(v_3d.z)
    return_dict['location'] = float(l_3d.x), float(l_3d.y), float(l_3d.z)
    return_dict['rotation'] =  float(r_3d.yaw), float(r_3d.roll), float(r_3d.pitch)
    return_dict['ang_velocity'] = float(a_3d.x), float(a_3d.y), float(a_3d.z)
    return_dict['name'] = get_actor_display_name(actor)
    return_dict['acceleration'] = float(acc_3d.x), float(acc_3d.y), float(acc_3d.z)
    if(waypoint):
        return_dict['lane_id'] = waypoint.lane_id
        return_dict['road_id'] = waypoint.road_id
        return_dict["waypoint_info"] = build_dict_lane_single(waypoint)

    return_dict['bounding_box'] = {
        'extent': to_json_friendly(actor.bounding_box.extent),
        'location': to_json_friendly(actor.bounding_box.location),
        'rotation': to_json_friendly(actor.bounding_box.rotation)
    }


    if re.match("vehicle.*", actor.type_id):
        try:
            return_dict['color'] = attributes['color']
        except:
            return_dict['color'] = 'Unknown'
        return_dict['generation'] = attributes['generation']
        return_dict['number_of_wheels'] = attributes['number_of_wheels']

        vehicle_control = actor.get_control()
        return_dict['vehicle_control'] = {
            "throttle": vehicle_control.throttle,
            "steer": vehicle_control.steer,
            "brake": vehicle_control.brake,
            "hand_brake": vehicle_control.hand_brake,
            "reverse": vehicle_control.reverse,
            "manual_gear_shift": vehicle_control.manual_gear_shift,
            "gear": vehicle_control.gear
        }

        for feature in return_dict['vehicle_control']:
            return_dict['vehicle_control'][feature] = to_json_friendly(return_dict['vehicle_control'][feature])


        # try:
        #     physics_control = actor.get_physics_control()
        #     return_dict['physics_control'] = {
        #         "torque_curve": physics_control.torque_curve,
        #         "max_rpm": physics_control.max_rpm,
        #         "moi": physics_control.moi,
        #         "damping_rate_full_throttle": physics_control.damping_rate_full_throttle,
        #         "damping_rate_zero_throttle_clutch_engaged": physics_control.damping_rate_zero_throttle_clutch_engaged,
        #         "damping_rate_zero_throttle_clutch_disengaged": physics_control.damping_rate_zero_throttle_clutch_disengaged,
        #         "use_gear_autobox": physics_control.use_gear_autobox,
        #         "gear_switch_time": physics_control.gear_switch_time,
        #         "clutch_strength": physics_control.clutch_strength,
        #         "final_ratio": physics_control.final_ratio,
        #         "forward_gears": physics_control.forward_gears,
        #         "drag_coefficient": physics_control.drag_coefficient,
        #         "center_of_mass": physics_control.center_of_mass,
        #         "steering_curve": physics_control.steering_curve
        #     }
        #     for feature in return_dict['physics_control']:
        #         return_dict['physics_control'][feature] = to_json_friendly(return_dict['physics_control'][feature])
        # except:
        #     pass



    elif re.match("walker.*", actor.type_id):
        # TODO
        # gender, [adult, child], etc.
        pass
    elif actor.type_id == 'traffic.traffic_light':
        # TODO
        pass
    elif actor.type_id == 'traffic.traffic_sign':
        # TODO
        pass
    
    return return_dict


def get_vehicle_attributes(vehicle, waypoint=None):
    return_dict = get_actor_attributes(vehicle, waypoint)
    
    light_state = vehicle.get_light_state()
    #light_state variables are booleans
    return_dict['left_blinker_on'] = True if (light_state.LeftBlinker & carla.VehicleLightState.LeftBlinker > 0) else False
    return_dict['right_blinker_on'] = True if (light_state.RightBlinker & carla.VehicleLightState.LeftBlinker > 0) else False
    return_dict['brake_light_on'] = True if (light_state.Brake & carla.VehicleLightState.Brake > 0) else False
    return return_dict


class DataExtractor(object):
    
    def __init__(self, world, sensor_manager, ego, other_actors, store_path, detect_range=100, extract_per_frames=1, warmup_frames=30):
        self.world = world
        self.output_dir = (Path(store_path) / 'scene_raw').resolve()
        self.sensor_manager = sensor_manager

        self.output_dir.mkdir(exist_ok=True)

        self.framedict=defaultdict()
        self.ego = ego
        self.other_actors = other_actors
        self.orig_ego_lane_idx = None

        self.detect_range = detect_range
        self.extract_per_frames = extract_per_frames
        self.warmup_frames = warmup_frames

    def extract_frame(self):
        snapshot = self.world.get_snapshot()
        if snapshot:
            frame = snapshot.timestamp.frame
        if frame % self.extract_per_frames != 0 or frame < self.warmup_frames:
            return
        map1 = self.world.get_map()
        t = self.ego.get_transform()
        ego_location = self.ego.get_location()
        distance = lambda l: math.sqrt((l.x - t.location.x)**2 + (l.y - t.location.y)**2 + (l.z - t.location.z)**2)

        vehicles = self.world.get_actors().filter('vehicle.*')
        pedestrians = self.world.get_actors().filter('walker.*')
        trafficlights = self.world.get_actors().filter('traffic.traffic_light')
        signs = self.world.get_actors().filter('traffic.traffic_sign')
        
        egodict = defaultdict()
        actordict = defaultdict()
        peddict = defaultdict()
        lightdict = defaultdict()
        signdict = defaultdict()
        lanedict = defaultdict()
        staticdict = defaultdict()


        waypoint = map1.get_waypoint(ego_location, project_to_road=True, lane_type=(carla.LaneType.Driving | carla.LaneType.Shoulder | carla.LaneType.Sidewalk))

        ego_lane = waypoint
                


        def build_dict_lane(lane_waypoint, distance=self.detect_range):
            ## called by each neighboring lane.
            lane_dict = {}
            lane_dict["curr"] = [build_dict_lane_single(lane_waypoint)]
            lane_dict["next"] = [build_dict_lane_single(next_waypoint) for next_waypoint in lane_waypoint.next(distance)] #d = 100
            lane_dict["prev"] = [build_dict_lane_single(next_waypoint) for next_waypoint in lane_waypoint.previous(distance)]
            return lane_dict

        def build_lanes(src_lane, direction="left"):
            lanes = []
            cur_lane = src_lane # starting the src_lane (ego)
            while True:
                lane = cur_lane.get_left_lane() if direction == "left" else cur_lane.get_right_lane()
                if lane is None:
                    break 
                if lane.lane_type in [carla.LaneType.Shoulder, carla.LaneType.Sidewalk]:
                    break
                if cur_lane.lane_id * lane.lane_id < 0: ## special handling.
                    break
                lanes.append(build_dict_lane(lane))
                cur_lane = lane
            return lanes
        
        def get_actor_lane_idx(lanes, lane_id, road_id):
            for idx, lane in enumerate(lanes):
                for key, lane_list in lane.items():
                    for lane_dict in lane_list:
                        if lane_dict['lane_id'] == lane_id and lane_dict['road_id'] == road_id:
                            return idx, key
            return None, None

        # 1. build the road topology based where the ego is at. 
        # 2. build the lane idx by our own. not using the opendriving idx.
        #    also we store waypoints of next(100) and previous(100) for each lane. 

        # lane 0: [current, next1, next2, previous 1, previous 2] next(50), 
        # lane 1: [current, next1, next2, previous 1, previous 2]
        # lane 2: ego lane [current, next1, next2, previous 1, previous 2 

        ## build the new lane dictionary and systems. 
        left_lanes = build_lanes(waypoint)[::-1]
        right_lanes = build_lanes(waypoint, direction="right")
        lanes = left_lanes + [build_dict_lane(ego_lane)] + right_lanes
        lanedict['lanes'] = lanes
        lanedict['ego_lane_idx'] = len(left_lanes)

        egodict = get_vehicle_attributes(self.ego, waypoint)
        egodict['lane_idx'] = lanedict['ego_lane_idx'] 
        if self.orig_ego_lane_idx == None:
            self.orig_ego_lane_idx = lanedict['ego_lane_idx'] 
        egodict['orig_lane_idx'] = self.orig_ego_lane_idx
        egodict["lane_invasion"] = \
            self.sensor_manager.sensors_groups[self.ego.id]["lane_invasion"].is_invading_lane(frame)
            
        # TODO: calculate the lane invasion direction by calculating the relative position between the vehicle and lane center.
        # if egodict["lane_invasion"]:
        #     if lane_change_direction == "left":
        #         lane_id = self.orig_ego_lane_idx - 1
        #     else:
        #         lane_id = self.orig_ego_lane_idx + 1
        #     egodict["invading_lane"] = lane_id
        egodict["invading_lane"] = self.orig_ego_lane_idx
            
        egodict["collision"] = \
            self.sensor_manager.sensors_groups[self.ego.id]["collision"].has_collided()

        # export data from surrounding vehicles
        # if len(vehicles) > 1:
        #     for vehicle in vehicles:
        #         # TODO: change the 100m condition to field of view.
        #         if vehicle.id != self.ego.id and distance(vehicle.get_location()) < self.detect_range:
        #             vehicle_wp = map1.get_waypoint(vehicle.get_location(), project_to_road=True, lane_type=(carla.LaneType.Driving | carla.LaneType.Shoulder | carla.LaneType.Sidewalk))
        #             vehicle_dict = get_vehicle_attributes(vehicle, vehicle_wp)
        #             vehicle_dict["lane_idx"], vehicle_dict["relative_position"] = get_actor_lane_idx(lanes, vehicle_dict["waypoint_info"]['lane_id'], vehicle_dict["waypoint_info"]['road_id']) # the found lane_idx
        #             if vehicle_dict["lane_idx"] is not None:
        #                 actordict[vehicle.id] = vehicle_dict
        #                 actordict[vehicle.id]["lane_invasion"] = \
        #                     self.sensor_manager.sensors_groups[vehicle.id]["lane_invasion"].is_invading_lane(frame)
        #                 actordict[vehicle.id]["collision"] = \
        #                     self.sensor_manager.sensors_groups[vehicle.id]["collision"].has_collided()
        
        for actor in self.other_actors:
            actor_type = actor.type_id.split('.')[0]
            if actor_type == 'vehicle': 
                if distance(actor.get_location()) < self.detect_range:
                    vehicle_wp = map1.get_waypoint(actor.get_location(), project_to_road=True, lane_type=(carla.LaneType.Driving | carla.LaneType.Shoulder | carla.LaneType.Sidewalk))
                    vehicle_dict = get_vehicle_attributes(actor, vehicle_wp)
                    vehicle_dict["lane_idx"], vehicle_dict["relative_position"] = get_actor_lane_idx(lanes, vehicle_dict["waypoint_info"]['lane_id'], vehicle_dict["waypoint_info"]['road_id']) # the found lane_idx
                    if vehicle_dict["lane_idx"] is not None:
                        actordict[actor.id] = vehicle_dict
                        actordict[actor.id]["lane_invasion"] = \
                            self.sensor_manager.sensors_groups[actor.id]["lane_invasion"].is_invading_lane(frame)
                        actordict[actor.id]["collision"] = \
                            self.sensor_manager.sensors_groups[actor.id]["collision"].has_collided()
            if  actor_type == 'static':
                staticdict[actor.id] = get_actor_attributes(actor)
                 

    
        '''
        for p in pedestrians:
            if p.get_location().distance(self.ego.get_location()) < self.detect_range:
                ped_wp = map1.get_waypoint(p.get_location(), project_to_road=True, lane_type=(carla.LaneType.Driving | carla.LaneType.Shoulder | carla.LaneType.Sidewalk))
                ped_dict = get_actor_attributes(p, ped_wp)
                ped_dict["lane_idx"], ped_dict["relative_position"] = get_actor_lane_idx(lanes, ped_dict['lane_id'], ped_dict['road_id']) # the found lane_idx
                # ped_dict['lane_id'], ped_dict['road_id'])
                if ped_dict["lane_idx"] is not None:
                    peddict[p.id] = ped_dict
        '''

        for p in pedestrians:
            if p.get_location().distance(self.ego.get_location()) < self.detect_range:
                ped_wp = map1.get_waypoint(p.get_location(), project_to_road=True, lane_type=(
                            carla.LaneType.Driving | carla.LaneType.Shoulder | carla.LaneType.Sidewalk))
                ped_dict = get_actor_attributes(p, ped_wp)
                peddict[p.id] = ped_dict

        for t_light in trafficlights:
            if t_light.get_location().distance(self.ego.get_location()) < self.detect_range:
                lightdict[t_light.id] = get_traffic_light_attributes(t_light)

        for s in signs:
            if s.get_location().distance(self.ego.get_location()) < self.detect_range:
                signdict[s.id] = get_actor_attributes(s)


        self.framedict[frame]={
            "ego": egodict,
            "actors": actordict,
            "pedestrians": peddict,
            "trafficlights": lightdict,
            "signs": signdict,
            "statics": staticdict,
            "lane": lanedict}
        
    def export_data(self):
        # savedict = {"stop_status": status,
        #             "frame_info": self.framedict}
        # with open(self.output_dir / (str(list(self.framedict.keys())[0]) + '-' + str(list(self.framedict.keys())[len(self.framedict)-1])+'.json'), 'w') as file:
        #     file.write(json.dumps(savedict))
        # self.framedict.clear()
        savedict = self.framedict
        with open(self.output_dir / (str(list(self.framedict.keys())[0]) + '-' + str(list(self.framedict.keys())[len(self.framedict)-1])+'.json'), 'w') as file:
            file.write(json.dumps(savedict))
        self.framedict.clear()
