import os.path

import numpy as np
from PIL import Image
from torchvision import transforms
import torch
import matplotlib.pyplot as plt

def imsave(image, folder, name, norm='minmax'):
    root = f"/ws/external/debug_images/{folder}"
    os.makedirs(root, exist_ok=True)

    if isinstance(image, Image.Image):
        image.save(root + f'/{name}.png')
        return
    if norm == 'minmax':
        image = (image - image.min()) / (image.max() - image.min())
    elif norm == '01':
        image = torch.clamp(image, 0, 1)
    elif norm == 'l2':
        image = image / image.norm(dim=(1, 2), keepdim=True, p=2)
    image = image.cpu().detach().numpy()
    image = (image * 255).astype(np.uint8)
    image = np.transpose(image, (1, 2, 0))
    if image.shape[2] == 1:
        plt.imsave(root + f'/{name}.png', np.asarray(image)[:, :, 0], cmap='gray')
    elif image.shape[2] != 3:
        image = image.mean(axis=2, keepdims=True)
        plt.imsave(root + f'/{name}.png', np.asarray(image)[:, :, 0], cmap='gray')
    else:
        plt.imsave(root + f'/{name}.png', np.asarray(image))

def get_valid_points(imr, imq, p2Dr, p2Dq):
    _, a, a = imr.size()
    _, h, w = imq.size()

    valid_imr = torch.zeros_like(imr)
    p2Dr = p2Dr.clamp(min=0, max=a-1)
    p2Dq[:, 0] = p2Dq[:, 0].clamp(min=0, max=w-1)
    p2Dq[:, 1] = p2Dq[:, 1].clamp(min=0, max=h-1)
    p2Dr_int, p2Dq_int = p2Dr.long(), p2Dq.long()
    valid_imr[:, p2Dr_int[:, 1], p2Dr_int[:, 0]] = imq[:, p2Dq_int[:, 1], p2Dq_int[:, 0]]
    # p2Dr_num, p2Dq_num = p2Dr_int.size(0), p2Dq_int.size(0)
    # plot_images([valid_imr], cmaps=matplotlib.cm.gnuplot2, titles=[str(p2Dr_num)])
    return valid_imr


def get_valid_points2(imr, imq, p2Dr):
    _, a, a = imr.size()
    n, c = imq.size()

    valid_imr = torch.zeros_like(imr)
    p2Dr = p2Dr.clamp(min=0, max=a-1)
    p2Dr_int = p2Dr.long()
    valid_imr[:, p2Dr_int[:, 1], p2Dr_int[:, 0]] = imq.transpose(1, 0)
    # p2Dr_num, p2Dq_num = p2Dr_int.size(0), p2Dq_int.size(0)
    # plot_images([valid_imr], cmaps=matplotlib.cm.gnuplot2, titles=[str(p2Dr_num)])
    return valid_imr


def save_valid_points2(imr, imq, p2Dr, save_path='intetral_pose', name='valid_points'):
    _, a, a = imr.size()
    n, c = imq.size()

    valid_imr = torch.zeros_like(imr)
    p2Dr = p2Dr.clamp(min=0, max=a-1)
    p2Dr_int = p2Dr.long()
    valid_imr[:, p2Dr_int[:, 1], p2Dr_int[:, 0]] = imq.transpose(1, 0)
    # p2Dr_num, p2Dq_num = p2Dr_int.size(0), p2Dq_int.size(0)
    # plot_images([valid_imr], cmaps=matplotlib.cm.gnuplot2, titles=[str(p2Dr_num)])
    imsave(valid_imr.mean(dim=0, keepdim=True), save_path, name=name)

    return valid_imr


