import os
import sys

sys.path.append(os.path.dirname(sys.path[0]))
from learning.util.image_trainer import Image_Trainer
from learning.util.scenegraph_trainer import Scenegraph_Trainer
from learning.util.ss_rs2g_trainer import RS2G_Trainer
from util.config_parser import configuration
import wandb
import torch.optim as optim

# Usage:
# python 7_transfer_model.py --yaml_path ../config/transfer_rule_graph_risk_config.yaml
# python 7_transfer_model.py --yaml_path ../config/transfer_ss_graph_risk_config.yaml
# python testModel.py --yaml_path ../config/transfer_rs2g_graph_risk_config.yaml
def train_Trainer(learning_config):
    ''' Training the dynamic kg algorithm with different attention layer choice.'''

    # wandb setup
    wandb_arg = wandb.init(config=learning_config,
                           project=learning_config.wandb_config['project'])


    trainer = RS2G_Trainer(learning_config, wandb_arg)
    trainer.build_transfer_learning_dataset()
    trainer.load_model()
    # trainer.evaluate_transfer_learning()

    model = trainer.model
    # 获取除了node_encoder和pretext之外的所有参数
    params_to_update = []
    for name, param in model.named_parameters():
        print(name)
        if param.requires_grad and "node_encoder" not in name and "pretext" not in name:
            print(name)
            params_to_update.append(param)
    # # 创建优化器，只包含这些参数
    # optimizer = torch.optim.Adam(params_to_update, lr=0.001)
    # print(optimizer)

    trainer.split_dataset()
    trainer.learn()


if __name__ == "__main__":
    # the entry of dynkg pipeline training
    learning_config = configuration(sys.argv[1:])
    train_Trainer(learning_config)