# import os 
# os.environ['CUDA_VISIBLE_DEVICES'] = "7"
from argparse import ArgumentParser
import cv2
import os
#os.environ['CUDA_VISIBLE_DEVICE'] = '7'
import pytorch_lightning as pl
from omegaconf import OmegaConf
import torch
from tqdm import tqdm
from utils.common import instantiate_from_config, load_state_dict
from torchvision.utils import save_image
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel

def main() -> None:

    # dist.init_process_group(backend='nccl', init_method='env://')
    # rank = torch.cuda.set_device(dist.get_rank()) 
    # device = torch.device('cuda', rank)

    parser = ArgumentParser()
    parser.add_argument("--config", type=str, required=True)
    parser.add_argument("--output", type=str, required=True)
    parser.add_argument("--watch_step", action='store_true')
    # parser.add_argument("--local_rank", type=int, default=0)
    args = parser.parse_args()
    
    config = OmegaConf.load(args.config)
    pl.seed_everything(config.lightning.seed, workers=True)
    model_config = OmegaConf.load(config.model.config)
    #print(model_config)
    model_config['params']['output'] = args.output
    
    data_module = instantiate_from_config(config.data)
    data_module.setup(stage="fit")
    model = instantiate_from_config(model_config)
    if config.model.get("resume"):
        model_dict = torch.load(config.model.resume, map_location="cpu")
        model_dict = model_dict['state_dict'] if 'state_dict' in model_dict.keys() else model_dict
        a,b = model.load_state_dict(model_dict, strict=False)
        print("missing_keys:",a)
        print("unexpected_keys:",b)
        print("{} model has been loaded!".format(config.model.resume))
        
    model.eval()
    
    save_path = args.output
    callbacks = []
    for callback_config in config.lightning.callbacks:
        callbacks.append(instantiate_from_config(callback_config))
    trainer = pl.Trainer(callbacks=callbacks, **config.lightning.trainer)
    testloader = data_module.val_dataloader()
    trainer.test(model,test_dataloaders=testloader)


def save_batch(images,imgname_batch, save_path, watch_step=False):
    if watch_step:
        for list_idx, img_list in enumerate(images):
            for img_idx, img in enumerate(img_list):
                imgname = str(list_idx)+"_"+imgname_batch[img_idx]
                save_img = os.path.join(save_path,imgname)
                save_image(img,save_img)
    else:   
        for img_idx, img in enumerate(images):
            imgname = imgname_batch[img_idx]
            save_img = os.path.join(save_path,imgname)
            save_image(img,save_img)


if __name__ == "__main__":
    main()

'''
CUDA_VISIBLE_DEVICES=0 \
python3 \
test.py \
--config configs/test.yaml \
--output ./outputs/test/

'''