import torch
import torch.optim as optim
import numpy as np
import os
import argparse
import time
import yaml
import scipy.io as scio
from argparse import Namespace
from tqdm import tqdm


from dataLoader.KITTI_depth_dataset import load_data
from models.models_SIBCL import SIBCL
from models.models_PID import PIDLoc
from networks.pid_optimizer import PIDOptimizer
from utils.wandb_logger import WandbLogger
# import ssl
# ssl._create_default_https_context = ssl._create_unverified_context  # for downloading pretrained VGG weights


def load_config(config_path):
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    return config


########################### ranking test ############################
def test(net_test, args, save_path, best_rank_result, epoch, wandb_logger=None, optimizer=None, mode='test1'):
    net_test.eval()

    dataloader = load_data(args, args.batch_size, args.shift_range_lat, args.shift_range_lon, args.rotation_range, args.depth, mode)

    pred_shifts = []
    pred_headings = []
    gt_shifts = []
    gt_headings = []
    wandb_features = dict()

    # analysis
    pred_lats_all = []
    pred_lons_all = []
    pred_thetas_all = []
    gt_lats_all = []
    gt_lons_all = []
    gt_thetas_all = []

    start_time = time.time()
    for i, data in enumerate(tqdm(dataloader)):
        if args.debug2 and i == 10:
            break

        sat_map, left_camera_k, grd_left_imgs, gt_shift_u, gt_shift_v, gt_heading, lidar_cam, lidar_im, mask = [item.to(device) for item in data[:-1]]

        if args.direction in ['PIDLoc', 'SIBCL'] and 'lidar' in args.depth:
            gt_depth = (lidar_cam, lidar_im)
        else:
            gt_depth = None

        if args.direction in ['PIDLoc', 'SIBCL', 'LM_S2GP']:
            shifts_lat, shifts_lon, theta = net_test(sat_map, grd_left_imgs, left_camera_k, mode='test', gt_depth=gt_depth,
                                                     grd_mask=mask)

        shifts = torch.stack([shifts_lat, shifts_lon], dim=-1)  # shifts: [B, 2]
        headings = theta.unsqueeze(dim=-1)          # headings: [B, 1]

        gt_shift = torch.cat([gt_shift_v, gt_shift_u], dim=-1)  # [B, 2]

        # if args.shift_range_lat ==0 and args.shift_range_lon==0:
        #     loss = torch.mean(headings - gt_heading)
        # else:
        loss = torch.mean(shifts_lat - gt_shift_u)
        loss.backward()  # just to release graph # D branch gradients are released

        pred_shifts.append(shifts.data.cpu().numpy())
        pred_headings.append(headings.data.cpu().numpy())
        gt_shifts.append(gt_shift.data.cpu().numpy())
        gt_headings.append(gt_heading.data.cpu().numpy())

        ## for analysis
        pred_lats_all.append(net_test.shift_lats.cpu().numpy())
        pred_lons_all.append(net_test.shift_lons.cpu().numpy())
        pred_thetas_all.append(net_test.thetas.cpu().numpy())
        gt_lats_all.append(gt_shift_v.cpu().numpy())
        gt_lons_all.append(gt_shift_u.cpu().numpy())
        gt_thetas_all.append(gt_heading.cpu().numpy())


    end_time = time.time()
    duration = (end_time - start_time)/len(dataloader)

    ##############
    ## analysis ##
    ##############
    pred_lats_all = np.concatenate(pred_lats_all, axis=0)
    pred_lons_all = np.concatenate(pred_lons_all, axis=0)
    pred_thetas_all = np.concatenate(pred_thetas_all, axis=0)
    gt_lats_all = np.concatenate(gt_lats_all, axis=0)
    gt_lons_all = np.concatenate(gt_lons_all, axis=0)
    gt_thetas_all = np.concatenate(gt_thetas_all, axis=0)

    # save results
    results_all_path = os.path.join(save_path, f'{mode}_results_all')
    np.savez(results_all_path+'.npz', pred_lats_all=pred_lats_all, pred_lons_all=pred_lons_all, pred_thetas_all=pred_thetas_all,
             gt_lats_all=gt_lats_all, gt_lons_all=gt_lons_all, gt_thetas_all=gt_thetas_all)

    B, level, iters = pred_lats_all.shape
    pred_lats_all = pred_lats_all.reshape(B, -1)
    pred_lons_all = pred_lons_all.reshape(B, -1)
    pred_thetas_all = pred_thetas_all.reshape(B, -1)

    ## variance ##
    pred_lats_error = (pred_lats_all - gt_lats_all) * args.shift_range_lat
    pred_lons_error = (pred_lons_all - gt_lons_all) * args.shift_range_lon
    pred_thetas_error = (pred_thetas_all - gt_thetas_all) * args.rotation_range

    wandb_features[f'{mode}/shift_lats_var'] = np.mean(np.var(pred_lats_error, axis=1))
    wandb_features[f'{mode}/shift_lons_var'] = np.mean(np.var(pred_lons_error, axis=1))
    wandb_features[f'{mode}/shift_thetas_var'] = np.mean(np.var(pred_thetas_error, axis=1))

    ##################################


    pred_shifts = np.concatenate(pred_shifts, axis=0) * np.array([args.shift_range_lat, args.shift_range_lon]).reshape(1, 2)
    pred_headings = np.concatenate(pred_headings, axis=0) * args.rotation_range
    gt_shifts = np.concatenate(gt_shifts, axis=0) * np.array([args.shift_range_lat, args.shift_range_lon]).reshape(1, 2)
    gt_headings = np.concatenate(gt_headings, axis=0) * args.rotation_range

    result_path = os.path.join(save_path, f'{mode}_results')
    scio.savemat(result_path + '.mat', {'gt_shifts': gt_shifts, 'gt_headings': gt_headings,
                                                         'pred_shifts': pred_shifts, 'pred_headings': pred_headings})

    distance = np.sqrt(np.sum((pred_shifts - gt_shifts) ** 2, axis=1))  # [N]
    angle_diff = np.remainder(np.abs(pred_headings - gt_headings), 360)
    idx0 = angle_diff > 180
    angle_diff[idx0] = 360 - angle_diff[idx0]
    # angle_diff = angle_diff.numpy()

    init_dis = np.sqrt(np.sum(gt_shifts ** 2, axis=1))
    init_angle = np.abs(gt_headings)

    metrics = [0.25, 0.5, 1, 3, 5]
    angles = [0.25, 0.5, 1, 3, 5]

    f = open(os.path.join(save_path, f'{mode}_results.txt'), 'a')
    f.write('====================================\n')
    f.write('       EPOCH: ' + str(epoch) + '\n')
    f.write('Time per image (second): ' + str(duration) + '\n')
    print('====================================')
    print('       EPOCH: ' + str(epoch))
    print('Time per image (second): ' + str(duration) + '\n')
    print('Validation results:')
    print('Init distance average: ', np.mean(init_dis))
    print('Pred distance average: ', np.mean(distance))
    print('Init angle average: ', np.mean(init_angle))
    print('Pred angle diff average: ', np.mean(angle_diff))
    wandb_features[f'{mode}/fps'] = duration
    wandb_features[f'{mode}/shift_rot_last'] = np.mean(angle_diff)
    wandb_features[f'{mode}/shift_rot_last_median'] = np.median(angle_diff)
    wandb_features[f'{mode}/shift_rot_last_var'] = np.var(angle_diff)
    wandb_features[f'{mode}/distance_last'] = np.mean(distance)
    wandb_features[f'{mode}/distance_last_median'] = np.median(distance)

    for idx in range(len(metrics)):
        pred = np.sum(distance < metrics[idx]) / distance.shape[0] * 100
        init = np.sum(init_dis < metrics[idx]) / init_dis.shape[0] * 100

        line = 'distance within ' + str(metrics[idx]) + ' meters (pred, init): ' + str(pred) + ' ' + str(init)
        print(line)
        f.write(line + '\n')

        wandb_features[f'{mode}/percent_dis_{metrics[idx]}m'] = pred

    print('-------------------------')
    f.write('------------------------\n')

    diff_shifts = np.abs(pred_shifts - gt_shifts)
    wandb_features[f'{mode}/shift_lat_last'] = np.mean(diff_shifts[:, 0])
    wandb_features[f'{mode}/shift_lat_last_median'] = np.median(diff_shifts[:, 0])
    wandb_features[f'{mode}/shift_lat_last_var'] = np.var(diff_shifts[:, 0])

    wandb_features[f'{mode}/shift_lon_last'] = np.mean(diff_shifts[:, 1])
    wandb_features[f'{mode}/shift_lon_last_median'] = np.median(diff_shifts[:, 1])
    wandb_features[f'{mode}/shift_lon_last_var'] = np.var(diff_shifts[:, 1])

    for idx in range(len(metrics)):
        lat_pred = np.sum(diff_shifts[:, 0] < metrics[idx]) / diff_shifts.shape[0] * 100
        init = np.sum(np.abs(gt_shifts[:, 0]) < metrics[idx]) / init_dis.shape[0] * 100

        line = 'lateral      within ' + str(metrics[idx]) + ' meters (pred, init): ' + str(lat_pred) + ' ' + str(init)
        print(line)
        f.write(line + '\n')

        lon_pred = np.sum(diff_shifts[:, 1] < metrics[idx]) / diff_shifts.shape[0] * 100
        init = np.sum(np.abs(gt_shifts[:, 1]) < metrics[idx]) / diff_shifts.shape[0] * 100

        line = 'longitudinal within ' + str(metrics[idx]) + ' meters (pred, init): ' + str(lon_pred) + ' ' + str(init)
        print(line)
        f.write(line + '\n')

        wandb_features[f'{mode}/percent_lat_{metrics[idx]}m'] = lat_pred
        wandb_features[f'{mode}/percent_lon_{metrics[idx]}m'] = lon_pred

    print('-------------------------')
    f.write('------------------------\n')

    for idx in range(len(angles)):
        pred = np.sum(angle_diff < angles[idx]) / angle_diff.shape[0] * 100
        init = np.sum(init_angle < angles[idx]) / angle_diff.shape[0] * 100
        line = 'angle within ' + str(angles[idx]) + ' degrees (pred, init): ' + str(pred) + ' ' + str(init)
        print(line)
        f.write(line + '\n')

        wandb_features[f'{mode}/percent_rot_{metrics[idx]}m'] = pred

    print('-------------------------')
    f.write('------------------------\n')

    for idx in range(len(angles)):
        pred = np.sum((angle_diff[:, 0] < angles[idx]) & (diff_shifts[:, 0] < metrics[idx])) / angle_diff.shape[0] * 100
        init = np.sum((init_angle[:, 0] < angles[idx]) & (np.abs(gt_shifts[:, 0]) < metrics[idx])) / angle_diff.shape[0] * 100
        line = 'lat within ' + str(metrics[idx]) + ' & angle within ' + str(angles[idx]) + \
               ' (pred, init): ' + str(pred) + ' ' + str(init)
        print(line)
        f.write(line + '\n')
        wandb_features[f'{mode}/pred_lat_rot_{metrics[idx]}'] = pred

    print('====================================')
    f.write('====================================\n')
    f.close()
    result = np.sum((distance < metrics[0]) & (angle_diff < angles[0])) / distance.shape[0] * 100
    wandb_logger.log_evaluate(wandb_features)

    net_test.train()

    ### save the best params
    if (result > best_rank_result):
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        if optimizer == None:
            torch.save(net_test.state_dict(), os.path.join(save_path, f'Model_best{mode[-1]}.pth'))
            if args.save_every_epoch:
                torch.save(net_test.state_dict(), os.path.join(save_path, f'Model_epoch{epoch}.pth'))
        else:
            checkpoint = {
                'epoch': epoch,
                'state_dict': net_test.state_dict(),
                'optimizer': optimizer.state_dict()
            }
            torch.save(checkpoint, os.path.join(save_path, f'Model_best{mode[-1]}.pth'))
            if args.save_every_epoch:
                torch.save(checkpoint, os.path.join(save_path, f'Model_epoch{epoch}.pth'))

    return result

