import numpy as np
import os
from PIL import Image
from torch.utils.data import Dataset

import torch
from utils import utils
import torchvision.transforms.functional as TF
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torchvision import transforms
import matplotlib.pyplot as plt

from dataLoader.kitti_utils import generate_depth_map, project_lidar2im, project_lidar2camim, project_lidar2camim_all

root_dir = '/ws/data/kitti-vo'

test_csv_file_name = 'test.csv'
ignore_csv_file_name = 'ignore.csv'
satmap_dir = 'satmap'
grdimage_dir = 'raw_data' # 'raw_data'
left_color_camera_dir = 'image_02/data'
right_color_camera_dir = 'image_03/data'
oxts_dir = 'oxts/data'

GrdImg_H = 256  # 256 # original: 375 #224, 256
GrdImg_W = 1024  # 1024 # original:1242 #1248, 1024
GrdOriImg_H = 375
GrdOriImg_W = 1242
num_thread_workers = 4

train_file = '/ws/external/dataLoader/train_files.txt'
test1_file = '/ws/external/dataLoader/test1_files.txt'
test2_file = '/ws/external/dataLoader/test2_files.txt'


def imsave(image, folder, name):
    root = f"/ws/external/debug_images/{folder}"
    os.makedirs(root, exist_ok=True)

    if isinstance(image, Image.Image):
        image.save(root + f'/{name}.png')
        return
    image = (image - image.min()) / (image.max() - image.min())
    image = image.cpu().detach().numpy()
    image = (image * 255).astype(np.uint8)
    image = np.transpose(image, (1, 2, 0))
    plt.imsave(root + f'/{name}.png', np.asarray(image))


