import sys
SRUNNER_PATH = r'./scenario_runner'
sys.path.append(SRUNNER_PATH)

from scenario_runner import ScenarioRunner
from srunner.scenariomanager.scenario_manager import ScenarioManager

import time
import traceback
import random
import json
import os
import logging
import carla
import cv2
import numpy as np
import glob

import py_trees

from srunner.scenariomanager.carla_data_provider import CarlaDataProvider
from srunner.tools.scenario_parser import ScenarioConfigurationParser
from srunner.scenariomanager.timer import GameTime
from srunner.scenariomanager.watchdog import Watchdog
from sensor_manager import SensorManager
from pathlib import Path
import pprint
from data_extractor import DataExtractor
from tqdm import tqdm
from collections import defaultdict


class CostumeScenarioManager(ScenarioManager):

    def __init__(self, args):
        super(CostumeScenarioManager, self).__init__(debug_mode=args.debug, sync_mode=args.sync, timeout=args.timeout)
        self.args = args
        
    # def set_collision_detection(self, traffic_manager, ego, others, activate_ratio=0.5):
    #     ratio = np.random.rand()
    #     # print("other actor list:")
    #     # print(others)
    #     for other_actor in others:
    #         traffic_manager.collision_detection(ego, other_actor, True)
    #         # traffic_manager.collision_detection(ego, other_actor, True if ratio < activate_ratio else False)

    def run_scenario(self, scenario_name, repeat_path, extra_vehicles_list):
        
        """
        Trigger the start of the scenario and wait for it to finish/fail
        """

        world = CarlaDataProvider.get_world()

        # ================
        # Sensor Manager
        # ================
        self.sensor_manager = SensorManager(world, repeat_path, self.args.extract_per_frames, self.args.warmup_frames)
        self.sensor_manager.attach_sensors_to_vehicles(self.ego_vehicles + self.other_actors + extra_vehicles_list, 
                                                        ["collision", "lane_invasion"]
                                                        # "all"
                                                       )

        # ================
        # Data Extractor
        # ================+
        extractor = DataExtractor(world, 
                                self.sensor_manager, 
                                self.ego_vehicles[0], 
                                self.other_actors,
                                repeat_path, 
                                self.args.response_distance, 
                                self.args.extract_per_frames,
                                self.args.warmup_frames)

        
        self.start_system_time = time.time()
        start_game_time = GameTime.get_time()

        self._watchdog = Watchdog(float(self._timeout))
        self._watchdog.start()
        self._running = True
        idx = 0
        before_stopping = 30
        collision_happens = False
        try:
            with tqdm() as pbar:
                while before_stopping > 0 and self._running and idx < self.args.max_episode_len:
                    # self.set_collision_detection(tm, self.ego_vehicles[0], self.other_actors, self.args.collsion_detector_ratio)
                    pbar.update(1)
                    idx += 1
                    timestamp = None
                    world = CarlaDataProvider.get_world()
                    self.ego_vehicles[0].apply_control(self.scenario_class.change_control(self.ego_vehicles[0].get_control()))
                    # print("current steer:", self.ego_vehicles[0].get_control().steer)
                    if world:
                        snapshot = world.get_snapshot()
                        if snapshot:
                            timestamp = snapshot.timestamp
                    if timestamp:
                        self._tick_scenario(timestamp)
                        
                    if not collision_happens and self.sensor_manager.sensors_groups[self.ego_vehicles[0].id]["collision"].has_collided():
                        collision_happens = True
                        print('\033[93m'+"collision happens"+'\033[0m')
                    if collision_happens:
                        before_stopping -= 1
                    if self._running:
                        extractor.extract_frame()
                        
                        
        except Exception as e:
            print(traceback.format_exc())
            
        self.sensor_manager.destory_all()
        self.sensor_manager = None
        
        extractor.export_data()
        
        if collision_happens:
            scenario_result = "COLLISION"
        elif self.scenario_tree.status == py_trees.common.Status.SUCCESS:
            if self.scenario.timeout_node.timeout:
                scenario_result = "TIMEOUT"
            else:
                scenario_result = "SUCCESS"
        elif self.scenario_tree.status == py_trees.common.Status.RUNNING:
            scenario_result = "TIMEOUT"
        else:
            scenario_result = "FAILED"
        
        
        weather = world.get_weather()
        weather_dict = {"wetness": weather.wetness, "wind_intensity": weather.wind_intensity,
                        "precipitation_deposits": weather.precipitation_deposits,
                        "precipitation": weather.precipitation, "cloudiness": weather.cloudiness,
                        "fog_density": weather.fog_density, "fog_distance": weather.fog_distance,
                        "sun_altitude_angle": weather.sun_altitude_angle,
                        "sun_azimuth_angle": weather.sun_azimuth_angle}
        
        meta_dict = {"map": os.path.basename(os.path.normpath(world.get_map().name)),
                        "scenario": scenario_name,
                        "weather": weather_dict, 
                        "result": scenario_result}

        self.cleanup()

        self.end_system_time = time.time()
        end_game_time = GameTime.get_time()

        self.scenario_duration_system = self.end_system_time - \
            self.start_system_time
        self.scenario_duration_game = end_game_time - start_game_time

        if self.scenario_tree.status == py_trees.common.Status.FAILURE:
            print("ScenarioManager: Terminated due to failure")


        # img_array = []
        # for 
        # for filename in glob.glob(f"{repeat_path}/"):
        #     img = cv2.imread(filename)
        #     height, width, layers = img.shape
        #     size = (width,height)
        #     img_array.append(img)


        # out = cv2.VideoWriter('project.avi',cv2.VideoWriter_fourcc(*'DIVX'), 15, size)

        # for i in range(len(img_array)):
        #     out.write(img_array[i])
        # out.release()
        return meta_dict
        
    def _tick_scenario(self, timestamp):
        """
        Run next tick of scenario and the agent.
        If running synchornously, it also handles the ticking of the world.
        """

        if self._timestamp_last_run < timestamp.elapsed_seconds and self._running:
            self._timestamp_last_run = timestamp.elapsed_seconds

            self._watchdog.update()

            if self._debug_mode:
                print("\n--------- Tick ---------\n")

            # Update game time and actor information
            GameTime.on_carla_tick(timestamp)
            CarlaDataProvider.on_carla_tick()

            if self._agent is not None:
                ego_action = self._agent()  # pylint: disable=not-callable

            if self._agent is not None:
                self.ego_vehicles[0].apply_control(ego_action)

            # Tick scenario
            self.scenario_tree.tick_once()

            if self._debug_mode:
                print("\n")
                py_trees.display.print_ascii_tree(self.scenario_tree, show_status=True)
                sys.stdout.flush()

            if self.scenario_tree.status != py_trees.common.Status.RUNNING:
                self._running = False
            

        if self._sync_mode and self._running and self._watchdog.get_status():
            CarlaDataProvider.get_world().tick()