def train(net, args, save_path, wandb_logger, optimizer, scheduler, **kwargs):
    bestRankResult = 0.0
    bestRankResult2 = 0.0
    wandb_features = dict()

    net.train()
    trainloader = load_data(args, args.batch_size, args.shift_range_lat, args.shift_range_lon, args.rotation_range,
                            args.depth, split='train')

    for epoch in range(args.resume, args.epochs):
        optimizer.zero_grad()

        for Loop, Data in enumerate(tqdm(trainloader, leave=True), 0):
            if Loop == 5 and args.debug:
                break
            # get the inputs
            sat_map, left_camera_k, grd_left_imgs, gt_shift_u, gt_shift_v, gt_heading, lidar_cam, lidar_im, mask = [item.to(device) for item in Data[:-1]]
            file_name = Data[-1]

            # zero the parameter gradients
            optimizer.zero_grad()

            if args.direction in ['PIDLoc', 'SIBCL'] and 'lidar' in args.depth:
                gt_depth = (lidar_cam, lidar_im)
            else:
                gt_depth = None

            if args.direction in ['PIDLoc', 'SIBCL', 'LM_S2GP']:
                loss, loss_decrease, shift_lat_decrease, shift_lon_decrease, thetas_decrease, loss_last, \
                shift_lat_last, shift_lon_last, theta_last, \
                L1_loss, L2_loss, L3_loss, L4_loss, grd_conf_list, smooth_loss, additional_loss = \
                    net(sat_map, grd_left_imgs, left_camera_k, gt_shift_u, gt_shift_v, gt_heading, mode='train', file_name=file_name,
                        gt_depth=gt_depth, loop=Loop, level_first=args.level_first, grd_mask=mask)

            loss.backward()
            wandb_features.update(log_gradient_statistics(net))

            if args.grad_clip != -1:
                torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=args.grad_clip)
                if args.base_optimizer == 'pid_adam':
                    wandb_features.update(log_gradient_statistics(pid, topic='pid_'))
                    torch.nn.utils.clip_grad_norm_(pid.parameters(), max_norm=args.grad_clip/1000)

            optimizer.step()  # update weights


            if Loop % 100 == 0:  #
                level = -1 # args.level - 1
                if args.N_iters != 1:
                    print('Epoch: ' + str(epoch) + ' Loop: ' + str(Loop) + ' Delta: Level-' + str(level) +
                          ' loss: ' + str(np.round(loss_decrease[level].item(), decimals=4)) +
                          ' lat: ' + str(np.round(shift_lat_decrease[level].item(), decimals=2)) +
                          ' lon: ' + str(np.round(shift_lon_decrease[level].item(), decimals=2)) +
                          ' rot: ' + str(np.round(thetas_decrease[level].item(), decimals=2)))

                print('Epoch: ' + str(epoch) + ' Loop: ' + str(Loop) + ' Last: Level-' + str(level) +
                      ' loss: ' + str(np.round(loss_last[level].item(), decimals=4)) +
                      ' lat: ' + str(np.round(shift_lat_last[level].item(), decimals=2)) +
                      ' lon: ' + str(np.round(shift_lon_last[level].item(), decimals=2)) +
                      ' rot: ' + str(np.round(theta_last[level].item(), decimals=2))
                      )

                # log wandb features
                wandb_features['lr'] = scheduler.get_last_lr()[0]
                wandb_features['train/loss_decrease'] = np.round(loss_decrease[level].item(), decimals=4)
                wandb_features['train/shift_lat_decrease'] = np.round(shift_lat_decrease[level].item(), decimals=2)
                wandb_features['train/shift_lon_decrease'] = np.round(shift_lon_decrease[level].item(), decimals=2)
                wandb_features['train/shift_rot_decrease'] = np.round(thetas_decrease[level].item(), decimals=2)

                wandb_features['train/loss_last'] = np.round(loss_last[level].item(), decimals=4)
                wandb_features['train/shift_lat_last'] = np.round(shift_lat_last[level].item(), decimals=2)
                wandb_features['train/shift_lon_last'] = np.round(shift_lon_last[level].item(), decimals=2)
                wandb_features['train/shift_rot_last'] = np.round(theta_last[level].item(), decimals=2)

                if args.pid_k == 'learnable':
                    wandb_features['train/kp'] = np.round(pid.kp.item(), decimals=4)
                    wandb_features['train/ki'] = np.round(pid.ki.item(), decimals=4)
                    wandb_features['train/kd'] = np.round(pid.kd.item(), decimals=4)

                wandb_logger.log_evaluate(wandb_features)

        scheduler.step()

        if not os.path.exists(save_path):
            os.makedirs(save_path)

        ### ranking test & save
        current = test(net, args, save_path, bestRankResult, epoch, wandb_logger, optimizer=optimizer, mode='test1')
        if (current > bestRankResult):
            bestRankResult = current

        current2 = test(net, args, save_path, bestRankResult2, epoch, wandb_logger, optimizer=optimizer, mode='test2')
        if (current2 > bestRankResult2):
            bestRankResult2 = current2

    print('Finished Training')

