import georeg, tfcv, math, cosy, sys, time
import numpy as np
from functools import partial
import tensorflow as tf
import georegdata as grd

class Prediction:
    def __init__(self, axy, scores, prior_latlon, prior_bearing, epsg3857_to_epsg4326, epsg4326_to_epsg3857, meters_per_pixel):
        self.axy = axy # curr_to_prior in meters
        self.scores = scores

        self.prior_latlon = prior_latlon
        self.prior_bearing = prior_bearing

        self.epsg3857_to_epsg4326 = epsg3857_to_epsg4326
        self.epsg4326_to_epsg3857 = epsg4326_to_epsg3857
        self.meters_per_pixel = meters_per_pixel

        self.prior_to_epsg3857 = cosy.np.Rigid(
            translation=cosy.np.proj.eastnorthmeters_at_latlon_to_epsg3857(self.prior_latlon).translation,
            rotation=cosy.np.angle_to_rotation_matrix(epsg4326_to_epsg3857.transform_angle(math.radians(prior_bearing))),
        )

    def to_epsg4326(self, curr_to_prior):
        curr_to_epsg3857 = self.prior_to_epsg3857 * curr_to_prior

        latlon = self.epsg3857_to_epsg4326(curr_to_epsg3857.translation.astype("float64"))
        bearing = math.degrees(self.epsg3857_to_epsg4326.transform_angle(cosy.np.rotation_matrix_to_angle(curr_to_epsg3857.rotation.astype("float64"))))
        return latlon, bearing

    def from_epsg4326(self, latlon, bearing):
        curr_to_epsg3857 = cosy.np.Rigid(
            rotation=cosy.np.angle_to_rotation_matrix(self.epsg4326_to_epsg3857.transform_angle(math.radians(bearing))).astype("float64"),
            translation=self.epsg4326_to_epsg3857(latlon).astype("float64"),
        )
        curr_to_prior = self.prior_to_epsg3857.inverse() * curr_to_epsg3857
        return curr_to_prior

    def axy_to_transform(self, axy, unit):
        if unit not in ["pixels", "meters"]:
            raise ValueError("Got invalid unit parameter")
        transform = cosy.np.Rigid(
            translation=axy[1:],
            rotation=cosy.np.angle_to_rotation_matrix(axy[0]),
        )
        if unit == "meters":
            transform.translation *= self.meters_per_pixel
        return transform

    # def transform_to_axy(self, transform):
    #     return np.asarray(
    #           [cosy.np.rotation_matrix_to_angle(transform.rotation)] \
    #         + transform.translation.tolist()
    #     )

    def discrete_argmax(self, unit):
        index = np.argmax(self.scores)
        return self.axy_to_transform(self.axy[index], unit)

class Predictor:
    def __init__(self, model, model_constants, angles, preprocess_aerial, preprocess_ground, config, corr_shape=None):
        self.epsg4326_to_epsg3857 = cosy.np.proj.Transformer("epsg:4326", "epsg:3857")
        self.epsg3857_to_epsg4326 = cosy.np.proj.Transformer("epsg:3857", "epsg:4326")

        self.model_constants = model_constants
        self.corr_shape = config["validation"]["corr_shape"] if corr_shape is None else corr_shape
        points_num = config["validation"]["points_num"]
        self.angles = angles

        self.frames_to_model_input=partial(georeg.model.io.ModelInput.from_frames,
            angles=angles,
            align_bearing=False,
            augment_aerial_image=None,
            augment_ground_image=None,
            shuffle_points=False,
            points_num=points_num,
            corr_shape=self.corr_shape,
            training_mask=False,
        )

        def tf_function(model_input):
            model_input = georeg.model.io.ModelInput(*model_input)
            corr, valid_corr = model(model_input.to_list(preprocess_aerial=preprocess_aerial, preprocess_ground=preprocess_ground), training=False)

            corr = tfcv.model.einops.apply("b h a s... -> b a s...", corr, reduction="mean")
            valid_corr = tfcv.model.einops.apply("b h a s... -> b a s...", valid_corr, reduction="all")
            corr = georeg.model.correlation.math.softmax(corr, valid_corr)



            axy = georeg.model.correlation.math.axy_volume(tf.shape(corr)[-2:], angles=model_input.angles) # b a s... 3

            axy = tfcv.model.einops.apply("b a s... i -> b (a s...) i", axy)
            corr = tfcv.model.einops.apply("b a s... -> b (a s...)", corr)

            return axy, corr
        tf_function = tf.function(tf_function)
        self.tf_function = tf_function

    def __call__(self, aligned_ground_frame, tileloader, prior_latlon, prior_bearing, silent=False):
        # Load frame
        if not silent:
            print("Loading frame...", end="")
            sys.stdout.flush()
            start = time.time()
        frame = grd.Frame(
            aligned_ground_frame,
            grd.aerial.FrameId(tileloader, tileloader.name, tileloader.zoom, prior_latlon, prior_bearing, self.model_constants.meters_per_pixel[-1] / self.model_constants.aerial_stride, self.model_constants.aerial_image_shape).load(),
        )
        model_input = self.frames_to_model_input([frame])
        if not silent:
            duration = time.time() - start
            print(f" done. Seconds: {duration}")

        # Predict transforms and scores
        if not silent:
            print("Predicting...", end="")
            sys.stdout.flush()
            start = time.time()
        pred_axy, pred_scores = self.tf_function(model_input.to_list())
        if not silent:
            duration = time.time() - start
            print(f" done. Seconds: {duration}")
        pred_axy = pred_axy[0].numpy() # bevembedpixels_to_aerialembedpixels
        pred_scores = pred_scores[0].numpy()

        return Prediction(
            axy=pred_axy,
            scores=pred_scores,
            prior_latlon=prior_latlon,
            prior_bearing=prior_bearing,
            epsg3857_to_epsg4326=self.epsg3857_to_epsg4326,
            epsg4326_to_epsg3857=self.epsg4326_to_epsg3857,
            meters_per_pixel=self.model_constants.meters_per_pixel[-1],
        ), frame
