class IGS(torch.nn.Module, SaverMixin):
    @dataclass
    class Config:
        pass
    cfg: Config

    def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None):
        state_dict = load_module_weights(
            weights, ignore_modules=ignore_modules, map_location="cpu"
        )
        self.load_state_dict(state_dict, strict=False)

    def __init__(self, cfg):
        super().__init__()
        self.cfg = parse_structured(self.Config, cfg)
        self._save_dir: Optional[str] = None
        
        # init module here

    def _forward_v3(self, batch: Dict[str, Any], first_frame = True) -> Dict[str, Any]:

        B,V, C,H,W = batch["cur_images_input"].shape 

        cur_images_input = batch["cur_images_input"].reshape((-1, C, H, W))
        next_images_input = batch["next_images_input"].reshape((-1, C, H, W))

        motion_feature_2d = self.backbone(cur_images_input, next_images_input) #[(B V) N C H/8 W/8]
        
        anchor_points = anchor_sampling(batch["gs"])


        if self.cfg.use_condition3d:
            # inject 3d information
            motion_feature_2d = self.condition3D(motion_feature_2d, batch["rays"] if not  self.cfg.local_ray else batch["local_rays"], batch["depth"])

        motion_feature_3d = self.motion_feature_lift( motion_feature_2d,  anchor_points, **batch) #[B 3 C h w]
        

        res = self.render( motion_feature_3d, anchor_points=anchor_points, **batch)

        return {**res}  


    def condition3D(self, motion_feature, rays, depth):
        # inject depth and ray information to motion feature
        pass

        return motion_feature


    def forward(self, batch):
        out = self._forward_v3(batch)
        return out