def log_gradient_statistics(model, topic=''):

    wandb_features = {}
    grad_means = []
    grad_stds = []
    grad_maxs = []
    grad_mins = []
    grad_magnitudes = []

    for name, param in model.named_parameters():
        if param.grad is not None:
            grad = param.grad.detach()
            grad_mean = grad.mean().item()
            grad_std = grad.std().item()
            grad_max = grad.max().item()
            grad_min = grad.min().item()
            grad_means.append(grad_mean)
            grad_stds.append(grad_std)
            grad_maxs.append(grad_max)
            grad_mins.append(grad_min)
            grad_magnitudes.append(grad.norm().item())
            # print(f"Layer: {name}, Mean: {grad_mean:.4f}, Std: {grad_std:.4f}, Max: {grad_max:.4f}, Min: {grad_min:.4f}")

    # print(f"Global Gradient Norm: {torch.norm(torch.tensor(grad_magnitudes)):.4f}\n")
    wandb_features[f'train/{topic}grad_mean'] = np.mean(grad_means)
    wandb_features[f'train/{topic}grad_std'] = np.mean(grad_stds)
    wandb_features[f'train/{topic}grad_max'] = np.mean(grad_maxs)
    wandb_features[f'train/{topic}grad_min'] = np.mean(grad_mins)
    wandb_features[f'train/{topic}grad_magnitude'] = np.mean(grad_magnitudes)

    return wandb_features

