import abc
import os

from utils.general import load_pickle
class BaseOptions(metaclass=abc.ABCMeta):
    def __init__(self):
        self.device = 'cuda'
        self.expdir = '../../exps'
        self.use_wandb = True
        self.debug = False
    @abc.abstractmethod
    def model_name(self):
        raise NotImplementedError

class TrainNGPOptions(BaseOptions):
    def __init__(self):
        super(TrainNGPOptions, self).__init__()


        #############dataset options###########################
        self.preload = True
        self.scale = 1.
        self.offset = [0., 0., 0.]
        self.bound = 1.
        self.patch_size = None
        self.error_map = True
        self.color_space = 'srgb'
        self.radius_range = (1.5, 2.5)
        self.fov_range = (50, 90)
        self.angle_overhead = 30
        self.angle_front = 70
        ##############general options##########################
        self.fp16 = True
        self.cuda_ray = True
        self.latent_mode = False

        self.train_type = 'gen' #gen or rec

        self.use_new_raymarcher = True

        self.activation = 'softplus' #softplus or exp
        self.blob_type = 'blob' #blob or gaussian
        ###############training options#########################
        self.num_rays = 4096
        self.nepochs = 100
        self.learning_rate = 5e-3

        self.use_scheduler = True
        self.sparsity_lambda = 5e-4
        self.shape_lambda = 5e-6
        self.sketch_lambda = 5e-6
        self.sil_lambda = 1.

        self.training_progress = None #None or list of epochs
        self.save_freq = 10
        self.eval_freq = 10 if self.train_type == 'rec' else 1
        self.print_freq = 10
        ###############rendering_options########################
        self.min_near = 0.1
        self.density_thresh = 10 if self.activation == 'exp' else 2.5
        self.bg_radius = 1.5
        self.update_extra_interval = 16
        self.staged_rendering = True
        ###############eval options#############################
        self.render_normals = True
        ###############network options##########################
        self.num_layers = 3
        self.hidden_dim = 64
        self.num_layers_bg = 2
        self.hidden_dim_bg = 64
        ####################text options#########################
        self.text = "a red flower stem rising from a potted plant"
        self.dir_guidance = True
        #############directory and name options################
        self.expname = 'compare'
        if self.train_type == 'gen':
            self.expname += '_{}'.format(self.text)
        self.expdir = os.path.join(self.expdir, self.model_name())
        #############checkpoint options##########################
        self.checkpoint_path = '../../exps/dreamfusion/horse_highres_a potted green plant/2023_02_13_14_45_52/checkpoints/latest.pth'
        #############disentanglement options#####################
        self.use_dt = False
        self.geom_feature_size = 15
        self.num_layers_geom = 2
        self.num_layers_color = 2

        #################sketch shape options#####################
        self.use_sketch_shape = False
        self.sketch_shape_path = '../../data/meshes/cat.obj'
        self.shape_scale = 1
        self.proximal_surface = 1.
        self.use_edit_sketch = True
        self.edit_sketch_path = '../../data/meshes/flower.obj'
        self.use_bbox_init = False

        ##################2D sketch options#######################
        self.use_2d_sketch = False
        self.sketch_path = '../../data/plant2'
        self.use_base_mesh = False
        self.sketch_indices = None
        self.proximal_surface_2d = 0.025
        self.occ_thresh = 0.001
        self.sketch_type = 'manual'
        self.bitfield_warmup_iters = None

        self.sketch_H = 128
        self.sketch_W = 128

        self.use_kd_tree = False

        self.fill_sketch = False
        self.preprocess_sketch = True






    def model_name(self):
        if self.train_type == 'rec':
            return 'instant_ngp'
        elif self.train_type == 'gen':
            return 'dreamfusion'

class RendererOptions(BaseOptions):
    def __init__(self):
        super(RendererOptions, self).__init__()
        self.exp_dir = '../exps/dreamfusion/compare_a 3D rendering of unicorn head/2023_03_11_22_43_01'

        self.ckpt_path = None if self.exp_dir is None else os.path.join(self.exp_dir, 'checkpoints/ckpt_0100.pth')
        opt_path = None #if self.exp_dir is None else os.path.join(self.exp_dir, 'opts.pkl')

        if opt_path is not None:
            self.model_opts = load_pickle(opt_path) if os.path.exists(opt_path) else TrainNGPOptions()
        else:
            self.model_opts = TrainNGPOptions()

        default_opts = TrainNGPOptions()
        for k, v in default_opts.__dict__.items():
            if k not in self.model_opts.__dict__:
                self.model_opts.__dict__[k] = v

        self.H = 512
        self.W = 512
        self.fov_range = [60, 80]

        self.out_dir = './' if self.exp_dir is None else os.path.join(self.exp_dir, 'outputs')
    def model_name(self):
        return 'renderer'
class LatentGuidancePredictorOptions(BaseOptions):
    def __init__(self):
        super(LatentGuidancePredictorOptions, self).__init__()
        self.input_blocks = [2, 4, 8]
        self.middle_blocks = [0, 1, 2]
        self.output_blocks = [2, 4, 8]
        self.t_embeds_count = 10
        self.mlp_layers = [512, 256, 128]
        self.mlp_in_dim = 9290
        self.mlp_out_dim = 4
        self.mlp_activations = 'nn.ReLU'
        self.ckpt_path = 'checkpoint/model_3000.pt'
    def model_name(self):
        return 'latent_guidance_predictor'

class SketchOptions(BaseOptions):
    def __init__(self):
        super(SketchOptions, self).__init__()