class CostumeScenarioRunner(ScenarioRunner):

    def __init__(self, args):
        self.frame_rate = 10.0      # in Hz
        super(CostumeScenarioRunner, self).__init__(args)
        print("sync mode", self._args.sync)
        self.manager = CostumeScenarioManager(self._args)
        root_path = Path(self._args.save_dir)
        root_path.mkdir(exist_ok=True)
        self.scene_path = root_path
        self.vehicles_list = []
        self.walkers_list = []
        self.all_id = []
        
    def _load_and_wait_for_world(self, town, ego_vehicles=None):
        """
        Load a new CARLA world and provide data to CarlaDataProvider
        """

        if self._args.reloadWorld:
            self.world = self.client.load_world(town)
        else:
            # if the world should not be reloaded, wait at least until all ego vehicles are ready
            ego_vehicle_found = False
            if self._args.waitForEgo:
                while not ego_vehicle_found and not self._shutdown_requested:
                    vehicles = self.client.get_world().get_actors().filter('vehicle.*')
                    for ego_vehicle in ego_vehicles:
                        ego_vehicle_found = False
                        for vehicle in vehicles:
                            if vehicle.attributes['role_name'] == ego_vehicle.rolename:
                                ego_vehicle_found = True
                                break
                        if not ego_vehicle_found:
                            print("Not all ego vehicles ready. Waiting ... ")
                            time.sleep(1)
                            break

        self.world = self.client.get_world()

        if self._args.sync:
            import math
            settings = self.world.get_settings()
            settings.synchronous_mode = True
            settings.fixed_delta_seconds = 1.0 / self.frame_rate
            settings.max_substeps = math.ceil(settings.fixed_delta_seconds / settings.max_substep_delta_time)
            self.world.apply_settings(settings)

        CarlaDataProvider.set_client(self.client)
        CarlaDataProvider.set_world(self.world)

        # Wait for the world to be ready
        if CarlaDataProvider.is_sync_mode():
            self.world.tick()
        else:
            self.world.wait_for_tick()

        map_name = CarlaDataProvider.get_map().name.split('/')[-1]
        if map_name not in (town, "OpenDriveMap"):
            print("The CARLA server uses the wrong map: {}".format(map_name))
            print("This scenario requires to use map: {}".format(town))
            return False

        return True
        

    def _set_ego_vehicle_behaviors(self, traffic_manager):
        param_range = {
            "distance_to_leading_vehicle": (0, 5),
            "percentage_speed_difference": (-100, 30),
            "ignore_lights_percentage": (0, 50),
            "ignore_signs_percentage": (0, 50),
            "ignore_vehicle_percentage": (20, 40),
            "ignore_walker_percentage": (20, 40)
        } 
        param_values = dict()

        for key in param_range:
            min_v, max_v = param_range[key]
            param_values[key] = random.random() * (max_v-min_v) + min_v

        # pprint.pprint(param_values)

        self.ego_vehicles[0].set_autopilot(True, traffic_manager.get_port())
        traffic_manager.distance_to_leading_vehicle(self.ego_vehicles[0], param_values["distance_to_leading_vehicle"])
        traffic_manager.vehicle_percentage_speed_difference(self.ego_vehicles[0], param_values["percentage_speed_difference"])
        traffic_manager.ignore_lights_percentage(self.ego_vehicles[0], param_values["ignore_lights_percentage"])
        traffic_manager.ignore_signs_percentage(self.ego_vehicles[0], param_values["ignore_signs_percentage"])
        traffic_manager.ignore_vehicles_percentage(self.ego_vehicles[0], param_values["ignore_vehicle_percentage"])
        traffic_manager.ignore_walkers_percentage(self.ego_vehicles[0], param_values["ignore_walker_percentage"])
        
        return param_values
        
        

    def _run_scenarios(self):
        """
        Run conventional scenarios (e.g. implemented using the Python API of ScenarioRunner)
        """
        result = False

        # Load the scenario configurations provided in the config file
        scenario_configurations = ScenarioConfigurationParser.parse_scenario_configuration(
            self._args.scenario,
            self._args.configFile)
        if not scenario_configurations:
            print("Configuration for scenario {} cannot be found!".format(self._args.scenario))
            return result

        # Execute each configuration
        results_meta = dict()
        for config in scenario_configurations:
            results_meta_scenario = defaultdict(int)
            for repeat in range(self._args.repetitions):
                result = self._load_and_run_scenario(config)
                if result:
                    results_meta_scenario[result] += 1
                else:
                    results_meta_scenario["ERROR"] += 1
            results_meta[config.name] = results_meta_scenario
            # self._cleanup()
        # pprint.pprint(results_meta)
        with open(self.scene_path/"meta_result.txt", 'a+') as mf:
            mf.write(json.dumps(results_meta))
            mf.write('\n')
        return result

    def spawn_walkers(self, counts):
        blueprintsWalkers = self.world.get_blueprint_library().filter('walker.pedestrian.*')
        percentagePedestriansRunning = 0.0      # how many pedestrians will run
        percentagePedestriansCrossing = 0.0     # how many pedestrians will walk through the road
        # 1. take all the random locations to spawn
        spawn_points = []
        for i in range(counts):
            spawn_point = carla.Transform()
            loc = self.world.get_random_location_from_navigation()
            if (loc != None):
                spawn_point.location = loc
                spawn_points.append(spawn_point)
        # 2. we spawn the walker object
        batch = []
        walker_speed = []
        for spawn_point in spawn_points:
            walker_bp = random.choice(blueprintsWalkers)
            if walker_bp.has_attribute('is_invincible'):
                walker_bp.set_attribute('is_invincible', 'false')
            if walker_bp.has_attribute('speed'):
                if (random.random() > percentagePedestriansRunning):
                    # walking
                    walker_speed.append(walker_bp.get_attribute('speed').recommended_values[1])
                else:
                    # running
                    walker_speed.append(walker_bp.get_attribute('speed').recommended_values[2])
            else:
                print("Walker has no speed")
                walker_speed.append(0.0)
            batch.append(carla.command.SpawnActor(walker_bp, spawn_point))
        results = self.client.apply_batch_sync(batch, True)
        walker_speed2 = []
        for i in range(len(results)):
            if results[i].error:
                logging.error(results[i].error)
            else:
                self.walkers_list.append({"id": results[i].actor_id})
                walker_speed2.append(walker_speed[i])
        walker_speed = walker_speed2
        # 3. we spawn the walker controller
        batch = []
        walker_controller_bp = self.world.get_blueprint_library().find('controller.ai.walker')
        for i in range(len(self.walkers_list)):
            batch.append(carla.command.SpawnActor(walker_controller_bp, carla.Transform(), self.walkers_list[i]["id"]))
        results = self.client.apply_batch_sync(batch, True)
        for i in range(len(results)):
            if results[i].error:
                logging.error(results[i].error)
            else:
                self.walkers_list[i]["con"] = results[i].actor_id
        # 4. we put altogether the walkers and controllers id to get the objects from their id
        for i in range(len(self.walkers_list)):
            self.all_id.append(self.walkers_list[i]["con"])
            self.all_id.append(self.walkers_list[i]["id"])
        self.all_actors = self.world.get_actors(self.all_id)

        # wait for a tick to ensure client receives the last transform of the walkers we have just created
        if not self._args.sync:
            self.world.wait_for_tick()
        else:
            self.world.tick()

        # 5. initialize each controller and set target to walk to (list is [controler, actor, controller, actor ...])
        self.world.set_pedestrians_cross_factor(percentagePedestriansCrossing)
        for i in range(0, len(self.all_id), 2):
            # start walker
            self.all_actors[i].start()
            # set walk to random point
            self.all_actors[i].go_to_location(self.world.get_random_location_from_navigation())
            # max speed
            self.all_actors[i].set_max_speed(float(walker_speed[int(i/2)]))

    def spawn_vehicles(self, counts):
        blueprints = self.world.get_blueprint_library().filter('vehicle.*')
        # blueprints = [x for x in blueprints if int(x.get_attribute('number_of_wheels')) == 4]
        # blueprints = [x for x in blueprints if not x.id.endswith('isetta')]
        # blueprints = [x for x in blueprints if not x.id.endswith('carlacola')]
        blueprints = [x for x in blueprints if not x.id.endswith('cybertruck')]
        # blueprints = [x for x in blueprints if not x.id.endswith('t2')]
        # blueprints = [x for x in blueprints if (x.id.endswith('isetta') or x.id.endswith('carlacola') or x.id.endswith('cybertruck') or x.id.endswith('t2'))]

        spawn_points = self.world.get_map().get_spawn_points()
        number_of_spawn_points = len(spawn_points)

        if counts < number_of_spawn_points:
            random.shuffle(spawn_points)
        elif counts > number_of_spawn_points:
            msg = 'requested %d vehicles, but could only find %d spawn points'
            logging.warning(msg, counts, number_of_spawn_points)
            counts = number_of_spawn_points
        
        batch = []
        for n, transform in enumerate(spawn_points):
            if n >= counts:
                break
            blueprint = random.choice(blueprints)
            if blueprint.has_attribute('color'):
                color = random.choice(blueprint.get_attribute('color').recommended_values)
                blueprint.set_attribute('color', color)
            if blueprint.has_attribute('driver_id'):
                driver_id = random.choice(blueprint.get_attribute('driver_id').recommended_values)
                blueprint.set_attribute('driver_id', driver_id)
            blueprint.set_attribute('role_name', 'autopilot')
            batch.append(
                carla.command.SpawnActor(blueprint, transform)
                .then(carla.command.SetAutopilot(carla.command.FutureActor, True))
                )
        vehicles_ids = []
        for response in self.client.apply_batch_sync(batch, self._args.sync):
            if response.error:
                logging.error(response.error)
            else:
                vehicles_ids.append(response.actor_id)
        self.vehicles_list = list(self.world.get_actors(vehicles_ids))
    

    def _load_and_run_scenario(self, config):
        
        """
        Load and run the scenario given by config
        """
        result = False
        if not self._load_and_wait_for_world(config.town, config.ego_vehicles):
            self._cleanup()
            return False

        if self._args.agent:
            agent_class_name = self.module_agent.__name__.title().replace('_', '')
            try:
                self.agent_instance = getattr(self.module_agent, agent_class_name)(self._args.agentConfig)
                config.agent = self.agent_instance
            except Exception as e:          # pylint: disable=broad-except
                traceback.print_exc()
                print("Could not setup required agent due to {}".format(e))
                self._cleanup()
                return False

        CarlaDataProvider.set_traffic_manager_port(int(self._args.trafficManagerPort))
        tm = self.client.get_trafficmanager(int(self._args.trafficManagerPort))
        tm.set_random_device_seed(int(self._args.trafficManagerSeed))
        if self._args.sync:
            tm.set_synchronous_mode(True)

        # Prepare scenario
        print("Preparing scenario: " + config.name)
        try:
            self._prepare_ego_vehicles(config.ego_vehicles)
            if self._args.openscenario:
                scenario = OpenScenario(world=self.world,
                                        ego_vehicles=self.ego_vehicles,
                                        config=config,
                                        config_file=self._args.openscenario,
                                        timeout=100000)
            elif self._args.route:
                scenario = RouteScenario(world=self.world,
                                         config=config,
                                         debug_mode=self._args.debug)
            else:
                scenario_class = self._get_scenario_class_or_fail(config.type)
                scenario = scenario_class(self.world,
                                          self.ego_vehicles,
                                          config,
                                          self._args.randomize,
                                          self._args.debug)
        except Exception as exception:                  # pylint: disable=broad-except
            print("The scenario cannot be loaded")
            traceback.print_exc()
            print(exception)
            
            self._cleanup()
            
            return False
            
        ego_behaviors_param = self._set_ego_vehicle_behaviors(tm)
        if self._args.extra_vehicles:
            self.spawn_vehicles(self._args.extra_vehicles)
        if self._args.extra_walkers:
            self.spawn_walkers(self._args.extra_walkers)
            
            
        
        
        scenario_path = Path("%s/%s" % (self.scene_path, config.name)).resolve()
        scenario_path.mkdir(exist_ok=True)
        existed = [int(c) for c in os.listdir(scenario_path)]
        if existed:
            start_idx = max(existed) + 1
        else:
            start_idx = 0
        repeat = start_idx
        repeat_path = Path("%s/%s" % (scenario_path, start_idx)).resolve()
        repeat_path.mkdir(exist_ok=False)
        try:
            if self._args.record:
                recorder_name = "{}/{}/{}.log".format(
                    os.getenv('SCENARIO_RUNNER_ROOT', "./"), self._args.record, config.name)
                self.client.start_recorder(recorder_name, True)

            # Load scenario and run it
            self.manager.load_scenario(scenario, self.agent_instance)
            print("ScenarioManager: Running scenario {}, repeat {}".format(config.name, repeat))
            metadata = self.manager.run_scenario(config.name, repeat_path, self.vehicles_list)
            
            
            # ================
            # Save meta info
            # ================
            with open((repeat_path / 'metadata.txt').resolve(), 'w') as file:
                metadata['ego_behaviors_param'] = ego_behaviors_param
                file.write(json.dumps(metadata))

            # Provide outputs if required
            self._analyze_scenario(config)

            # Remove all actors, stop the recorder and save all criterias (if needed)
            scenario.remove_all_actors()
            self.client.apply_batch([carla.command.DestroyActor(x) for x in self.vehicles_list])
            for i in range(0, len(self.all_id), 2):
                self.all_actors[i].stop()
            self.client.apply_batch([carla.command.DestroyActor(x) for x in self.all_id])
            self.walkers_list = self.vehicles_list = self.all_id = []
            if self._args.record:
                self.client.stop_recorder()
                self._record_criteria(self.manager.scenario.get_criteria(), recorder_name)


        except Exception as e:              # pylint: disable=broad-except
            traceback.print_exc()
            print(e)
            result = False

        self._cleanup()
        return metadata['result']




























    # def _get_scenario_class_or_fail(self, scenario):
    #     import importlib
    #     import inspect
    #     import glob
    #     """
    #     Get scenario class by scenario name
    #     If scenario is not supported or not found, exit script
    #     """

    #     # Path of all scenario at "srunner/scenarios" folder + the path of the additional scenario argument
    #     scenarios_list = glob.glob("{}/srunner/scenarios/*.py".format(os.getenv('SCENARIO_RUNNER_ROOT', "./")))
    #     scenarios_list.append(self._args.additionalScenario)

    #     for scenario_file in scenarios_list:

    #         # Get their module
    #         module_name = os.path.basename(scenario_file).split('.')[0]
            
    #         sys.path.insert(0, os.path.dirname(scenario_file))
    #         scenario_module = importlib.import_module(module_name)
    #         print(scenario_module)

    #         # And their members of type class
    #         for member in inspect.getmembers(scenario_module, inspect.isclass):
    #             if scenario in member:
    #                 return member[1]

    #         # Remove unused Python paths
    #         sys.path.pop(0)

    #     print("Scenario '{}' not supported ... Exiting".format(scenario))
    #     sys.exit(-1)
    