# def parse_args():
#     parser = argparse.ArgumentParser()
#     parser.add_argument('--config', type=str, required=True, help='Path to config.yaml')
#
#     parser.add_argument('--resume', type=int, default=0, help='resume the trained model')
#     parser.add_argument('--checkpoint', type=str, help='checkpoint model')
#     parser.add_argument('--test', action='store_true', default=False, help='test with trained model')
#     parser.add_argument('--save', '-s', type=str, required=True, default='debug', help='save_path')
#     parser.add_argument('--save_every_epoch', '-see', action='store_true', default=False, help='save_every_epoch')
#     parser.add_argument('--debug', action='store_true', help='debug to visualize images of dataloader ')
#     parser.add_argument('--debug2', action='store_true', help='checkout performance quickly')
#     parser.add_argument('--wandb', '-wb', action='store_true', help='Turn on wandb log')
#
#     ## PID parameters ##
#     parser.add_argument('--require-jac', action='store_true', default=True, help='D branch or not')
#     parser.add_argument('--require-int', action='store_true', default=True, help='I branch or not')
#     parser.add_argument('--j-norm', type=str, default='L2', choices=['L2', 'zsn3', 'none', 'clip', 'std'])
#     parser.add_argument('--j-ver', type=str, default='sum', choices=['none', 'lat', 'long', 'rot', 'trans', 'sum'])
#     parser.add_argument('--pid-k', type=str, default='learnable', choices=['const', 'learnable', 'learnable'], help='pid-k')
#     parser.add_argument('--kp', type=float, default=1.0, help='pid kp')
#     parser.add_argument('--ki', type=float, default=1.0, help='pid ki')
#     parser.add_argument('--kd', type=float, default=1.0, help='pid kd')
#     parser.add_argument('--integral', type=str, default='spatial', choices=['spatial', 'spatial2', 'temporal'],
#                         help='spatial candidate pose type')
#     parser.add_argument('--integral-num', type=int, default=5, help='total number of spatial candidate pose')
#
#     ##  architecture ##
#     parser.add_argument('--direction', type=str, default='PIDLoc', choices=['PIDLoc', 'SIBCL'], help='PIDLoc or SIBCL')
#     parser.add_argument('--encoder', type=str, default='vgg', help='encoder architecture')
#     parser.add_argument('--Optimizer', type=str, default='SPE', choices=['LM', 'SPE'], help='LM or SPE')
#     parser.add_argument('--Optimizer_input', type=str, default='resconcat', choices=['concat', 'res', 'resconcat'], help='input to the optimizer nn')
#     parser.add_argument('--level_first',  action='store_true', default=True, help='0 or 1, level first: 1 or iter first: 0')
#     parser.add_argument('--level', type=float, default=3.2, help='2, 3, 4, -1, -2, -3, -4')
#     parser.add_argument('--N_iters', type=int, default=5, help='any integer')
#
#     ## network ##
#     parser.add_argument('--point-norm', type=str, default='zsn3', choices=['L2', 'none', 'zsn', 'zsn3'],
#                         help='point feature normalization')
#     parser.add_argument('--grd2sat', type=str, default='geo', choices=['geo', 'decouple'],
#                         help='geometric transformation')
#     parser.add_argument('--pool', type=str, default='embed_aap2', choices=['gap', 'max',  'embed_aap2', 'all_ap'],
#                         help='pooling PID-branch feature in NN optimizer')
#     parser.add_argument('--hidden-dim', '-hin', type=int, default=64, help='hidden dimension of NN pose estimator')
#
#     ## optimizer ##
#     parser.add_argument('--epochs', type=int, default=30, help='number of training epochs')
#     parser.add_argument('--lr', type=float, default=1e-4, help='learning rate for network parameters')  # 1e-2
#     parser.add_argument('--lr2', type=float, default=1e-4, help='learning rate for PID coefficients')  # 1e-2
#     parser.add_argument('--batch_size', type=int, default=4, help='batch size')
#     parser.add_argument('--base-optimizer', type=str, default='adam', choices=['adam', 'sgd', 'pid_adam'],
#                         help='base optimizer: SGD, ADAM')
#     parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
#     parser.add_argument('--decay', '-wd', type=float, default=0.0005, help='Weight decay (L2 penalty).')
#     parser.add_argument('--beta1', type=float, default=0.9, help='coefficients for adam optimizer')
#     parser.add_argument('--beta2', type=float, default=0.999, help='coefficients for adam optimizer')
#     parser.add_argument('--grad_clip', type=float, default=10.0, help='grad_clip')
#
#     ## etc setting ##
#     parser.add_argument('--max-shift', type=float, default=2.5, help='max shift for u, v, and rotation')
#
#     # noise range and loss coefficient
#     parser.add_argument('--rotation_range', type=float, default=10., help='degree') # default +-10
#     parser.add_argument('--shift_range_lat', type=float, default=20., help='meters')
#     parser.add_argument('--shift_range_lon', type=float, default=20., help='meters')
#
#     parser.add_argument('--coe_shift_lat', type=float, default=100., help='meters')
#     parser.add_argument('--coe_shift_lon', type=float, default=100., help='meters')
#     parser.add_argument('--coe_heading', type=float, default=100., help='degree')
#     parser.add_argument('--loss_method', type=float, default=0.1, help='0, 1, 2, 3, 4, 5')
#
#     # LM parameters
#     parser.add_argument('--using_weight', type=int, default=0, help='weighted LM or not')
#     parser.add_argument('--damping', type=float, default=0.1, help='coefficient in LM optimization')
#     parser.add_argument('--train_damping', type=int, default=0, help='coefficient in LM optimization')
#     parser.add_argument('--grd_mask', type=str, default='halfcut', choices=['half', 'none', 'maxd', 'fov', 'halfcut'], help='G2SP proj mask')
#     parser.add_argument('--sat_mask', type=str, default='half', choices=['none', 'half', 'fov'], help='G2SP proj mask')
#     parser.add_argument('--dropout', type=int, default=0, help='0 or 1')
#     parser.add_argument('--use_hessian', type=int, default=0, help='0 or 1')
#
#     ## Dataset ##
#     parser.add_argument('--use_gt_depth', action='store_true', help='use monodepth or not')
#     parser.add_argument('--depth', type=str, default='lidar', choices=['none', 'mono', 'lidar', 'mono_half', 'both'])
#     parser.add_argument('--depth-range', type=str, default='all', choices=['fov', 'all'])
#     parser.add_argument('--max-depth', type=float, default=70)
#     parser.add_argument('--max-points', type=int, default=5000)
#     parser.add_argument('--max-out-points', type=int, default=10000)
#     parser.add_argument('--grdH', type=int, default=375)
#     parser.add_argument('--grdW', type=int, default=1242)
#
#     args = parser.parse_args()
#
#     return args

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, required=True, help='Path to config.yaml')
    args = parser.parse_args()
    return args

