from dbst import DBSTPipeline
from torchvision.io import read_image, ImageReadMode
import torch
import torchvision.transforms.v2 as v2
import argparse


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--image-path", type=str, required=True)
    parser.add_argument("--lora-save-path", type=str, required=True)
    parser.add_argument("--class-name", type=str, required=True)
    args = parser.parse_args()

    to_tenor = v2.Compose(
        [v2.Resize((512, 512)), v2.ToDtype(torch.float32, scale=True)]
    )

    image = (
        to_tenor(read_image(args.image_path, mode=ImageReadMode.RGB))
        .unsqueeze(0)
        .cuda()
    )

    pipeline: DBSTPipeline = DBSTPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float32
    )
    pipeline = pipeline.to("cuda")

    pipeline.train_lora(
        image,
        f"a photo of {args.class_name}",
        args.lora_save_path,
        lora_steps=200,
        lora_lr=1e-3,
    )
