@dataclass
class OptConfig:
    pass
def main(cfg, accelerator):    
    opt: OptConfig = parse_structured(OptConfig, cfg.opt)


    writer = None
    if accelerator.is_main_process and not opt.debug:
        writer = SummaryWriter(os.path.join(opt.workspace,"runs"))


    # model
    model = IGS(cfg.system)

    start_epoch = 0

    if opt.resume is not None:
        start_epoch = opt.start_epoch

        if opt.resume.endswith('safetensors'):
            ckpt = load_file(opt.resume, device='cpu')
        else:
            ckpt = torch.load(opt.resume, map_location='cpu')

        state_dict = model.state_dict()
        for k, v in ckpt.items():
            if k in state_dict: 
                if state_dict[k].shape == v.shape:
                    state_dict[k].copy_(v)
                else:
                    accelerator.print(f'[WARN] mismatching shape for param {k}: ckpt {v.shape} != model {state_dict[k].shape}, ignored.')
            else:
                accelerator.print(f'[WARN] unexpected param {k}: {v.shape}')
        accelerator.print("load ckpt ok")

    train_dataset = igs.find(cfg.data.data_cls)(cfg.data.data,training=True)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_size,
        shuffle=True, # false only debug
        num_workers=opt.num_workers,
        pin_memory=True,
        drop_last=True,
        collate_fn=train_dataset.collate

    )

    test_dataset = igs.find(cfg.data.data_cls)(cfg.data.data,training=False)

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset,
        batch_size=opt.batch_size,
        shuffle=False,
        num_workers=opt.num_workers,
        pin_memory=True,
        drop_last=False,
        collate_fn=test_dataset.collate

    )

    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=opt.lr, weight_decay=0.05, betas=(0.9, 0.95))


    total_steps = opt.num_epochs * len(train_dataloader)
    pct_start = 3000 / total_steps


    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=opt.lr, total_steps=total_steps, pct_start=pct_start)


    model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare(
        model, optimizer, train_dataloader, test_dataloader, scheduler
    )
    

    for epoch in range(start_epoch, opt.num_epochs):
        # train
        model.train()
        total_loss = 0
        total_psnr = 0
        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch}", disable=not accelerator.is_local_main_process)
        
        
        

        
        for i, data in enumerate(train_dataloader):
            with accelerator.accumulate(model):

                optimizer.zero_grad()
                loss = 0


                out = model(data) #model forward

                pred_images = out['images_pred'] # [B, V, C, output_size, output_size]
                gt_images = data['images_output'].to("cuda") # [B, V, 3, output_size, output_size], ground-truth novel views

               
                if opt.lambda_rgb >0:


                    loss_mse = l1_loss(pred_images, gt_images)
                    out["loss_mse"] = loss_mse.item()
                    loss = loss + opt.lambda_rgb * loss_mse

                if opt.lambda_ssim >0:
                    img1 = rearrange(pred_images, " B V C H W -> (B V) C H W")
                    img2 = rearrange(gt_images, " B V C H W -> (B V) C H W")
                    ssim_value, ssim_map = ssim(img1, img2)

                    loss_ssim = 1.0 - ssim_value
                    out['loss_ssim'] = loss_ssim.item()
                    loss = loss + opt.lambda_ssim*loss_ssim



                accelerator.backward(loss) # loss backward here

                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), opt.gradient_clip)

                optimizer.step()
                scheduler.step()





if __name__ == "__main__":
    import argparse
    import subprocess
    from igs.utils.config import ExperimentConfig, load_config
    from igs.utils.misc import todevice, get_device

    parser = argparse.ArgumentParser("IGS")
    parser.add_argument("--config", required=True, help="path to config file")

    args, extras = parser.parse_known_args()

    device = get_device()

    cfg: ExperimentConfig = load_config(args.config, cli_args=extras)

    accelerator = Accelerator(
        mixed_precision=cfg.opt.mixed_precision,
        gradient_accumulation_steps=cfg.opt.gradient_accumulation_steps,
        # kwargs_handlers=[ddp_kwargs],
    )


    if accelerator.is_main_process:

        os.makedirs(cfg.opt.workspace,exist_ok=True)

        cfg_dict = OmegaConf.to_container(cfg, resolve=True)
        with open(os.path.join(cfg.opt.workspace,'experiment_config.yaml'), 'w') as f:
            OmegaConf.save(config=cfg_dict, f=f)

        saveRuntimeCode(os.path.join(cfg.opt.workspace,"backup"))

    np.random.seed(6)
    torch.manual_seed(1111)
    torch.cuda.manual_seed(2222)
    torch.cuda.manual_seed_all(3333)
    ic.disable()
    main(cfg, accelerator)