if __name__ == '__main__':

    if torch.cuda.is_available():
        device = torch.device("cuda:0")
    else:
        device = torch.device("cpu")

    np.random.seed(2025)
    # args = parse_args()
    config = load_config(parse_args().config)
    args = Namespace(**config)
    # mini_batch = args.batch_size

    # Log with wandb
    if args.wandb:
        wandb_config = dict(project="cvl", entity='kaist-url-ai28', name=args.save)
        wandb_logger = WandbLogger(wandb_config, args)
    else:
        wandb_logger = WandbLogger(None)
    wandb_logger.before_run()
    save_path = f'/ws/external/checkpoints/kitti/{args.save}'
    if not os.path.exists(save_path):
        if not args.debug:
            os.makedirs(save_path)

    # Build network
    net = eval(args.direction)(args)
    net.to(device)

    ## Optimizer
    if args.base_optimizer == 'adam':
        optimizer = optim.Adam(net.parameters(), lr=args.lr)
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1.0 - float(epoch) / args.epochs)

    elif args.base_optimizer == 'pid_adam' and args.pid_k == 'learnable':
        pid = PIDOptimizer(args)
        optimizer = optim.Adam([
            {'params': net.parameters(), 'lr': args.lr},
            {'params': pid.parameters(), 'lr': args.lr2, 'weight_decay': 0},
        ])
        scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: 1.0 - float(epoch) / args.epochs)
        net.NNrefine.pid = pid


    if args.checkpoint:
        checkpoint = torch.load(args.checkpoint)
        if isinstance(checkpoint, dict):
            net.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        else:
            net.load_state_dict(torch.load(args.checkpoint))
        print(f"load model from {args.checkpoint}")

    if args.test:
        test(net, args, save_path, 0., epoch=0, wandb_logger=wandb_logger, mode='test1')
        test(net, args, save_path, 0., epoch=0, wandb_logger=wandb_logger, mode='test2')
    else:   # train
        train(net, args, save_path, wandb_logger, optimizer=optimizer, scheduler=scheduler)