class SatGrdDataset(Dataset):
    def __init__(self, args, root, file,
                 transform=None, shift_range_lat=20, shift_range_lon=20, rotation_range=10,
                 use_gt_depth=False, split='train'):
        self.args = args
        self.root = root
        self.split = split

        self.meter_per_pixel = utils.get_meter_per_pixel(scale=1)
        self.shift_range_meters_lat = shift_range_lat  # in terms of meters
        self.shift_range_meters_lon = shift_range_lon  # in terms of meters
        self.shift_range_pixels_lat = shift_range_lat / self.meter_per_pixel  # shift range is in terms of meters
        self.shift_range_pixels_lon = shift_range_lon / self.meter_per_pixel  # shift range is in terms of meters

        # self.shift_range_meters = shift_range  # in terms of meters

        self.rotation_range = rotation_range  # in terms of degree
        self.use_gt_depth = use_gt_depth
        self.grdH = args.grdH
        self.grdW = args.grdW

        self.skip_in_seq = 2  # skip 2 in sequence: 6,3,1~
        if transform != None:
            self.satmap_transform = transform[0]
            self.grdimage_transform = transform[1]

        self.pro_grdimage_dir = 'raw_data'
        self.satmap_dir = satmap_dir

        with open(file, 'r') as f:
            file_name = f.readlines()
        self.file_name = [file[:-1] for file in file_name]

        # if self.args.epochs == 30:
        #     self.curriculum = [0.25, 0.25, 0.5, 0.5, 0.5, 0.5, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75, 0.75,
        #                    1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,]
        # elif self.args.epochs == 10:
        #     self.curriculum = [0.25, 0.5, 0.5, 0.75, 0.75, 1.0, 1.0, 1.0, 1.0, 1.0]

        # self.epoch = 0

    def __len__(self):
        return len(self.file_name)

    def get_file_list(self):
        return self.file_name

    def __getitem__(self, idx):
        # read cemera k matrix from camera calibration files, day_dir is first 10 chat of file name

        if self.split == 'train':
            file_name = self.file_name[idx]
        else:
            line = self.file_name[idx]
            file_name, gt_shift_x, gt_shift_y, gt_heading = line.split(' ')

        day_dir = file_name[:10]
        drive_dir = file_name[:38]
        image_no = file_name[38:]

        # =================== read camera intrinsice for left and right cameras ====================
        calib_file_name = os.path.join(self.root, grdimage_dir, day_dir, 'calib_cam_to_cam.txt')
        with open(calib_file_name, 'r') as f:
            lines = f.readlines()
            for line in lines:
                # left color camera k matrix
                if 'P_rect_02' in line:
                    # get 3*3 matrix from P_rect_**:
                    items = line.split(':')
                    valus = items[1].strip().split(' ')
                    fx = float(valus[0]) * self.grdW / GrdOriImg_W
                    cx = float(valus[2]) * self.grdW / GrdOriImg_W
                    fy = float(valus[5]) * self.grdH / GrdOriImg_H
                    cy = float(valus[6]) * self.grdH / GrdOriImg_H
                    left_camera_k = [[fx, 0, cx], [0, fy, cy], [0, 0, 1]]
                    left_camera_k = torch.from_numpy(np.asarray(left_camera_k, dtype=np.float32))
                    # if not self.stereo:
                    break

        # =================== read satellite map ===================================
        SatMap_name = os.path.join(self.root, self.satmap_dir, file_name)
        with Image.open(SatMap_name, 'r') as SatMap:
            sat_map = SatMap.convert('RGB')

        # =================== initialize some required variables ============================
        grd_left_imgs = torch.tensor([])

        # oxt: such as 0000000000.txt
        oxts_file_name = os.path.join(self.root, grdimage_dir, drive_dir, oxts_dir,
                                      image_no.lower().replace('.png', '.txt'))
        with open(oxts_file_name, 'r') as f:
            content = f.readline().split(' ')
            # get heading
            heading = float(content[5])
            heading = torch.from_numpy(np.asarray(heading))

            left_img_name = os.path.join(self.root, self.pro_grdimage_dir, drive_dir, left_color_camera_dir,
                                         image_no.lower())
            with Image.open(left_img_name, 'r') as GrdImg:
                grd_img_left = GrdImg.convert('RGB')
                if self.grdimage_transform is not None:
                    grd_img_left = self.grdimage_transform(grd_img_left)
            grd_left_imgs = torch.cat([grd_left_imgs, grd_img_left.unsqueeze(0)], dim=0)

        sat_rot = sat_map.rotate(-heading / np.pi * 180)
        sat_align_cam = sat_rot.transform(sat_rot.size, Image.AFFINE,
                                          (1, 0, utils.CameraGPS_shift_left[0] / self.meter_per_pixel,
                                           0, 1, utils.CameraGPS_shift_left[1] / self.meter_per_pixel),
                                          resample=Image.BILINEAR)

        # the homography is defined on: from target pixel to source pixel
        # now east direction is the real vehicle heading direction

        # randomly generate shift
        if self.split == 'train':
            gt_shift_x = np.random.uniform(-1, 1)  # --> right as positive, parallel to the heading direction
            gt_shift_y = np.random.uniform(-1, 1)  # --> up as positive, vertical to the heading direction
            gt_heading = np.random.uniform(-1, 1)
        else:
            gt_shift_x = -float(gt_shift_x)  # --> right as positive, parallel to the heading direction
            gt_shift_y = -float(gt_shift_y)  # --> up as positive, vertical to the heading direction
            gt_heading = float(gt_heading)

        sat_rand_shift = \
            sat_align_cam.transform(
                sat_align_cam.size, Image.AFFINE,
                (1, 0, gt_shift_x * self.shift_range_pixels_lon,
                 0, 1, -gt_shift_y * self.shift_range_pixels_lat),
                resample=Image.BILINEAR)

        # randomly generate rotation
        # theta = np.random.uniform(-1, 1)
        sat_rand_shift_rand_rot = \
            sat_rand_shift.rotate(gt_heading * self.rotation_range)

        # debug images
        if self.args.debug:
            imsave(sat_map, self.args.save, 'sat_1orig')
            imsave(sat_rot, self.args.save, 'sat_2rot')
            imsave(sat_align_cam, self.args.save, 'sat_3align')
            imsave(sat_rand_shift, self.args.save, 'sat_4rshift')
            imsave(sat_rand_shift_rand_rot, self.args.save, 'sat_5rshift_rrot')
            imsave(grd_left_imgs[0], self.args.save, 'grd_orig')

        sat_map = TF.center_crop(sat_rand_shift_rand_rot, utils.SatMap_process_sidelength)
        # sat_map = np.array(sat_map, dtype=np.float32)

        # transform
        if self.satmap_transform is not None:
            sat_map = self.satmap_transform(sat_map)

        if self.args.debug:
            imsave(sat_map, self.args.save, 'sat_6final')

        # # =================== read grd gt depth ===================================
        if self.args.depth in ['lidar', 'both']:
            calib_dir = os.path.join(self.root, 'raw_data', day_dir)
            velo_filename = os.path.join(self.root, 'raw_data', drive_dir, 'velodyne_points/data', image_no.replace('.png', '.bin'))
            if self.args.depth_range == 'fov':
                cam_depth, im_depth = project_lidar2camim(calib_dir, velo_filename, 2, True,
                                                          self.args.max_points, self.args.max_depth)
                mask = np.ones((cam_depth.shape[0]), dtype=bool)
            elif self.args.depth_range == 'all':
                cam_depth, im_depth, mask = project_lidar2camim_all(calib_dir, velo_filename, 2, True,
                                                              self.args.max_points, self.args.max_out_points, self.args.max_depth)

            if self.args.debug:
                from utils.visualize_utils import plot_images, plot_keypoints
                gt_depth_point = project_lidar2im(calib_dir, velo_filename, 2, True)
                plot_images([grd_img_left.permute(1, 2, 0).cpu().detach().numpy()], dpi=100)
                plot_keypoints([gt_depth_point[:, :2]], colors='lime')
                plt.savefig(f'/ws/external/debug_images/{self.args.save}/gt_depth_points.png')

        else:
            cam_depth = None
            im_depth = None

        return sat_map, left_camera_k, grd_left_imgs[0], \
               torch.tensor(-gt_shift_x, dtype=torch.float32).reshape(1), \
               torch.tensor(-gt_shift_y, dtype=torch.float32).reshape(1), \
               torch.tensor(gt_heading, dtype=torch.float32).reshape(1), \
               cam_depth, \
               im_depth, \
               mask, \
               file_name


def load_data(args, batch_size, shift_range_lat=20, shift_range_lon=20, rotation_range=10,
                    use_gt_depth=False, split='train'):

    if split == 'train':
        shuffle = True
        file = train_file
    elif split == 'test1':
        shuffle = False
        file = test1_file
    elif split == 'test2':
        shuffle = False
        file = test2_file

    SatMap_process_sidelength = utils.get_process_satmap_sidelength()

    satmap_transform = transforms.Compose([
        transforms.Resize(size=[SatMap_process_sidelength, SatMap_process_sidelength]),
        transforms.ToTensor(),
    ])

    # Grd_h = GrdImg_H
    # Grd_w = GrdImg_W

    grdimage_transform = transforms.Compose([
        transforms.Resize(size=[args.grdH, args.grdW]),
        transforms.ToTensor(),
    ])

    dataset = SatGrdDataset(args=args, root=root_dir, file=file,
                              transform=(satmap_transform, grdimage_transform),
                              shift_range_lat=shift_range_lat,
                              shift_range_lon=shift_range_lon,
                              rotation_range=rotation_range,
                              use_gt_depth=use_gt_depth,
                              split=split
                              )

    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, pin_memory=True,
                              num_workers=num_thread_workers, drop_last=False)
    return data_loader








