
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch
from utils import utils
import os

# from networks.VGG import VGGUnet
from networks.VGG import VGGUnet
from utils.jacobian import grid_sample, grid_sample_forward

from collections import defaultdict
import random

EPS = utils.EPS

class SIBCL(nn.Module):
    def __init__(self, args):  # device='cuda:0',
        super(SIBCL, self).__init__()
        '''
        S2GP & G2SP
        '''
        self.args = args

        self.level = args.level
        self.N_iters = args.N_iters
        self.using_weight = args.using_weight
        self.loss_method = args.loss_method

        if args.encoder == 'vgg':
            self.SatFeatureNet = VGGUnet(args, self.level)
            self.GrdFeatureNet = self.SatFeatureNet

        if args.rotation_range > 0:
            self.damping = nn.Parameter(
                torch.zeros(size=(1, 3), dtype=torch.float32, requires_grad=True))
        else:
            self.damping = nn.Parameter(
            torch.zeros(size=(), dtype=torch.float32, requires_grad=True))

        ori_grdH, ori_grdW = self.args.grdH, self.args.grdW     #256, 1024
        xyz_grds = []
        for level in range(4):
            grd_H, grd_W = ori_grdH/(2**(3-level)), ori_grdW/(2**(3-level))
            if self.args.proj == 'geo': # proj based on ground plane homography
                xyz_grd, mask, xyz_w = self.grd_img2cam(grd_H, grd_W, ori_grdH,
                                                 ori_grdW)  # [1, grd_H, grd_W, 3] under the grd camera coordinates
                xyz_grds.append((xyz_grd, mask, xyz_w)) # multi-scale, xyz_grds is wKinv[v_g]

            else:
                xyz_grd, mask = self.grd_img2cam_polar(grd_H, grd_W, ori_grdH, ori_grdW)
                xyz_grds.append((xyz_grd, mask))

        self.xyz_grds = xyz_grds

        self.meters_per_pixel = []
        meter_per_pixel = utils.get_meter_per_pixel()
        for level in range(4):
            self.meters_per_pixel.append(meter_per_pixel * (2 ** (3 - level)))

        polar_grids = []
        for level in range(4):
            grids = self.polar_coordinates(level)
            polar_grids.append(grids)
        self.polar_grids = polar_grids

        torch.autograd.set_detect_anomaly(True)
        # Running the forward pass with detection enabled will allow the backward pass to print the traceback of the forward operation that created the failing backward function.
        # Any backward computation that generate “nan” value will raise an error.

    def grd_img_cood(self, B, grd_H, grd_W):
        v, u = torch.meshgrid(torch.arange(0, grd_H, dtype=torch.float32),
                              torch.arange(0, grd_W, dtype=torch.float32))
        uv = torch.stack([u, v], dim=-1).unsqueeze(dim=0)  # [1, grd_H, grd_W, 3]
        uv = uv.repeat((B, 1, 1, 1))
        return uv

    def sample_points(self, pts_, mask_pts):
        with torch.no_grad():
            pts = []
            B = pts_.size(0)
            for b in range(B):
                mask_idx = mask_pts[b].nonzero()[:, 0]
                rand_idx = torch.randperm(mask_idx.shape[0])
                mask_idx = mask_idx[rand_idx][:self.args.max_points]
                pt = pts_[b, mask_idx]
                pts.append(pt)
            pts = torch.stack(pts, dim=0)

        return pts

    def grd_img2cam(self, grd_H, grd_W, ori_grdH, ori_grdW):
        
        ori_camera_k = torch.tensor([[[582.9802,   0.0000, 496.2420],
                                      [0.0000, 482.7076, 125.0034],
                                      [0.0000,   0.0000,   1.0000]]], 
                                    dtype=torch.float32, requires_grad=True)  # [1, 3, 3]

        camera_height = utils.get_camera_height()

        camera_k = ori_camera_k.clone()
        camera_k[:, :1, :] = ori_camera_k[:, :1,
                             :] * grd_W / ori_grdW  # original size input into feature get network/ output of feature get network
        camera_k[:, 1:2, :] = ori_camera_k[:, 1:2, :] * grd_H / ori_grdH
        camera_k_inv = torch.inverse(camera_k)  # [B, 3, 3]

        v, u = torch.meshgrid(torch.arange(0, grd_H, dtype=torch.float32),
                              torch.arange(0, grd_W, dtype=torch.float32))
        uv1 = torch.stack([u, v, torch.ones_like(u)], dim=-1).unsqueeze(dim=0)  # [1, grd_H, grd_W, 3]
        xyz_w = torch.sum(camera_k_inv[:, None, None, :, :] * uv1[:, :, :, None, :], dim=-1)  # [1, grd_H, grd_W, 3]
        # set y_c to the distance between query camera to the ground plane
        # r = xyz_w[..., 1:2].cpu().detach().numpy()
        w = camera_height / torch.where(torch.abs(xyz_w[..., 1:2]) > utils.EPS, xyz_w[..., 1:2],
                                        utils.EPS * torch.ones_like(xyz_w[..., 1:2]))  # [BN, grd_H, grd_W, 1]
        xyz_grd = xyz_w * w  # [1, grd_H, grd_W, 3] under the grd camera coordinates
        # xyz_grd = xyz_grd.reshape(B, N, grd_H, grd_W, 3)

        mask = (xyz_grd[..., -1] > 0).float()  # # [1, grd_H, grd_W]
        # debug
        if self.args.debug:
            xyz_grd_mask = xyz_grd * mask[..., None]
            xyz_w_cpu = xyz_w[0].cpu().detach().numpy()[:,:,-1]
            w_cpu = w[0].cpu().detach().numpy()[:,:,-1]
            xyz_grd_cpu = xyz_grd[0].cpu().detach().numpy()[:,:,-1]
            xyz_grd_mask_cpu = xyz_grd_mask[0].cpu().detach().numpy()[:,:,-1]

        return xyz_grd, mask, xyz_w

    def grd_img2cam_polar(self, grd_H, grd_W, ori_grdH, ori_grdW):

        v, u = torch.meshgrid(torch.arange(0, grd_H, dtype=torch.float32),
                              torch.arange(0, grd_W, dtype=torch.float32))
        theta = u/grd_W * np.pi/4
        radius = (1 - v / grd_H) * 30  # set radius as 30 meters

        z = radius * torch.cos(np.pi/4 - theta)
        x = -radius * torch.sin(np.pi/4 - theta)
        y = utils.get_camera_height() * torch.ones_like(z)
        xyz_grd = torch.stack([x, y, z], dim=-1).unsqueeze(dim=0) # [1, grd_H, grd_W, 3] under the grd camera coordinates

        mask = torch.ones_like(z).unsqueeze(dim=0)  # [1, grd_H, grd_W]

        return xyz_grd, mask

    def grd2cam2world2sat(self, ori_shift_u, ori_shift_v, ori_heading, level,
                          satmap_sidelength, require_jac=False, gt_depth=None):
        '''
        realword: X: south, Y:down, Z: east
        camera: u:south, v: down from center (when heading east, need to rotate heading angle)
        Args:
            ori_shift_u: [B, 1]
            ori_shift_v: [B, 1]
            heading: [B, 1]
            XYZ_1: [H,W,4]
            ori_camera_k: [B,3,3]
            grd_H:
            grd_W:
            ori_grdH:
            ori_grdW:

        Returns:
        '''
        B, _ = ori_heading.shape
        heading = ori_heading * self.args.rotation_range / 180 * np.pi  # rotation_range +-10
        shift_u = ori_shift_u * self.args.shift_range_lon   # shift_range_lon: 20
        shift_v = ori_shift_v * self.args.shift_range_lat   # shift_range_lat: 20

        cos = torch.cos(heading)
        sin = torch.sin(heading)
        zeros = torch.zeros_like(cos)
        ones = torch.ones_like(cos)
        R = torch.cat([cos, zeros, -sin, zeros, ones, zeros, sin, zeros, cos], dim=-1)  # shape = [B, 9] # why? only heading direction
        R = R.view(B, 3, 3)  # shape = [B, N, 3, 3]
        # this R is the inverse of the R in G2SP

        camera_height = utils.get_camera_height()
        # camera offset, shift[0]:east,Z, shift[1]:north,X
        height = camera_height * torch.ones_like(shift_u[:, :1])
        T0 = torch.cat([shift_v, height, -shift_u], dim=-1)  # shape = [B, 3]   # why? -> shift_v: lat, shift_u: -lon,
        # T0 = torch.unsqueeze(T0, dim=-1)  # shape = [B, N, 3, 1]
        # T0 = torch.unsqueeze(T0, dim=-1)  # shape = [B, N, 3, 1]
        # T = torch.einsum('bnij, bnj -> bni', -R, T0) # [B, N, 3]
        T = torch.sum(-R * T0[:, None, :], dim=-1)   # [B, 3]

        # The above R, T define transformation from camera to world

        if gt_depth!=None:
            xyz_w = self.xyz_grds[level][2].detach().to(ori_shift_u.device).repeat(B, 1, 1, 1)
            H, W = xyz_w.shape[1:-1]
            gt_depth = gt_depth.clamp(min=0.1, max=self.args.max_depth)
            depth = F.interpolate(gt_depth, (H, W))  # TODO: 2D interpolation or 3D point grid sample?
            if self.args.depth == 'mono_half':
                depth_mask = torch.ones_like(depth)
                depth_mask[:, :, :H//2, :] = -1
                depth = depth * depth_mask.detach()
            xyz_grd = xyz_w * depth.permute(0, 2, 3, 1)
            # mask = (gt_depth != -1).float()
            # mask = F.interpolate(mask, (H, W), mode='nearest')
            # mask = mask[:, 0, :, :]
            mask = (xyz_grd[..., -1] > 0).float().detach()
        else:
            xyz_grd = self.xyz_grds[level][0].detach().to(ori_shift_u.device).repeat(B, 1, 1, 1)
            mask = self.xyz_grds[level][1].detach().to(ori_shift_u.device).repeat(B, 1, 1)  # [B, grd_H, grd_W]
        grd_H, grd_W = xyz_grd.shape[1:3]

        xyz = torch.sum(R[:, None, None, :, :] * xyz_grd[:, :, :, None, :], dim=-1) + T[:, None, None, :]
        # [B, grd_H, grd_W, 3]
        # zx0 = torch.stack([xyz[..., 2], xyz[..., 0]], dim=-1)  # [B, N, grd_H, grd_W, 2]
        R_sat = torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.float32, device=ori_shift_u.device, requires_grad=True)\
            .reshape(2, 3)
        zx = torch.sum(R_sat[None, None, None, :, :] * xyz[:, :, :, None, :], dim=-1)
        # [B, grd_H, grd_W, 2]
        # assert zx == zx0

        meter_per_pixel = utils.get_meter_per_pixel()
        meter_per_pixel *= utils.get_process_satmap_sidelength() / satmap_sidelength
        sat_uv = zx/meter_per_pixel + satmap_sidelength / 2  # [B, grd_H, grd_W, 2] sat map uv

        if require_jac:
            dR_dtheta = self.args.rotation_range / 180 * np.pi * \
                        torch.cat([-sin, zeros, -cos, zeros, zeros, zeros, cos, zeros, -sin], dim=-1)  # shape = [B, N, 9]
            dR_dtheta = dR_dtheta.view(B, 3, 3)
            # R_zeros = torch.zeros_like(dR_dtheta)

            dT0_dshiftu = self.args.shift_range_lon * torch.tensor([0., 0., -1.], dtype=torch.float32, device=shift_u.device,
                                                         requires_grad=True).view(1, 3).repeat(B, 1)
            dT0_dshiftv = self.args.shift_range_lat * torch.tensor([1., 0., 0.], dtype=torch.float32, device=shift_u.device,
                                                         requires_grad=True).view(1, 3).repeat(B, 1)
            # T0_zeros = torch.zeros_like(dT0_dx)

            dxyz_dshiftu = torch.sum(-R * dT0_dshiftu[:, None, :], dim=-1)[:, None, None, :].\
                repeat([1, grd_H, grd_W, 1])   # [B, grd_H, grd_W, 3]
            dxyz_dshiftv = torch.sum(-R * dT0_dshiftv[:, None, :], dim=-1)[:, None, None, :].\
                repeat([1, grd_H, grd_W, 1])   # [B, grd_H, grd_W, 3]
            dxyz_dtheta = torch.sum(dR_dtheta[:, None, None, :, :] * xyz_grd[:, :, :, None, :], dim=-1) + \
                          torch.sum(-dR_dtheta * T0[:, None, :], dim=-1)[:, None, None, :]

            duv_dshiftu = 1 / meter_per_pixel * \
                     torch.sum(R_sat[None, None, None, :, :] * dxyz_dshiftu[:, :, :, None, :], dim=-1)
            # [B, grd_H, grd_W, 2]
            duv_dshiftv = 1 / meter_per_pixel * \
                     torch.sum(R_sat[None, None, None, :, :] * dxyz_dshiftv[:, :, :, None, :], dim=-1)
            # [B, grd_H, grd_W, 2]
            duv_dtheta = 1 / meter_per_pixel * \
                     torch.sum(R_sat[None, None, None, :, :] * dxyz_dtheta[:, :, :, None, :], dim=-1)
            # [B, grd_H, grd_W, 2]

            # duv_dshift = torch.stack([duv_dx, duv_dy], dim=0)
            # duv_dtheta = duv_dtheta.unsqueeze(dim=0)

            return sat_uv, mask, duv_dshiftu, duv_dshiftv, duv_dtheta

        return sat_uv, mask, None, None, None


    def grd2cam2world(self, ori_heading, level, gt_depth=None):
        '''
        realword: X: south, Y:down, Z: east
        camera: u:south, v: down from center (when heading east, need to rotate heading angle)
        Args:
            ori_shift_u: [B, 1]
            ori_shift_v: [B, 1]
            heading: [B, 1]
            XYZ_1: [H,W,4]
            ori_camera_k: [B,3,3]
            grd_H:
            grd_W:
            ori_grdH:
            ori_grdW:

        Returns:
        '''
        B, _ = ori_heading.shape

        if gt_depth!=None:
            xyz_w = self.xyz_grds[level][2].detach().to(ori_heading.device).repeat(B, 1, 1, 1)
            H, W = xyz_w.shape[1:-1]
            # mask1 = (gt_depth > 0.1) & (gt_depth < self.args.max_depth) # .float()
            gt_depth = gt_depth.clamp(min=0.1, max=self.args.max_depth)
            depth = F.interpolate(gt_depth, (H, W))  # TODO: 2D interpolation or 3D point grid sample?
            mask1 = ((depth > 0.1) & (depth < self.args.max_depth)).float()
            if self.args.depth == 'mono_half':
                depth_mask = torch.ones_like(depth)
                depth_mask[:, :, :H//2, :] = -1
                depth = depth * depth_mask.detach()
            xyz_grd = xyz_w * depth.permute(0, 2, 3, 1)
            # mask = (gt_depth != -1).float()
            # mask = F.interpolate(mask, (H, W), mode='nearest')
            # mask = mask[:, 0, :, :]
            # mask2 = (xyz_grd[..., -1] > 0) # (xyz_grd[..., -1] > 0).float()
            # mask = mask1[:, 0, :, ] & mask2 # (mask1 * mask2).detach()
            mask2 = (xyz_grd[..., -1] > 0).float().unsqueeze(dim=1)
            mask = (mask1 * mask2).detach()
        else:
            xyz_grd = self.xyz_grds[level][0].detach().to(ori_heading.device).repeat(B, 1, 1, 1)
            mask = self.xyz_grds[level][1].detach().to(ori_heading.device).repeat(B, 1, 1)  # [B, grd_H, grd_W]

        return xyz_grd, mask


    def world2sat2im(self, grd_xyz, sat_f, ori_shift_u, ori_shift_v, ori_heading, require_jac=False):
        '''
        realword: X: south, Y:down, Z: east
        camera: u:south, v: down from center (when heading east, need to rotate heading angle)
        Args:
            ori_shift_u: [B, 1]
            ori_shift_v: [B, 1]
            heading: [B, 1]
            XYZ_1: [H,W,4]
            ori_camera_k: [B,3,3]
            grd_H:
            grd_W:
            ori_grdH:
            ori_grdW:

        Returns:
        '''
        B, _ = ori_heading.shape
        heading = ori_heading * self.args.rotation_range / 180 * np.pi  # rotation_range +-10
        shift_u = ori_shift_u * self.args.shift_range_lon   # shift_range_lon: 20
        shift_v = ori_shift_v * self.args.shift_range_lat   # shift_range_lat: 20

        cos = torch.cos(heading)
        sin = torch.sin(heading)
        zeros = torch.zeros_like(cos)
        ones = torch.ones_like(cos)
        R = torch.cat([cos, zeros, -sin, zeros, ones, zeros, sin, zeros, cos], dim=-1)  # shape = [B, 9] # why? only heading direction
        R = R.view(B, 3, 3)  # shape = [B, N, 3, 3]
        # this R is the inverse of the R in G2SP

        camera_height = utils.get_camera_height()
        # camera offset, shift[0]:east,Z, shift[1]:north,X
        height = camera_height * torch.ones_like(shift_u[:, :1])
        T0 = torch.cat([shift_v, height, -shift_u], dim=-1)  # shape = [B, 3]   # why? -> shift_v: lat, shift_u: -lon,
        T = torch.sum(-R * T0[:, None, :], dim=-1)   # [B, 3]

        if grd_xyz.dim() == 4:
            sat_xyz = torch.sum(R[:, None, None, :, :] * grd_xyz[:, :, :, None, :], dim=-1) + T[:, None, None, :] #  [B, grd_H, grd_W, 3]
        elif grd_xyz.dim() == 3:
            sat_xyz = torch.sum(R[:, None, :, :] * grd_xyz[:, :, None, :], dim=-1) + T[:, None, :] #  [B, grd_H, grd_W, 3]

        # grd_H, grd_W = sat_xyz.shape[1:3]
        num_points = grd_xyz.shape[1]
        _, _, satmap_sidelength, _ = sat_f.size()

        # R_sat = torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.float32,
        #                      device=sat_f.device,
        #                      requires_grad=True).reshape(2, 3)
        R_sat = torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.float32,
                             device=sat_f.device).reshape(2, 3)

        if sat_xyz.dim() == 4:
            sat_zx = torch.sum(R_sat[None, None, None, :, :] * sat_xyz[:, :, :, None, :], dim=-1)
        elif sat_xyz.dim() == 3:
            sat_zx = torch.sum(R_sat[None, None, :, :] * sat_xyz[:, :, None, :], dim=-1)
        # [B, grd_H, grd_W, 2]

        meter_per_pixel = utils.get_meter_per_pixel()
        meter_per_pixel *= utils.get_process_satmap_sidelength() / satmap_sidelength
        sat_uv = sat_zx / meter_per_pixel + satmap_sidelength / 2  # [B, grd_H, grd_W, 2] sat map uv

        if require_jac:
            dR_dtheta = self.args.rotation_range / 180 * np.pi * \
                        torch.cat([-sin, zeros, -cos, zeros, zeros, zeros, cos, zeros, -sin], dim=-1)  # shape = [B, N, 9]
            dR_dtheta = dR_dtheta.view(B, 3, 3)
            # R_zeros = torch.zeros_like(dR_dtheta)

            dT0_dshiftu = self.args.shift_range_lon * torch.tensor([0., 0., -1.], dtype=torch.float32, device=shift_u.device,
                                                         requires_grad=True).view(1, 3).repeat(B, 1)
            dT0_dshiftv = self.args.shift_range_lat * torch.tensor([1., 0., 0.], dtype=torch.float32, device=shift_u.device,
                                                         requires_grad=True).view(1, 3).repeat(B, 1)
            # T0_zeros = torch.zeros_like(dT0_dx)

            dxyz_dshiftu = torch.sum(-R * dT0_dshiftu[:, None, :], dim=-1)[:, None, :].\
                repeat([1, num_points, 1])   # [B, grd_H, grd_W, 3]
            dxyz_dshiftv = torch.sum(-R * dT0_dshiftv[:, None, :], dim=-1)[:, None, :].\
                repeat([1, num_points, 1])   # [B, grd_H, grd_W, 3]
            dxyz_dtheta = torch.sum(dR_dtheta[:, None, :, :] * grd_xyz[:, :, None, :], dim=-1) + \
                          torch.sum(-dR_dtheta * T0[:, None, :], dim=-1)[:, None, :]

            duv_dshiftu = 1 / meter_per_pixel * \
                     torch.sum(R_sat[None, None, :, :] * dxyz_dshiftu[:, :, None, :], dim=-1)
            # [B, grd_H, grd_W, 2]
            duv_dshiftv = 1 / meter_per_pixel * \
                     torch.sum(R_sat[None, None, :, :] * dxyz_dshiftv[:, :, None, :], dim=-1)
            # [B, grd_H, grd_W, 2]
            duv_dtheta = 1 / meter_per_pixel * \
                     torch.sum(R_sat[None, None, :, :] * dxyz_dtheta[:, :, None, :], dim=-1)
            # [B, grd_H, grd_W, 2]
            J_p2D_T = torch.stack([duv_dshiftu, duv_dshiftv, duv_dtheta], dim=-1)
            # duv_dshift = torch.stack([duv_dx, duv_dy], dim=0)
            # duv_dtheta = duv_dtheta.unsqueeze(dim=0)

            return sat_uv, sat_xyz, J_p2D_T

        else:
            return sat_uv, sat_xyz, None

    def world2sat(self, grd_xyz, ori_shift_u, ori_shift_v, ori_heading):
        '''
        realword: X: south, Y:down, Z: east
        camera: u:south, v: down from center (when heading east, need to rotate heading angle)
        Args:
            ori_shift_u: [B, 1]
            ori_shift_v: [B, 1]
            heading: [B, 1]
            XYZ_1: [H,W,4]
            ori_camera_k: [B,3,3]
            grd_H:
            grd_W:
            ori_grdH:
            ori_grdW:

        Returns:
        '''
        B, _ = ori_heading.shape
        heading = ori_heading * self.args.rotation_range / 180 * np.pi  # rotation_range +-10
        shift_u = ori_shift_u * self.args.shift_range_lon   # shift_range_lon: 20
        shift_v = ori_shift_v * self.args.shift_range_lat   # shift_range_lat: 20

        cos = torch.cos(heading)
        sin = torch.sin(heading)
        zeros = torch.zeros_like(cos)
        ones = torch.ones_like(cos)
        R = torch.cat([cos, zeros, -sin, zeros, ones, zeros, sin, zeros, cos], dim=-1)  # shape = [B, 9] # why? only heading direction
        R = R.view(B, 3, 3)  # shape = [B, N, 3, 3]
        # this R is the inverse of the R in G2SP

        camera_height = utils.get_camera_height()
        # camera offset, shift[0]:east,Z, shift[1]:north,X
        height = camera_height * torch.ones_like(shift_u[:, :1])
        T0 = torch.cat([shift_v, height, -shift_u], dim=-1)  # shape = [B, 3]   # why? -> shift_v: lat, shift_u: -lon,
        T = torch.sum(-R * T0[:, None, :], dim=-1)   # [B, 3]

        if grd_xyz.dim() == 4:
            sat_xyz = torch.sum(R[:, None, None, :, :] * grd_xyz[:, :, :, None, :], dim=-1) + T[:, None, None, :] #  [B, grd_H, grd_W, 3]
        elif grd_xyz.dim() == 3:
            sat_xyz = torch.sum(R[:, None, :, :] * grd_xyz[:, :, None, :], dim=-1) + T[:, None, :] #  [B, grd_H, grd_W, 3]

        return sat_xyz


    def sat2im(self, xyz, sat_f):
        '''
        realword: X: south, Y:down, Z: east
        camera: u:south, v: down from center (when heading east, need to rotate heading angle)
        Args:
            xyz: [B, grd_H, grd_W, 3]
            ori_shift_u: [B, 1]
            ori_shift_v: [B, 1]
            heading: [B, 1]
            XYZ_1: [H,W,4]
            ori_camera_k: [B,3,3]
            grd_H:
            grd_W:
            ori_grdH:
            ori_grdW:

        Returns:
        '''

        grd_H, grd_W = xyz.shape[1:3]
        _, _, satmap_sidelength, _ = sat_f.size()

        # R_sat = torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.float32,
        #                      device=sat_f.device,
        #                      requires_grad=True).reshape(2, 3)
        R_sat = torch.tensor([0, 0, 1, 1, 0, 0], dtype=torch.float32,
                             device=sat_f.device).reshape(2, 3)

        if xyz.dim() == 4:
            zx = torch.sum(R_sat[None, None, None, :, :] * xyz[:, :, :, None, :], dim=-1)
        elif xyz.dim() == 3:
            zx = torch.sum(R_sat[None, None, :, :] * xyz[:, :, None, :], dim=-1)
        # [B, grd_H, grd_W, 2]

        meter_per_pixel = utils.get_meter_per_pixel()
        meter_per_pixel *= utils.get_process_satmap_sidelength() / satmap_sidelength
        sat_uv = zx/meter_per_pixel + satmap_sidelength / 2  # [B, grd_H, grd_W, 2] sat map uv

        return sat_uv

    def normalize_pts(self, pts, size, pad=4):
        H, W = size
        # mask = self.mask_in_image(pts, [W, H], pad=pad)
        scale = torch.tensor([W, H]).to(pts)
        pts = (pts / scale) * 2 - 1
        mask_ = (pts < 1) & (pts > -1)
        mask = torch.where(mask_[..., 0] & mask_[..., 1], True, False)
        mask = mask.unsqueeze(dim=-1)
        return pts, mask

    def mask_in_image(self, pts, image_size, pad=4):
        w, h = image_size
        image_size_ = torch.tensor([w - pad - 1, h - pad - 1]).to(pts)
        return torch.all((pts >= pad) & (pts <= image_size_), -1)

    def project_map_to_grd(self, sat_f, sat_c, shift_u, shift_v, heading, level, require_jac=False, gt_depth=None):
        '''
        Args:
            sat_f: [B, C, H, W]
            sat_c: [B, 1, H, W]
            shift_u: [B, 2]
            shift_v: [B, 2]
            heading: [B, 1]
            camera_k: [B, 3, 3]

            ori_grdH:
            ori_grdW:

        Returns:

        '''
        B, C, satmap_sidelength, _ = sat_f.size()
        A = satmap_sidelength

        if self.args.grd2sat == 'geo':
            uv, mask, jac_shiftu, jac_shiftv, jac_heading = self.grd2cam2world2sat(shift_u, shift_v, heading, level,
                                    satmap_sidelength, require_jac, gt_depth)
        elif self.args.grd2sat == 'decouple':
            uv, mask, jac_shiftu, jac_shiftv, jac_heading = self.grd2cam2world2sat2(shift_u, shift_v, heading, level,
                                                                                  satmap_sidelength, require_jac,
                                                                                  gt_depth)

        B, grd_H, grd_W, _ = uv.shape
        if require_jac:
            jac = torch.stack([jac_shiftu, jac_shiftv, jac_heading], dim=0)  # [3, B, H, W, 2]

            # jac = jac.reshape(3, -1, grd_H, grd_W, 2)
        else:
            jac = None

        # print('==================')
        # print(jac.shape)
        # print('==================')

        sat_f_trans, new_jac, proj_mask = grid_sample(sat_f,
                                                      uv,  # uv [1, 32, 128, 2]; [32, 128]=[w,h], [2]=uv
                                                      jac) # jac [3, 1, 32, 128, 2]; 3: jac_shiftu,v,heading
        sat_f_trans = sat_f_trans * mask[:, None, :, :]
        if require_jac:
            new_jac = new_jac * mask[None, :, None, :, :]

        if sat_c is not None:
            sat_c_trans, _, proj_c_mask = grid_sample(sat_c, uv)
            sat_c_trans = sat_c_trans * mask[:, None, :, :]
        else:
            sat_c_trans = None

        return sat_f_trans, sat_c_trans, new_jac, uv * mask[:, :, :, None], mask, proj_mask


    def project_grd_to_map_fw(self, grd_f, sat_f, sat_c, shift_u, shift_v, heading, level,
                              gt_depth=None, pred_w=None, mode='pose'):
        '''
        Args:
            sat_f: [B, C, H, W]
            sat_c: [B, 1, H, W]
            shift_u: [B, 2]
            shift_v: [B, 2]
            heading: [B, 1]
            camera_k: [B, 3, 3]

            ori_grdH:
            ori_grdW:

        Returns:

        '''
        B, C, satmap_sidelength, _ = sat_f.size()
        A = satmap_sidelength


        if self.args.grd2sat == 'geo':
            uv, mask, _, _, _ = self.grd2cam2world2sat(shift_u, shift_v, heading, level,
                                                       satmap_sidelength, None, gt_depth)

        B, grd_H, grd_W, _ = uv.shape

        if self.args.grd_mask == 'none':
            pass
        elif self.args.grd_mask in ['half', 'halfcut']:  # default
            grd_f = grd_f * mask[:, None, :, :]
            uv = uv * mask[:, :, :, None]
        elif self.args.grd_mask == 'maxd':
            grd_f = grd_f * mask[:, None, :, :]
            uv = uv * mask[:, :, :, None]

        # forward warping
        grd_f_trans, new_jac, proj_mask = grid_sample_forward(grd_f,
                                                              uv,  # uv [1, 32, 128, 2]; [32, 128]=[w,h], [2]=uv
                                                              sat_f,
                                                              gt_depth)  # jac [3, 1, 32, 128, 2]; 3: jac_shiftu,v,heading

        # grd_f_trans, new_jac, proj_mask = grid_sample_forwardv0_3(grd_f,
        #                                                       uv,  # uv [1, 32, 128, 2]; [32, 128]=[w,h], [2]=uv
        #                                                       sat_f,
        #                                                       gt_depth)

        # sat_f_trans = sat_f_trans * mask[:, None, :, :]
        # new_jac = None
        self.proj_mask = proj_mask

        return grd_f_trans, uv


    def project_point3d(self, grd_f, sat_f, shift_u, shift_v, heading, level, gt_depth):
        '''
        Args:
            sat_f: [B, c, a, a]
            grd_f: [B, c, h, w]
            shift_u: [B, 2]
            shift_v: [B, 2]
            heading: [B, 1]
            camera_k: [B, 3, 3]
            gt_depth: [B, 1, H, W]

            ori_grdH:
            ori_grdW:

        Returns:

        '''
        B, C, A, _ = sat_f.size()
        B, C, H, W = grd_f.size()

        # get query 3D point [N, C] from monodepth
        # get ref 3D point [N, C] by projection
        # get RGB feature from query and ref
        # concat to [N, 2C]

        xyz_grd, mask = self.grd2cam2world(heading, level, gt_depth=gt_depth)
        xyz = self.world2sat(xyz_grd, shift_u, shift_v, heading)
        sat_uv = self.sat2im(xyz, sat_f)
        mask_points = mask.reshape(B, -1).contiguous().unsqueeze(dim=-1)

        valid_grd_img = (grd_f * mask).permute((0, 2, 3, 1)).contiguous()  # debugging
        valid_grd_points = valid_grd_img.reshape(B, -1, C) * mask_points


        sat_uv = sat_uv / A / 2 - 1
        sat_uv = sat_uv.reshape(B, -1, 2)
        valid_sat_points = F.grid_sample(sat_f, sat_uv[:, None], mode='bilinear', align_corners=True)
        valid_sat_points = valid_sat_points.reshape(B, C, -1).transpose(-1, -2) * mask_points

        if self.args.grd_mask == 'none':
            pass
        elif self.args.grd_mask in ['half', 'halfcut']:  # default
            mask2 = torch.zeros_like(mask)
            mask2[:, :, :H*3//4, :] = 1
            mask2_points = mask2.reshape(B, -1).contiguous().unsqueeze(dim=-1).detach()
            pass

        return valid_grd_points, valid_sat_points, xyz_grd, xyz, sat_uv

    def LM_update(self, shift_u, shift_v, theta, sat_feat_proj, sat_conf_proj, grd_feat, grd_conf, dfeat_dpose):
        '''
        Args:
            shift_u: [B, 1]
            shift_v: [B, 1]
            theta: [B, 1]
            sat_feat_proj: [B, C, H, W]
            sat_conf_proj: [B, 1, H, W]
            grd_feat: [B, C, H, W]
            grd_conf: [B, 1, H, W]
            dfeat_dpose: [3, B, C, H, W]

        Returns:

        '''
        if self.args.rotation_range == 0:
            dfeat_dpose = dfeat_dpose[:2, ...]
        elif self.args.shift_range_lat == 0 and self.args.shift_range_lon == 0:
            dfeat_dpose = dfeat_dpose[2:, ...]

        dfeat_dpose = dfeat_dpose.permute(3, 0, 2, 1).contiguous()
        N, B, C, P = dfeat_dpose.shape
        if self.args.train_damping:
            # damping = self.damping
            min_, max_ = -6, 5
            damping = 10.**(min_ + self.damping.sigmoid()*(max_ - min_))
        else:
            damping = (self.args.damping * torch.ones(size=(1, N), dtype=torch.float32, requires_grad=True)).to(
                dfeat_dpose.device)

        if self.args.dropout > 0:
            inds = np.random.permutation(np.arange(H * W))[: H*W//2]
            dfeat_dpose = dfeat_dpose.reshape(N, B, C, -1)[:, :, :, inds].reshape(N, B, -1)
            sat_feat_proj = sat_feat_proj.reshape(B, C, -1)[:, :, inds].reshape(B, -1)
            grd_feat = grd_feat.reshape(B, C, -1)[:, :, inds].reshape(B, -1)
            sat_conf_proj = sat_conf_proj.reshape(B, -1)[:, inds]
            grd_conf = grd_conf.reshape(B, -1)[:, inds]
        else:
            dfeat_dpose = dfeat_dpose.reshape(N, B, -1)
            sat_feat_proj = sat_feat_proj.reshape(B, -1)
            grd_feat = grd_feat.reshape(B, -1)
            sat_conf_proj = sat_conf_proj.reshape(B, -1)
            grd_conf = grd_conf.reshape(B, -1)

        sat_feat_norm = torch.norm(sat_feat_proj, p=2, dim=-1)
        sat_feat_norm = torch.maximum(sat_feat_norm, 1e-6 * torch.ones_like(sat_feat_norm))
        sat_feat_proj = sat_feat_proj / sat_feat_norm[:, None]
        dfeat_dpose = dfeat_dpose / sat_feat_norm[None, :, None]  # [N, B, D]

        grd_feat_norm = torch.norm(grd_feat, p=2, dim=-1)
        grd_feat_norm = torch.maximum(grd_feat_norm, 1e-6 * torch.ones_like(grd_feat_norm))
        grd_feat = grd_feat / grd_feat_norm[:, None]


        r = sat_feat_proj - grd_feat  # [B, D]

        if self.using_weight:
            # weight = (sat_conf_proj * grd_conf).repeat(1, C, 1, 1).reshape(B, C * H * W)
            weight = (grd_conf[:, None, :]).repeat(1, C, 1).reshape(B, -1)
        else:
            weight = torch.ones([B, grd_feat.shape[-1]], dtype=torch.float32, device=shift_u.device, requires_grad=True)

        J = dfeat_dpose.permute(1, 2, 0)  # [B, C*H*W, #pose]
        temp = J.transpose(1, 2) * weight.unsqueeze(dim=1)
        Hessian = temp @ J  # [B, #pose, #pose]
        # print('===================')
        # print('Hessian.shape', Hessian.shape)
        if self.args.use_hessian:
            diag_H = torch.diag_embed(torch.diagonal(Hessian, dim1=1, dim2=2))  # [B, 3, 3]
            # print('diag_H.shape', diag_H.shape)
        else:
            diag_H = torch.eye(Hessian.shape[-1], requires_grad=True).unsqueeze(dim=0).repeat(B, 1, 1).to(
                Hessian.device)
        # print('Hessian + damping * diag_H.shape ', (Hessian + damping * diag_H).shape)
        delta_pose = - torch.inverse(Hessian + damping * diag_H) \
                     @ temp @ r.reshape(B, -1, 1)

        if self.args.rotation_range == 0:
            shift_u_new = shift_u + delta_pose[:, 0:1, 0]
            shift_v_new = shift_v + delta_pose[:, 1:2, 0]
            theta_new = theta
        elif self.args.shift_range_lat == 0 and self.args.shift_range_lon == 0:
            theta_new = theta + delta_pose[:, 0:1, 0]
            shift_u_new = shift_u
            shift_v_new = shift_v
        else:
            shift_u_new = shift_u + delta_pose[:, 0:1, 0]
            shift_v_new = shift_v + delta_pose[:, 1:2, 0]
            theta_new = theta + delta_pose[:, 2:3, 0]

            rand_u = torch.distributions.uniform.Uniform(-1, 1).sample([B, 1]).to(shift_u.device)
            rand_v = torch.distributions.uniform.Uniform(-1, 1).sample([B, 1]).to(shift_u.device)
            rand_u.requires_grad = True
            rand_v.requires_grad = True
            shift_u_new = torch.where((shift_u_new > -2.5) & (shift_u_new < 2.5), shift_u_new, rand_u)
            shift_v_new = torch.where((shift_v_new > -2.5) & (shift_v_new < 2.5), shift_v_new, rand_v)
            # shift_u_new = torch.where((shift_u_new > -2) & (shift_u_new < 2), shift_u_new, rand_u)
            # shift_v_new = torch.where((shift_v_new > -2) & (shift_v_new < 2), shift_v_new, rand_v)

            if torch.any(torch.isnan(theta_new)):
                print('theta_new is nan')
                print(theta, delta_pose[:, 2:3, 0], Hessian)

        return shift_u_new, shift_v_new, theta_new


    def forward(self, sat_map, grd_img_left, left_camera_k, gt_shiftu=None, gt_shiftv=None, gt_heading=None, mode='train',
                file_name=None, gt_depth=None, loop=0, level_first=0, grd_mask=None):
        '''
        :param sat_map: [B, C, A, A] A--> sidelength
        :param grd_img_left: [B, C, H, W]
        :return:
        '''

        return self.forward_level_first(sat_map, grd_img_left, left_camera_k, gt_shiftu, gt_shiftv, gt_heading, mode,
                                       file_name, gt_depth, loop, grd_mask)


    def forward_level_first(self, sat_map, grd_img_left, left_camera_k, gt_shiftu=None, gt_shiftv=None, gt_heading=None, mode='train',
                file_name=None, gt_depth=None, loop=0, grd_mask=None):
        '''
        :param sat_map: [B, C, A, A] A--> sidelength
        :param grd_img_left: [B, C, H, W]
        :return:
        '''

        B, _, ori_grdH, ori_grdW = grd_img_left.shape
        B, _, ori_A, _ = sat_map.size()

        # if self.args.estimate_depth != 'none':
        #     sat_feat_list, sat_conf_list, sat_depth_list = self.SatFeatureNet(sat_map)
        #     grd_feat_list, grd_conf_list, grd_depth_list = self.GrdFeatureNet(grd_img_left)
        #
        # else:
        sat_feat_list, sat_conf_list = self.SatFeatureNet(sat_map)
        grd_feat_list, grd_conf_list = self.GrdFeatureNet(grd_img_left)

        shift_u = torch.zeros([B, 1], dtype=torch.float32, requires_grad=True, device=sat_map.device)
        shift_v = torch.zeros([B, 1], dtype=torch.float32, requires_grad=True, device=sat_map.device)
        heading = torch.zeros([B, 1], dtype=torch.float32, requires_grad=True, device=sat_map.device)

        grd_xyz_pts, grd_uv_pts = gt_depth[0].to(sat_map.device).float(), gt_depth[1].to(sat_map.device)[..., :-1].float()

        # grd_uv_pts
        grd_uv_pts_norm, _ = self.normalize_pts(grd_uv_pts, size=[ori_grdH, ori_grdW])

        shift_us_all = []
        shift_vs_all = []
        headings_all = []
        pred_xyz_dict = defaultdict()
        pred_uv_dict = defaultdict()

        additional_loss = torch.tensor(0., dtype=torch.float32).to(sat_map.device)

        for level in range(len(sat_feat_list)):

            shift_us = []
            shift_vs = []
            headings = []

            sat_feat = sat_feat_list[level]
            sat_conf = sat_conf_list[level]
            grd_feat = grd_feat_list[level]
            grd_conf = grd_conf_list[level]

            b, c, h, w = grd_feat.size()

            grd_feat_pts = F.grid_sample(grd_feat, grd_uv_pts_norm[:, None], mode='bilinear', align_corners=True)  # [b, N, c]
            grd_feat_pts = grd_feat_pts.reshape(b, c, -1).transpose(-1, -2)
            grd_feat_pts = grd_feat_pts * grd_mask.unsqueeze(dim=-1).float()

            grd_conf_pts = F.grid_sample(grd_conf, grd_uv_pts_norm[:, None], mode='bilinear', align_corners=True)  # [b, N, c]
            grd_conf_pts = grd_conf_pts.reshape(b, 1, -1).transpose(-1, -2)
            grd_conf_pts = grd_conf_pts * grd_mask.unsqueeze(dim=-1).float()

            for iter in range(self.N_iters):

                A = sat_feat.shape[-1]
                b, c, a, _ = sat_feat.size()

                sat_uv_pts, sat_xyz_pts, J_p2D_T = self.world2sat2im(grd_xyz_pts, sat_feat, shift_u, shift_v, heading, require_jac=self.args.require_jac)

                # get feature corresponding to the LiDAR points
                b, c, a, a = sat_feat.size()
                sat_uv_pts_norm, mask, sat_feat_pts, J_f_p2D = self.interpolate_feature(sat_uv_pts, sat_feat)
                _, mask_conf, sat_conf_pts, _ = self.interpolate_feature(sat_uv_pts, sat_conf)

                #################################
                ########### masking #############
                #################################

                mask = mask.float().detach()

                sat_feat_pts = sat_feat_pts * mask
                sat_conf_pts = sat_conf_pts * mask

                if self.args.require_jac:
                    J = J_f_p2D @ J_p2D_T
                    J = J * mask.unsqueeze(dim=-1)
                else:
                    J = None

                if self.args.Optimizer == 'LM':
                    shift_u_new, shift_v_new, heading_new = self.LM_update(shift_u, shift_v, heading,
                                                            sat_feat_pts,
                                                            sat_conf_pts,
                                                            grd_feat_pts,
                                                            grd_conf_pts,
                                                            J)

                shift_us.append(shift_u_new[:, 0])  # [B]
                shift_vs.append(shift_v_new[:, 0])  # [B]
                headings.append(heading_new[:, 0])  # [B]

                shift_u = shift_u_new.clone()
                shift_v = shift_v_new.clone()
                heading = heading_new.clone()

                pred_xyz_dict[level] = sat_xyz_pts
                pred_uv_dict[level] = sat_uv_pts


            shift_us_all.append(torch.stack(shift_us, dim=1))  # [B, Level]
            shift_vs_all.append(torch.stack(shift_vs, dim=1))  # [B, Level]
            headings_all.append(torch.stack(headings, dim=1))  # [B, Level]

        shift_lats = torch.stack(shift_vs_all, dim=1)  # [B, N_iters, Level]
        shift_lons = torch.stack(shift_us_all, dim=1)  # [B, N_iters, Level]
        thetas = torch.stack(headings_all, dim=1)  # [B, N_iters, Level]

        # for analysis
        self.shift_lats = shift_lats.detach()
        self.shift_lons = shift_lons.detach()
        self.thetas = thetas.detach()


        if mode == 'train':
            if self.args.rotation_range == 0:
                coe_heading = 0
            else:
                coe_heading = self.args.coe_heading

            # gt sat xyz
            with torch.no_grad():
                gt_sat_xyz = self.world2sat(grd_xyz_pts, gt_shiftu, gt_shiftv, gt_heading)
                gt_sat_uv = self.sat2im(gt_sat_xyz, sat_feat)

            loss, loss_decrease, shift_lat_decrease, shift_lon_decrease, thetas_decrease, loss_last, \
            shift_lat_last, shift_lon_last, theta_last, \
            L1_loss, L2_loss, L3_loss, L4_loss \
                = loss_func(self.args.loss_method,
                            shift_lats, shift_lons, thetas, gt_shiftv[:, 0], gt_shiftu[:, 0], gt_heading[:, 0],
                            pred_xyz_dict, gt_sat_xyz, pred_uv_dict, gt_sat_uv,
                            self.args.coe_shift_lat, self.args.coe_shift_lon, coe_heading,
                            awl=self.awl if hasattr(self, 'awl') else None)

            smooth_loss = None
            loss = loss + additional_loss
            # additional_loss = self.NNrefine.additional_loss

            return loss, loss_decrease, shift_lat_decrease, shift_lon_decrease, thetas_decrease, loss_last, \
                   shift_lat_last, shift_lon_last, theta_last, \
                   L1_loss, L2_loss, L3_loss, L4_loss, grd_conf_list, smooth_loss, additional_loss
        else:
            return shift_lats[:, -1, -1], shift_lons[:, -1, -1], thetas[:, -1, -1]

    def get_warp_sat2real(self, satmap_sidelength):
        # satellite: u:east , v:south from bottomleft and u_center: east; v_center: north from center
        # realword: X: south, Y:down, Z: east   origin is set to the ground plane

        # meshgrid the sat pannel
        i = j = torch.arange(0, satmap_sidelength).cuda()  # to(self.device)
        ii, jj = torch.meshgrid(i, j)  # i:h,j:w

        # uv is coordinate from top/left, v: south, u:east
        uv = torch.stack([jj, ii], dim=-1).float()  # shape = [satmap_sidelength, satmap_sidelength, 2]

        # sat map from top/left to center coordinate
        u0 = v0 = satmap_sidelength // 2
        uv_center = uv - torch.tensor(
            [u0, v0]).cuda()  # .to(self.device) # shape = [satmap_sidelength, satmap_sidelength, 2]

        # affine matrix: scale*R
        meter_per_pixel = utils.get_meter_per_pixel()
        meter_per_pixel *= utils.get_process_satmap_sidelength() / satmap_sidelength
        R = torch.tensor([[0, 1], [1, 0]]).float().cuda()  # to(self.device) # u_center->z, v_center->x
        Aff_sat2real = meter_per_pixel * R  # shape = [2,2]

        # Trans matrix from sat to realword
        XZ = torch.einsum('ij, hwj -> hwi', Aff_sat2real,
                          uv_center)  # shape = [satmap_sidelength, satmap_sidelength, 2]

        Y = torch.zeros_like(XZ[..., 0:1])
        ones = torch.ones_like(Y)
        sat2realwap = torch.cat([XZ[:, :, :1], Y, XZ[:, :, 1:], ones], dim=-1)  # [sidelength,sidelength,4]

        return sat2realwap

    def seq_warp_real2camera(self, ori_shift_u, ori_shift_v, ori_heading, XYZ_1, ori_camera_k, grd_H, grd_W, ori_grdH,
                             ori_grdW, require_jac=True):
        # realword: X: south, Y:down, Z: east
        # camera: u:south, v: down from center (when heading east, need to rotate heading angle)
        # XYZ_1:[H,W,4], heading:[B,1], camera_k:[B,3,3], shift:[B,2]
        B = ori_heading.shape[0]
        shift_u_meters = self.args.shift_range_lon * ori_shift_u
        shift_v_meters = self.args.shift_range_lat * ori_shift_v
        heading = ori_heading * self.args.rotation_range / 180 * np.pi

        cos = torch.cos(-heading)
        sin = torch.sin(-heading)
        zeros = torch.zeros_like(cos)
        ones = torch.ones_like(cos)
        R = torch.cat([cos, zeros, -sin, zeros, ones, zeros, sin, zeros, cos], dim=-1)  # shape = [B,9]
        R = R.view(B, 3, 3)  # shape = [B,3,3]

        camera_height = utils.get_camera_height()
        # camera offset, shift[0]:east,Z, shift[1]:north,X
        height = camera_height * torch.ones_like(shift_u_meters)
        T = torch.cat([shift_v_meters, height, -shift_u_meters], dim=-1)  # shape = [B, 3]
        T = torch.unsqueeze(T, dim=-1)  # shape = [B,3,1]
        # T = torch.einsum('bij, bjk -> bik', R, T0)
        # T = R @ T0

        # P = K[R|T]
        camera_k = ori_camera_k.clone()
        camera_k[:, :1, :] = ori_camera_k[:, :1,
                             :] * grd_W / ori_grdW  # original size input into feature get network/ output of feature get network
        camera_k[:, 1:2, :] = ori_camera_k[:, 1:2, :] * grd_H / ori_grdH
        # P = torch.einsum('bij, bjk -> bik', camera_k, torch.cat([R, T], dim=-1)).float()  # shape = [B,3,4]
        P = camera_k @ torch.cat([R, T], dim=-1)

        # uv1 = torch.einsum('bij, hwj -> bhwi', P, XYZ_1)  # shape = [B, H, W, 3]
        uv1 = torch.sum(P[:, None, None, :, :] * XYZ_1[None, :, :, None, :], dim=-1)
        # only need view in front of camera ,Epsilon = 1e-6
        uv1_last = torch.maximum(uv1[:, :, :, 2:], torch.ones_like(uv1[:, :, :, 2:]) * 1e-6)
        uv = uv1[:, :, :, :2] / uv1_last  # shape = [B, H, W, 2]

        # ------ start computing jacobian ----- denote shift[:, 0] as x, shift[:, 1] as y below ----
        if require_jac:
            mask = torch.greater(uv1_last, torch.ones_like(uv1[:, :, :, 2:]) * 1e-6)
            dT_dx = self.args.shift_range_lon * torch.tensor([0., 0., -1.], dtype=torch.float32,
                                                             device=ori_shift_u.device, requires_grad=True).view(1, 3,
                                                                                                                 1).repeat(
                B, 1, 1)
            dT_dy = self.args.shift_range_lat * torch.tensor([1., 0., 0.], dtype=torch.float32,
                                                             device=ori_shift_u.device, requires_grad=True).view(1, 3,
                                                                                                                 1).repeat(
                B, 1, 1)
            T_zeros = torch.zeros([B, 3, 1], dtype=torch.float32, device=ori_shift_u.device, requires_grad=True)
            dR_dtheta = self.args.rotation_range / 180 * np.pi * torch.cat(
                [sin, zeros, cos, zeros, zeros, zeros, -cos, zeros, sin], dim=-1).view(B, 3, 3)
            R_zeros = torch.zeros([B, 3, 3], dtype=torch.float32, device=ori_shift_u.device, requires_grad=True)
            dP_dx = camera_k @ torch.cat([R_zeros, dT_dx], dim=-1)  # [B, 3, 4]
            dP_dy = camera_k @ torch.cat([R_zeros, dT_dy], dim=-1)  # [B, 3, 4]
            dP_dtheta = camera_k @ torch.cat([dR_dtheta, T_zeros], dim=-1)  # [B, 3, 4]
            duv1_dx = torch.sum(dP_dx[:, None, None, :, :] * XYZ_1[None, :, :, None, :], dim=-1)
            duv1_dy = torch.sum(dP_dy[:, None, None, :, :] * XYZ_1[None, :, :, None, :], dim=-1)
            duv1_dtheta = torch.sum(dP_dtheta[:, None, None, :, :] * XYZ_1[None, :, :, None, :], dim=-1)
            # duv1_dx = torch.einsum('bij, hwj -> bhwi', camera_k @ torch.cat([R_zeros, R @ dT0_dx], dim=-1), XYZ_1)
            # duv1_dy = torch.einsum('bij, hwj -> bhwi', camera_k @ torch.cat([R_zeros, R @ dT0_dy], dim=-1), XYZ_1)
            # duv1_dtheta = torch.einsum('bij, hwj -> bhwi', camera_k @ torch.cat([dR_dtheta, dR_dtheta @ T0], dim=-1), XYZ_1)

            duv_dx = duv1_dx[..., 0:2] / uv1_last - uv1[:, :, :, :2] * duv1_dx[..., 2:] / (uv1_last ** 2)
            duv_dy = duv1_dy[..., 0:2] / uv1_last - uv1[:, :, :, :2] * duv1_dy[..., 2:] / (uv1_last ** 2)
            duv_dtheta = duv1_dtheta[..., 0:2] / uv1_last - uv1[:, :, :, :2] * duv1_dtheta[..., 2:] / (uv1_last ** 2)

            duv_dx1 = torch.where(mask, duv_dx, torch.zeros_like(duv_dx))
            duv_dy1 = torch.where(mask, duv_dy, torch.zeros_like(duv_dy))
            duv_dtheta1 = torch.where(mask, duv_dtheta, torch.zeros_like(duv_dtheta))

            return uv, duv_dx1, duv_dy1, duv_dtheta1, mask

            # duv_dshift = torch.stack([duv_dx1, duv_dy1], dim=0)  # [ 2(pose_shift), B, H, W, 2(coordinates)]
            # duv_dtheta1 = duv_dtheta1.unsqueeze(dim=0) # [ 1(pose_heading), B, H, W, 2(coordinates)]
            # return uv, duv_dshift, duv_dtheta1, mask

            # duv1_dshift = torch.stack([duv1_dx, duv1_dy], dim=0)
            # duv1_dtheta = duv1_dtheta.unsqueeze(dim=0)
            # return uv1, duv1_dshift, duv1_dtheta, mask
        else:
            mask = torch.greater(uv1_last, torch.ones_like(uv1[:, :, :, 2:]) * 1e-6)
            mask = torch.squeeze(mask, dim=-1)
            return uv, None, None, None, mask            # return uv1

    def project_grd_to_map(self, grd_f, grd_c, shift_u, shift_v, heading, camera_k, satmap_sidelength, ori_grdH,
                           ori_grdW):
        # inputs:
        #   grd_f: ground features: B,C,H,W
        #   shift: B, S, 2
        #   heading: heading angle: B,S
        #   camera_k: 3*3 K matrix of left color camera : B*3*3
        # return:
        #   grd_f_trans: B,S,E,C,satmap_sidelength,satmap_sidelength

        B, C, H, W = grd_f.size()

        XYZ_1 = self.get_warp_sat2real(satmap_sidelength)  # [ sidelength,sidelength,4]

        if self.args.proj == 'geo':
            uv, jac_shiftu, jac_shiftv, jac_heading, mask = self.seq_warp_real2camera(shift_u, shift_v, heading, XYZ_1,
                                                                                      camera_k, H, W, ori_grdH,
                                                                                      ori_grdW,
                                                                                      require_jac=self.args.require_jac)  # [B, S, E, H, W,2]
            # [B, H, W, 2], [2, B, H, W, 2], [1, B, H, W, 2]

        elif self.args.proj == 'nn':
            uv, jac_shiftu, jac_shiftv, jac_heading, mask = self.inplane_grd_to_map(shift_u, shift_v, heading,
                                                                                    satmap_sidelength, require_jac=True)

        if self.args.require_jac:
            jac = torch.stack([jac_shiftu, jac_shiftv, jac_heading], dim=0) # [3, B, H, W, 2]
        else:
            jac = None

        grd_f_trans, new_jac, proj_mask = grid_sample(grd_f, uv, jac)
        # [B,C,sidelength,sidelength], [3, B, C, sidelength, sidelength]
        if grd_c is not None:
            grd_c_trans, _, proj_c_mask = grid_sample(grd_c, uv)
        else:
            grd_c_trans = None

        return grd_f_trans, grd_c_trans, new_jac, uv, mask, proj_mask  # uv * mask or uv?

    def inplane_grd_to_map(self, ori_shift_u, ori_shift_v, ori_heading, satmap_sidelength, require_jac=True):

        meter_per_pixel = utils.get_meter_per_pixel()
        meter_per_pixel *= utils.get_process_satmap_sidelength() / satmap_sidelength

        B = ori_heading.shape[0]
        shift_u_pixels = self.args.shift_range_lon * ori_shift_u / meter_per_pixel
        shift_v_pixels = self.args.shift_range_lat * ori_shift_v / meter_per_pixel
        T = torch.cat([-shift_u_pixels, shift_v_pixels], dim=-1)  # [B, 2]

        heading = ori_heading * self.args.rotation_range / 180 * np.pi
        cos = torch.cos(heading)
        sin = torch.sin(heading)
        R = torch.cat([cos, -sin, sin, cos], dim=-1).view(B, 2, 2)

        i = j = torch.arange(0, satmap_sidelength).cuda()  # to(self.device)
        v, u = torch.meshgrid(i, j)  # i:h,j:w
        uv_2 = torch.stack([u, v], dim=-1).unsqueeze(dim=0).repeat(B, 1, 1, 1).float()  # [B, H, W, 2]
        uv_2 = uv_2 - satmap_sidelength / 2

        uv_1 = torch.einsum('bij, bhwj->bhwi', R, uv_2)
        uv_0 = uv_1 + T[:, None, None, :]  # [B, H, W, 2]

        uv = uv_0 + satmap_sidelength / 2
        mask = torch.ones_like(uv[..., 0])

        if require_jac:
            dT_dshiftu = self.args.shift_range_lon / meter_per_pixel \
                         * torch.tensor([-1., 0], dtype=torch.float32, device=ori_shift_u.device,
                                        requires_grad=True).view(1, 2).repeat(B, 1)
            dT_dshiftv = self.args.shift_range_lat / meter_per_pixel \
                         * torch.tensor([0., 1], dtype=torch.float32, device=ori_shift_u.device,
                                        requires_grad=True).view(1, 2).repeat(B, 1)
            dR_dtheta = self.args.rotation_range / 180 * np.pi * torch.cat(
                [-sin, -cos, cos, -sin], dim=-1).view(B, 2, 2)

            duv_dshiftu = dT_dshiftu[:, None, None, :].repeat(1, satmap_sidelength, satmap_sidelength, 1)
            duv_dshiftv = dT_dshiftv[:, None, None, :].repeat(1, satmap_sidelength, satmap_sidelength, 1)
            duv_dtheta = torch.einsum('bij, bhwj->bhwi', dR_dtheta, uv_2)

            return uv, duv_dshiftu, duv_dshiftv, duv_dtheta, mask
        else:
            return uv, mask


    def polar_transform(self, sat_feat, level):
        meters_per_pixel = self.meters_per_pixel[level]

        B, C, A, _ = sat_feat.shape

        grd_H = A // 2
        grd_W = A * 2

        v, u = torch.meshgrid(torch.arange(0, grd_H, dtype=torch.float32),
                              torch.arange(0, 4 * grd_W, dtype=torch.float32))
        v = v.to(sat_feat.device)
        u = u.to(sat_feat.device)
        theta = u / grd_W * np.pi * 2
        radius = (1 - v / grd_H) * 40 / meters_per_pixel  # set radius as 40 meters

        us = A / 2 + radius * torch.cos(np.pi / 4 - theta)
        vs = A / 2 - radius * torch.sin(np.pi / 4 - theta)

        grids = torch.stack([us, vs], dim=-1).unsqueeze(dim=0).repeat(B, 1, 1, 1)  # [B, grd_H, grd_W, 2]

        polar_sat, _, proj_mask = grid_sample(sat_feat, grids)

        return polar_sat

    def polar_coordinates(self, level):
        meters_per_pixel = self.meters_per_pixel[level]

        # B, C, A, _ = sat_feat.shape
        A = 512 / 2**(3-level)

        grd_H = A // 2
        grd_W = A * 2

        v, u = torch.meshgrid(torch.arange(0, grd_H, dtype=torch.float32),
                              torch.arange(0, 4 * grd_W, dtype=torch.float32))
        # v = v.to(sat_feat.device)
        # u = u.to(sat_feat.device)
        theta = u / grd_W * np.pi * 2
        radius = (1 - v / grd_H) * 40 / meters_per_pixel  # set radius as 40 meters

        us = A / 2 + radius * torch.cos(np.pi / 4 - theta)
        vs = A / 2 - radius * torch.sin(np.pi / 4 - theta)

        grids = torch.stack([us, vs], dim=-1).unsqueeze(dim=0)# .repeat(B, 1, 1, 1)  # [1, grd_H, grd_W, 2]

        # polar_sat, _ = grid_sample(sat_feat, grids)

        return grids

    def orien_corr(self, sat_map, grd_img_left, gt_shiftu=None, gt_shiftv=None, gt_heading=None, mode='train',
                file_name=None, gt_depth=None):
        '''
        :param sat_map: [B, C, A, A] A--> sidelength
        :param grd_img_left: [B, C, H, W]
        :return:
        '''

        B, _, ori_grdH, ori_grdW = grd_img_left.shape

        # A = sat_map.shape[-1]
        # sat_img_proj, _, _, _, _ = self.project_map_to_grd(
        #     grd_img_left, None, gt_shiftu, gt_shiftv, gt_heading, level=3, require_jac=True, gt_depth=gt_depth)
        # sat_img = transforms.ToPILImage()(sat_img_proj[0])
        # sat_img.save('sat_proj.png')
        # grd = transforms.ToPILImage()(grd_img_left[0])
        # grd.save('grd.png')
        # sat = transforms.ToPILImage()(sat_map[0])
        # sat.save('sat.png')

        sat_feat_list, sat_conf_list = self.SatFeatureNet(sat_map)

        grd_feat_list, grd_conf_list = self.GrdFeatureNet(grd_img_left)

        corr_list = []
        for level in range(len(sat_feat_list)):
            sat_feat = sat_feat_list[level]
            grd_feat = grd_feat_list[level]  # [B, C, H, W]
            B, C, H, W = grd_feat.shape
            grd_feat = F.normalize(grd_feat.reshape(B, -1)).reshape(B, -1, H, W)

            grids = self.polar_grids[level].detach().to(sat_feat.device).repeat(B, 1, 1, 1)  # [B, H, 4W, 2]
            polar_sat, _, proj_mask = grid_sample(sat_feat, grids)
            # polar_sat = self.polar_transform(sat_feat, level)
            # [B, C, H, 4W]

            degree_per_pixel = 90 / W
            n = int(np.ceil(self.args.rotation_range / degree_per_pixel))
            sat_W = polar_sat.shape[-1]
            if sat_W - W < n:
                polar_sat1 = torch.cat([polar_sat[:, :, :, -n:], polar_sat, polar_sat[:, :, :, : (n - sat_W + W)]], dim=-1)
            else:
                polar_sat1 = torch.cat([polar_sat[:, :, :, -n:], polar_sat[:, :, :, : (W + n)]], dim=-1)

            # polar_sat1 = torch.cat([polar_sat, polar_sat[:, :, :, : (W-1)]], dim=-1)
            polar_sat2 = polar_sat1.reshape(1, B*C, H, -1)
            corr = F.conv2d(polar_sat2, grd_feat, groups=B)[0, :, 0, :]  # [B, 4W]

            denominator = F.avg_pool2d(polar_sat1.pow(2), (H, W), stride=1, divisor_override=1)[:, :, 0, :]  # [B, 4W]
            denominator = torch.sum(denominator, dim=1)  # [B, 4W]
            denominator = torch.maximum(torch.sqrt(denominator), torch.ones_like(denominator) * 1e-6)
            corr = 2 - 2 * corr / denominator

            orien = torch.argmin(corr, dim=-1)  # / (4 * W) * 360  # [B]
            orien = (orien - n) * degree_per_pixel

            corr_list.append((corr, degree_per_pixel))

        if mode == 'train':

            return self.triplet_loss(corr_list, gt_heading)
        else:
            return orien

    def triplet_loss(self, corr_list, gt_heading):
        gt = gt_heading * self.args.rotation_range #/ 360

        losses = []
        for level in range(len(corr_list)):
            corr = corr_list[level][0]
            degree_per_pixel = corr_list[level][1]
            B, W = corr.shape
            gt_idx = ((W - 1)/2 + torch.round(gt[:, 0]/degree_per_pixel)).long()

            # gt_idx = (torch.round(gt[:, 0] * (W-1)) % (W-1)).long()

            pos = corr[range(B), gt_idx]  # [B]
            pos_neg = pos[:, None] - corr  # [B, W]
            loss = torch.sum(torch.log(1 + torch.exp(pos_neg * 10))) / (B * (W - 1))
            losses.append(loss)

        return torch.sum(torch.stack(losses, dim=0))


    def interpolate_feature(self, pts, feat):
        """
        pts: sat_uv,
        feat: sat_feat [b,c,a,a]
        """
        b, c, h, w = feat.size()
        # mask = self.mask_in_image(pts, [W, H], pad=pad)
        scale = torch.tensor([w, h]).to(pts)
        pts = (pts / scale) * 2 - 1
        mask_ = (pts < 1) & (pts > -1)
        mask = torch.where(mask_[..., 0] & mask_[..., 1], True, False)
        mask = mask.unsqueeze(dim=-1)

        feat_pts = F.grid_sample(feat, pts[:, None], mode='bilinear', align_corners=True)  # [b, N, c]
        feat_pts = feat_pts.reshape(b, c, -1).transpose(-1, -2)

        if self.args.require_jac:
            dxdy = torch.tensor([[1, 0], [0, 1]])[:, None].to(pts) / scale * 2
            dx, dy = dxdy.chunk(2, dim=0)
            pts_d = torch.cat([pts - dx, pts + dx, pts - dy, pts + dy], 1)
            tensor_d = torch.nn.functional.grid_sample(
                feat, pts_d[:, None], mode='bilinear', align_corners=True)
            tensor_d = tensor_d.reshape(b, c, -1).transpose(-1, -2)
            tensor_x0, tensor_x1, tensor_y0, tensor_y1 = tensor_d.chunk(4, dim=1)
            gradients = torch.stack([
                (tensor_x1 - tensor_x0) / 2, (tensor_y1 - tensor_y0) / 2], dim=-1)
        else:
            gradients = None

        return pts, mask, feat_pts, gradients


def loss_func(loss_method, shift_lats, shift_lons, thetas, gt_shift_lat, gt_shift_lon, gt_theta,
              pred_xyz_dict, gt_xyz, pred_uv_dict, gt_uv,
              coe_shift_lat=100, coe_shift_lon=100, coe_theta=100, **kwargs):
    '''
    Args:
        loss_method:
        shift_lats: [B, N_iters, Level]
        shift_lons: [B, N_iters, Level]
        thetas: [B, N_iters, Level]
        gt_shift_lat: [B]
        gt_shift_lon: [B]
        gt_theta: [B]
        coe_shift_lat:
        coe_shift_lon:
        coe_theta:
    Returns:

    '''
    B = gt_shift_lat.shape[0]
    # shift_lats = torch.stack(shift_lats_all, dim=1)  # [B, N_iters, Level]
    # shift_lons = torch.stack(shift_lons_all, dim=1)  # [B, N_iters, Level]
    # thetas = torch.stack(thetas_all, dim=1)  # [B, N_iters, Level]

    shift_lat_delta0 = torch.abs(shift_lats - gt_shift_lat[:, None, None])  # [B, N_iters, Level]
    shift_lon_delta0 = torch.abs(shift_lons - gt_shift_lon[:, None, None])  # [B, N_iters, Level]
    thetas_delta0 = torch.abs(thetas - gt_theta[:, None, None])  # [B, N_iters, level]

    shift_lat_delta = torch.mean(shift_lat_delta0, dim=0)  # [N_iters, Level]
    shift_lon_delta = torch.mean(shift_lon_delta0, dim=0)  # [N_iters, Level]
    thetas_delta = torch.mean(thetas_delta0, dim=0)  # [N_iters, level]

    shift_lat_decrease = shift_lat_delta[0] - shift_lat_delta[-1]  # [level]
    shift_lon_decrease = shift_lon_delta[0] - shift_lon_delta[-1]  # [level]
    thetas_decrease = thetas_delta[0] - thetas_delta[-1]  # [level]

    if loss_method == 0:
        losses = coe_shift_lat * shift_lat_delta + coe_shift_lon * shift_lon_delta + coe_theta * thetas_delta  # [N_iters, level]
        loss_decrease = losses[0] - losses[-1]  # [level]
        loss = torch.mean(losses)  # mean or sum
        loss_last = losses[-1]

    elif loss_method == 0.1:
        losses = coe_shift_lat * shift_lat_delta[..., -1:] + coe_shift_lon * shift_lon_delta[..., -1:] + coe_theta * thetas_delta[..., -1:]  # [N_iters, level]
        loss_decrease = losses[0] - losses[-1]  # [level]
        loss = torch.mean(losses)  # mean or sum
        loss_last = losses[-1]


    return loss, loss_decrease, shift_lat_decrease, shift_lon_decrease, thetas_decrease, loss_last, \
           shift_lat_delta[-1], shift_lon_delta[-1], thetas_delta[-1], loss, None, None, None


def scaled_barron(a, c):
    return lambda x: scaled_loss(
            x, lambda y: barron_loss(y, y.new_tensor(a)), c)


def scaled_loss(x, fn, a):
    """Apply a loss function to a tensor and pre- and post-scale it.
    Args:
        x: the data tensor, should already be squared: `x = y**2`.
        fn: the loss function, with signature `fn(x) -> y`.
        a: the scale parameter.
    Returns:
        The value of the loss, and its first and second derivatives.
    """
    a2 = a**2
    loss, loss_d1, loss_d2 = fn(x/a2)
    return loss*a2, loss_d1, loss_d2/a2


def barron_loss(x, alpha, derivatives: bool = True, eps: float = 1e-7):
    """Parameterized  & adaptive robust loss function.
    Described in:
        A General and Adaptive Robust Loss Function, Barron, CVPR 2019

    Contrary to the original implementation, assume the the input is already
    squared and scaled (basically scale=1). Computes the first derivative, but
    not the second (TODO if needed).
    """
    loss_two = x
    loss_zero = 2 * torch.log1p(torch.clamp(0.5*x, max=33e37))

    # The loss when not in one of the above special cases.
    # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
    beta_safe = torch.abs(alpha - 2.).clamp(min=eps)
    # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
    alpha_safe = torch.where(
        alpha >= 0, torch.ones_like(alpha), -torch.ones_like(alpha))
    alpha_safe = alpha_safe * torch.abs(alpha).clamp(min=eps)

    loss_otherwise = 2 * (beta_safe / alpha_safe) * (
        torch.pow(x / beta_safe + 1., 0.5 * alpha) - 1.)

    # Select which of the cases of the loss to return.
    loss = torch.where(
        alpha == 0, loss_zero,
        torch.where(alpha == 2, loss_two, loss_otherwise))
    dummy = torch.zeros_like(x)

    if derivatives:
        loss_two_d1 = torch.ones_like(x)
        loss_zero_d1 = 2 / (x + 2)
        loss_otherwise_d1 = torch.pow(x / beta_safe + 1., 0.5 * alpha - 1.)
        loss_d1 = torch.where(
            alpha == 0, loss_zero_d1,
            torch.where(alpha == 2, loss_two_d1, loss_otherwise_d1))

        return loss, loss_d1, dummy
    else:
        return loss, dummy, dummy