def features_to_RGB(sat_feat_list, grd_feat_list, pred_feat_dict, gt_sat_feat_proj, loop=0, save_dir='./visualize/'):
    """Project a list of d-dimensional feature maps to RGB colors using PCA."""
    from sklearn.decomposition import PCA

    def reshape_normalize(x):
        '''
        Args:
            x: [B, C, H, W]

        Returns:

        '''
        B, C, H, W = x.shape
        x = x.transpose([0, 2, 3, 1]).reshape([-1, C])

        denominator = np.linalg.norm(x, axis=-1, keepdims=True)
        denominator = np.where(denominator==0, 1, denominator)
        return x / denominator

    def normalize(x):
        denominator = np.linalg.norm(x, axis=-1, keepdims=True)
        denominator = np.where(denominator == 0, 1, denominator)
        return x / denominator

    # sat_shape = []
    # grd_shape = []
    for level in range(len(sat_feat_list)):
    # for level in [len(sat_feat_list)-1]:
        flatten = []

        sat_feat = sat_feat_list[level].data.cpu().numpy()  # [B, C, H, W]
        grd_feat = grd_feat_list[level].data.cpu().numpy()  # [B, C, H, W]
        s2g_feat = [feat.data.cpu().numpy() for feat in pred_feat_dict[level]]
        # a list with length iters, each item has shape [B, C, H, W]
        gt_a2g = gt_sat_feat_proj[level].data.cpu().numpy()   # [B, C, H, W]

        B, C, A, _ = sat_feat.shape
        B, C, H, W = grd_feat.shape
        # sat_shape.append([B, C, A, A])
        # grd_shape.append([B, C, H, W])

        flatten.append(reshape_normalize(sat_feat))
        flatten.append(reshape_normalize(grd_feat))
        flatten.append(reshape_normalize(gt_a2g[:, :, H//2:, :]))

        for feat in s2g_feat:
            flatten.append(reshape_normalize(feat[:, :, H//2:, :]))

        flatten = np.concatenate(flatten[:1], axis=0)

        # if level == 0:
        pca = PCA(n_components=3)
        pca.fit(reshape_normalize(sat_feat))

        pca_grd = PCA(n_components=3)
        pca_grd.fit(reshape_normalize(grd_feat))

    # for level in range(len(sat_feat_list)):
        sat_feat = sat_feat_list[level].data.cpu().numpy()  # [B, C, H, W]
        grd_feat = grd_feat_list[level].data.cpu().numpy()  # [B, C, H, W]
        s2g_feat = [feat.data.cpu().numpy() for feat in pred_feat_dict[level]]
        # a list with length iters, each item has shape [B, C, H, W]
        gt_s2g = gt_sat_feat_proj[level].data.cpu().numpy()   # [B, C, H, W]

        B, C, A, _ = sat_feat.shape
        B, C, H, W = grd_feat.shape
        sat_feat_new = ((normalize(pca.transform(reshape_normalize(sat_feat[..., :]))) + 1 )/ 2).reshape(B, A, A, 3)
        grd_feat_new = ((normalize(pca_grd.transform(reshape_normalize(grd_feat[:, :, H//2:, :]))) + 1) / 2).reshape(B, H//2, W, 3)
        gt_s2g_new = ((normalize(pca.transform(reshape_normalize(gt_s2g[:, :, H//2:, :]))) + 1) / 2).reshape(B, H//2, W, 3)

        for idx in range(B):
            sat = Image.fromarray((sat_feat_new[idx] * 255).astype(np.uint8))
            sat = sat.resize((512, 512))
            sat.save(os.path.join(save_dir, str(loop * B + idx) + '_sat_feat' + '_level_' + str(level) + '.png'))

            grd = Image.fromarray((grd_feat_new[idx] * 255).astype(np.uint8))
            grd = grd.resize((1024, 128))
            grd.save(os.path.join(save_dir, str(loop * B + idx) + '_grd_feat' + '_level_' + str(level) + '.png'))

            s2g = Image.fromarray((gt_s2g_new[idx] * 255).astype(np.uint8))
            s2g = s2g.resize((1024, 128))
            s2g.save(os.path.join(save_dir, str(loop * B + idx) + '_s2g_gt_feat' + '_level_' + str(level) + '.png'))

        # for iter in range(len(s2g_feat)):
        for iter in [len(s2g_feat)-1]:
            feat = s2g_feat[iter]
            feat_new = ((normalize(pca.transform(reshape_normalize(feat[:, :, H//2:, :]))) + 1) / 2).reshape(B, H//2, W, 3)

            for idx in range(B):
                img = Image.fromarray((feat_new[idx] * 255).astype(np.uint8))
                img = img.resize((1024, 128))
                img.save(os.path.join(save_dir, str(loop * B + idx) + '_s2g_feat' + '_level_' + str(level)
                                      + '_iter_' + str(iter) + '.png'))

    return


def features_to_gray(sat_feat_list, grd_feat_list, pred_feat_dict, gt_feat_proj, loop=0, save_dir='./visualize/', name='sat', gtw_feat_dict=dict()):
    """Project a list of d-dimensional feature maps to gray."""

    def normalize_img(img):
        # img shape [H, W], numpy
        img = (img - img.min()) / (img.max() - img.min())
        img = (img * 255).astype(np.uint8)

        return img

    for level in range(len(sat_feat_list)):
        sat_feat = sat_feat_list[level].mean(dim=1).data.cpu().numpy()  # [B, H, W]
        grd_feat = grd_feat_list[level].mean(dim=1).data.cpu().numpy()  # [B, H, W]
        pred_feat = [feat.mean(dim=1).data.cpu().numpy() for feat in pred_feat_dict[level]]
        gt_feat = gt_feat_proj[level].mean(dim=1).data.cpu().numpy()   # [B, H, W]
        gtw_feat = gtw_feat_dict[level].mean(dim=1).data.cpu().numpy() if len(gtw_feat_dict) != 0 else None

        B, A, _ = sat_feat.shape
        B, H, W = grd_feat.shape

        for idx in range(B):
            sat = Image.fromarray(normalize_img(sat_feat[idx]))
            sat = sat.resize((512, 512))
            sat.save(os.path.join(save_dir, str(loop * B + idx) + '_sat_feat' + '_level' + str(level) + '.png'))

            grd_feat[:, :H // 2, :] = 0
            grd = Image.fromarray(normalize_img(grd_feat[idx]))
            grd = grd.resize((1024, 128))
            grd.save(os.path.join(save_dir, str(loop * B + idx) + '_grd_feat' + '_level' + str(level) + '.png'))

            gt = Image.fromarray(normalize_img(gt_feat[idx]))
            gt = gt.resize((1024, 128))
            gt.save(os.path.join(save_dir,  str(loop * B + idx) + f'_{name}_feat_gt' + '_level' + str(level) + '.png'))

            if len(gtw_feat_dict) != 0:
                gtw = Image.fromarray(normalize_img(gtw_feat[idx]))
                gtw = gtw.resize((1024, 128))
                gtw.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}_feat_gtw' + '_level' + str(level) + '.png'))

        feat = pred_feat[-1]
        for idx in range(B):
            img = Image.fromarray(normalize_img(feat[idx]))
            img = img.resize((1024, 128))
            img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}_feat_pred' + '_level' + str(level) + '.png'))

    return

def ws_to_gray(ws, loop=0, save_dir='./visualize/', mean_dim=-1):
    """Project a list of d-dimensional feature maps to gray."""

    def normalize_img(img):
        # img shape [H, W], numpy
        img = (img - img.min()) / (img.max() - img.min())
        img = (img * 255).astype(np.uint8)

        return img

    for level in range(len(ws)):
        if ws[level] is None:
            return
        w_feat = ws[level].mean(dim=mean_dim).data.cpu().numpy()  # [B, H, W]
        B, _, _ = w_feat.shape

        for idx in range(B):
            w = Image.fromarray(normalize_img(w_feat[idx]))
            w = w.resize((1024, 128))
            w.save(os.path.join(save_dir, str(loop * B + idx) + 'w_feat' + '_level' + str(level) + '.png'))

    return



def RGB_iterative_pose(sat_img, grd_img, shift_lats, shift_lons, thetas, gt_shift_u, gt_shift_v, gt_theta,
                       meter_per_pixel, args, loop=0, save_dir='./visualize/'):
    '''
    This function is for KITTI dataset
    Args:
        sat_img: [B, C, H, W]
        shift_lats: [B, Niters, Level]
        shift_lons: [B, Niters, Level]
        thetas: [B, Niters, Level]
        meter_per_pixel: scalar

    Returns:

    '''

    import matplotlib.pyplot as plt

    B, _, A, _ = sat_img.shape

    shift_lats_meter = (shift_lats.data.cpu().numpy() * args.shift_range_lat).reshape([B, -1])
    shift_lons_meter = (shift_lons.data.cpu().numpy() * args.shift_range_lon).reshape([B, -1])
    gt_v_meter = (gt_shift_v.data.cpu().numpy() * args.shift_range_lat).reshape([B, -1])
    gt_u_meter = (gt_shift_u.data.cpu().numpy() * args.shift_range_lon).reshape([B, -1])

    shift_lats = (A/2 - shift_lats.data.cpu().numpy() * args.shift_range_lat / meter_per_pixel).reshape([B, -1])
    shift_lons = (A/2 + shift_lons.data.cpu().numpy() * args.shift_range_lon / meter_per_pixel).reshape([B, -1])
    thetas = (- thetas.data.cpu().numpy() * args.rotation_range).reshape([B, -1])

    # gt_shift_u = gt_shift_u * 0 # debug
    # gt_shift_v = gt_shift_v * 0 # debug

    gt_u = (A/2 + gt_shift_u.data.cpu().numpy() * args.shift_range_lon / meter_per_pixel)
    gt_v = (A/2 - gt_shift_v.data.cpu().numpy() * args.shift_range_lat / meter_per_pixel)
    gt_theta = - gt_theta.cpu().numpy() * args.rotation_range

    for idx in range(B):
        img = np.array(transforms.functional.to_pil_image(sat_img[idx], mode='RGB'))
        # img = img[64:-64, 64:-64]
        # A = img.shape[0]

        fig, ax = plt.subplots()
        ax.imshow(img)
        init = ax.scatter(A/2, A/2, color='r', s=20, zorder=2)
        update = ax.scatter(shift_lons[idx, :-1], shift_lats[idx, :-1], color='m', s=15, zorder=2)
        pred = ax.scatter(shift_lons[idx, -1], shift_lats[idx, -1], color='g', s=20, zorder=2)
        gt = ax.scatter(gt_u[idx], gt_v[idx], color='b', s=20, zorder=2)

        # ax.legend((init, update, pred, gt), ('Init', 'Intermediate', 'Pred', 'GT'),
        #           frameon=False, fontsize=14, labelcolor='r', loc=2)
        # loc=1: upper right
        # loc=3: lower left

        # if args.rotation_range>0:
        init = ax.quiver(A/2, A/2, 1, 1, angles=0, color='r', zorder=2)
        # update = ax.quiver(shift_lons[idx, :], shift_lats[idx, :], 1, 1, angles=thetas[idx, :], color='r')
        pred = ax.quiver(shift_lons[idx, -1], shift_lats[idx, -1], 1, 1, angles=thetas[idx, -1], color='g', zorder=2)
        gt = ax.quiver(gt_u[idx], gt_v[idx], 1, 1, angles=gt_theta[idx], color='b', zorder=2)
        ax.set_title(f'pred: lon{str(shift_lons_meter[idx, -1])} lat{str(shift_lats_meter[idx, -1])} rot{str(thetas[idx, -1])}'
                     f'gt: lon{gt_u_meter[idx]} lat{gt_v_meter[idx]} rot{gt_theta[idx]}')
        # ax.legend((init, pred, gt), ('pred', 'Updates', 'GT'), frameon=False, fontsize=16, labelcolor='r')
        #
        # # for i in range(shift_lats.shape[1]-1):
        # #     ax.quiver(shift_lons[idx, i], shift_lats[idx, i], shift_lons[idx, i+1], shift_lats[idx, i+1], angles='xy',
        # #               color='r')
        #
        ax.axis('off')

        plt.savefig(os.path.join(save_dir, str(loop * B + idx) + '_points' + '.png'),
                    transparent=True, dpi=A, bbox_inches='tight')
        plt.close()

        grd = transforms.functional.to_pil_image(grd_img[idx], mode='RGB')
        grd.save(os.path.join(save_dir, str(loop * B + idx) + '_grd' + '.png'))

        sat = transforms.functional.to_pil_image(sat_img[idx], mode='RGB')
        sat.save(os.path.join(save_dir, str(loop * B + idx) + '_sat' + '.png'))


def save_proj_images(img_proj, gt_img_proj, sat_map, grd_img, img_projw=None, loop=0, resize=(1024, 256),
                     save_dir='./visualize/', name='sat', gtw_proj=None):
    """Project a list of sat, img_proj, grd maps """

    # sat image
    B, C, H, W = img_proj.size()
    img_proj = img_proj.permute(0, 2, 3, 1)
    gt_img_proj = gt_img_proj.permute(0, 2, 3, 1)
    for idx in range(B):
        img = Image.fromarray((img_proj[idx].detach().cpu().numpy() * 255).astype(np.uint8))
        img = img.resize(resize)
        img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}_img_proj' + '.png'))
    for idx in range(B):
        img = Image.fromarray((gt_img_proj[idx].detach().cpu().numpy() * 255).astype(np.uint8))
        img = img.resize(resize)
        img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}_img_proj_gt' + '.png'))
    if img_projw != None:
        img_projw = img_projw.permute(0, 2, 3, 1)
        for idx in range(B):
            img = Image.fromarray((img_projw[idx].detach().cpu().numpy() * 255).astype(np.uint8))
            img = img.resize(resize)
            img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}_img_projw_sat' + '.png'))
    if gtw_proj != None:
        gtw_proj = gtw_proj.permute(0, 2, 3, 1)
        for idx in range(B):
            img = Image.fromarray((gtw_proj[idx].detach().cpu().numpy() * 255).astype(np.uint8))
            img = img.resize(resize)
            img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}_img_proj_gtw' + '.png'))

    # grd image
    B, C, H, W = grd_img.size()
    grd_img = grd_img.permute(0, 2, 3, 1)
    grd_img = grd_img.detach().cpu().numpy()
    for idx in range(B):
        img = Image.fromarray((grd_img[idx] * 255).astype(np.uint8))
        img.save(os.path.join(save_dir, str(loop * B + idx) + '_grd' + '.png'))
        grd_img[:, :H // 2, :, :] = 0
        img = Image.fromarray((grd_img[idx] * 255).astype(np.uint8))
        img.save(os.path.join(save_dir, str(loop * B + idx) + '_grd_half' + '.png'))

    # sat image
    B, C, H, W = sat_map.size()
    sat_map = sat_map.permute(0, 2, 3, 1)
    sat_map = sat_map.detach().cpu().numpy()
    for idx in range(B):
        img = Image.fromarray((sat_map[idx] * 255).astype(np.uint8))
        img.save(os.path.join(save_dir, str(loop * B + idx) + '_sat' + '.png'))

    return


def save_images(imgs, loop=0, resize=(1024, 256), save_dir='./visualize/', name='sat'):
    B, C, H, W = imgs.size()
    imgs = imgs.permute(0, 2, 3, 1)
    for idx in range(B):
        img = Image.fromarray((imgs[idx].detach().cpu().numpy() * 255).astype(np.uint8))
        img = img.resize(resize)
        img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}' + '.png'))

    return


def save_features(imgs, loop=0, resize=(1024, 256), save_dir='./visualize/', name='sat'):
    B, C, H, W = imgs.size()
    imgs = imgs.mean(dim=1)
    # imgs = imgs.permute(0, 1, 2)
    for idx in range(B):
        img = Image.fromarray((imgs[idx].detach().cpu().numpy() * 255).astype(np.uint8))
        img = img.resize(resize)
        img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}' + '.png'))

    return

def log_images(imgs, loop=0, resize=(1024, 256), name='sat'):
    import wandb
    B, C, H, W = imgs.size()
    imgs = imgs.permute(0, 2, 3, 1)
    for idx in range(B):
        img = Image.fromarray((imgs[idx].detach().cpu().numpy() * 255).astype(np.uint8))
        img = img.resize(resize)
        img = wandb.Image(img, caption=str(loop * B + idx) + f'_{name}' + '.png')
        wandb.log({str(loop * B + idx) + f'_{name}' + '.png': img})


def log_proj_images(img_proj, gt_img_proj, sat_map, grd_img, img_projw=None, loop=0, resize=(1024, 256),
                     save_dir='./visualize/', name='sat', gtw_proj=None):
    """Project a list of sat, img_proj, grd maps """
    import wandb

    # sat image
    B, C, H, W = img_proj.size()
    img_proj = img_proj.permute(0, 2, 3, 1)
    gt_img_proj = gt_img_proj.permute(0, 2, 3, 1)
    for idx in range(B):
        img = Image.fromarray((img_proj[idx].detach().cpu().numpy() * 255).astype(np.uint8))
        img = img.resize(resize)
        img = wandb.Image(img, caption=str(loop * B + idx) + f'_{name}_img_proj' + '.png')
        wandb.log({str(loop * B + idx) + f'_{name}_img_proj' + '.png': img})
        # img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}_img_proj' + '.png'))
    for idx in range(B):
        img = Image.fromarray((gt_img_proj[idx].detach().cpu().numpy() * 255).astype(np.uint8))
        img = img.resize(resize)
        img = wandb.Image(img, caption=str(loop * B + idx) + f'_{name}_img_proj_gt' + '.png')
        wandb.log({str(loop * B + idx) + f'_{name}_img_proj_gt' + '.png': img})
        # img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}_img_proj_gt' + '.png'))
    if img_projw != None:
        img_projw = img_projw.permute(0, 2, 3, 1)
        for idx in range(B):
            img = Image.fromarray((img_projw[idx].detach().cpu().numpy() * 255).astype(np.uint8))
            img = img.resize(resize)
            img = wandb.Image(img, caption=str(loop * B + idx) + f'_{name}_img_projw_sat' + '.png')
            wandb.log({str(loop * B + idx) + f'_{name}_img_projw_sat' + '.png': img})
            # img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}_img_projw_sat' + '.png'))
    if gtw_proj != None:
        gtw_proj = gtw_proj.permute(0, 2, 3, 1)
        for idx in range(B):
            img = Image.fromarray((gtw_proj[idx].detach().cpu().numpy() * 255).astype(np.uint8))
            img = img.resize(resize)
            img = wandb.Image(img, caption=str(loop * B + idx) + f'_{name}_img_proj_gtw' + '.png')
            wandb.log({str(loop * B + idx) + f'_{name}_img_proj_gtw' + '.png': img})
            # img.save(os.path.join(save_dir, str(loop * B + idx) + f'_{name}_img_proj_gtw' + '.png'))

    # grd image
    B, C, H, W = grd_img.size()
    grd_img = grd_img.permute(0, 2, 3, 1)
    grd_img = grd_img.detach().cpu().numpy()
    for idx in range(B):
        img = Image.fromarray((grd_img[idx] * 255).astype(np.uint8))
        img = wandb.Image(img, caption=str(loop * B + idx) + '_grd' + '.png')
        wandb.log({str(loop * B + idx) + '_grd' + '.png': img})
        # img.save(os.path.join(save_dir, str(loop * B + idx) + '_grd' + '.png'))

    # sat image
    B, C, H, W = sat_map.size()
    sat_map = sat_map.permute(0, 2, 3, 1)
    sat_map = sat_map.detach().cpu().numpy()
    for idx in range(B):
        img = Image.fromarray((sat_map[idx] * 255).astype(np.uint8))
        img = wandb.Image(img, caption=str(loop * B + idx) + '_sat' + '.png')
        wandb.log({str(loop * B + idx) + '_sat' + '.png': img})
        # img.save(os.path.join(save_dir, str(loop * B + idx) + '_sat' + '.png'))

    return


def RGB_iterative_pose_ford(sat_img, grd_img, shift_lats, shift_lons, thetas, gt_shift_u, gt_shift_v, gt_theta,
                       meter_per_pixel, args, loop=0, save_dir='./visualize/'):
    '''
    This function is for KITTI dataset
    Args:
        sat_img: [B, C, H, W]
        shift_lats: [B, Niters, Level]
        shift_lons: [B, Niters, Level]
        thetas: [B, Niters, Level]
        meter_per_pixel: scalar

    Returns:

    '''

    import matplotlib.pyplot as plt

    B, _, A, _ = sat_img.shape

    # A = 512 - 128

    shift_lats = (A/2 - shift_lats.data.cpu().numpy() * args.shift_range_lat / meter_per_pixel).reshape([B, -1])
    shift_lons = (A/2 - shift_lons.data.cpu().numpy() * args.shift_range_lon / meter_per_pixel).reshape([B, -1])
    thetas = (- thetas.data.cpu().numpy() * args.rotation_range).reshape([B, -1])

    gt_u = (A/2 - gt_shift_u.data.cpu().numpy() * args.shift_range_lat / meter_per_pixel)
    gt_v = (A/2 - gt_shift_v.data.cpu().numpy() * args.shift_range_lon / meter_per_pixel)
    gt_theta = - gt_theta.cpu().numpy() * args.rotation_range

    for idx in range(B):
        img = np.array(transforms.functional.to_pil_image(sat_img[idx], mode='RGB'))
        # img = img[64:-64, 64:-64]
        # A = img.shape[0]

        fig, ax = plt.subplots()
        ax.imshow(img)
        init = ax.scatter(A/2, A/2, color='r', s=20, zorder=2)
        update = ax.scatter(shift_lats[idx, :-1], shift_lons[idx, :-1], color='m', s=15, zorder=2)
        pred = ax.scatter(shift_lats[idx, -1], shift_lons[idx, -1], color='g', s=20, zorder=2)
        gt = ax.scatter(gt_u[idx], gt_v[idx], color='b', s=20, zorder=2)
        # ax.legend((init, update, pred, gt), ('Init', 'Intermediate', 'Pred', 'GT'),
        #           frameon=False, fontsize=14, labelcolor='r', loc=2)
        # loc=1: upper right
        # loc=3: lower left

        # if args.rotation_range>0:
        init = ax.quiver(A/2, A/2, 1, 1, angles=90, color='r', zorder=2)
        # update = ax.quiver(shift_lons[idx, :], shift_lats[idx, :], 1, 1, angles=thetas[idx, :], color='r')
        pred = ax.quiver(shift_lats[idx, -1], shift_lons[idx, -1], 1, 1, angles=thetas[idx, -1] + 90, color='g', zorder=2)
        gt = ax.quiver(gt_u[idx], gt_v[idx], 1, 1, angles=gt_theta[idx] + 90, color='b', zorder=2)
        # ax.legend((init, pred, gt), ('pred', 'Updates', 'GT'), frameon=False, fontsize=16, labelcolor='r')
        #
        # # for i in range(shift_lats.shape[1]-1):
        # #     ax.quiver(shift_lons[idx, i], shift_lats[idx, i], shift_lons[idx, i+1], shift_lats[idx, i+1], angles='xy',
        # #               color='r')
        #
        ax.axis('off')

        plt.savefig(os.path.join(save_dir, 'points_' + str(loop * B + idx) + '.png'),
                    transparent=True, dpi=A, bbox_inches='tight')
        plt.close()

        grd = transforms.functional.to_pil_image(grd_img[idx], mode='RGB')
        grd.save(os.path.join(save_dir, 'grd_' + str(loop * B + idx) + '.png'))

        sat = transforms.functional.to_pil_image(sat_img[idx], mode='RGB')
        sat.save(os.path.join(save_dir, 'sat_' + str(loop * B + idx) + '.png'))


def plot_images(imgs, titles=None, cmaps='gray', dpi=100, pad=.5,
                adaptive=True, autoscale=True):
    """Plot a set of images horizontally.
    Args:
        imgs: a list of NumPy or PyTorch images, RGB (H, W, 3) or mono (H, W).
        titles: a list of strings, as titles for each image.
        cmaps: colormaps for monochrome images.
        adaptive: whether the figure size should fit the image aspect ratios.
    """
    n = len(imgs)
    if not isinstance(cmaps, (list, tuple)):
        cmaps = [cmaps] * n

    if adaptive:
        ratios = [i.shape[1] / i.shape[0] for i in imgs]  # W / H
    else:
        ratios = [4/3] * n
    figsize = [sum(ratios)*4.5, 4.5]
    fig, ax = plt.subplots(
        1, n, figsize=figsize, dpi=dpi, gridspec_kw={'width_ratios': ratios})
    if n == 1:
        ax = [ax]
    for i in range(n):
        ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
        ax[i].get_yaxis().set_ticks([])
        ax[i].get_xaxis().set_ticks([])
        ax[i].set_axis_off()
        for spine in ax[i].spines.values():  # remove frame
            spine.set_visible(False)
        if titles:
            ax[i].set_title(titles[i])
        if not autoscale:
            ax[i].autoscale(False)
    fig.tight_layout(pad=pad)


def plot_keypoints(kpts, colors='lime', ps=6):
    """Plot keypoints for existing images.
    Args:
        kpts: list of ndarrays of size (N, 2).
        colors: string, or list of list of tuples (one for each keypoints).
        ps: size of the keypoints as float.
    """
    if not isinstance(colors, list):
        colors = [colors] * len(kpts)
    axes = plt.gcf().axes
    for a, k, c in zip(axes, kpts, colors):
        if k is not None:
            a.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0)
