# import torch

# checkpoint_path = ""
# model.load_state_dict(torch.load(checkpoint_path), strict=False)

from PIL import Image
import requests
import torch
from open_flamingo import create_model_and_transforms


model, image_processor, tokenizer = create_model_and_transforms(
    "ViT-L-14",
    "openai",
    "/data/jcy/ckpt/anas-awadalla/mpt-1b-redpajama-200b-dolly",
    "/data/jcy/ckpt/anas-awadalla/mpt-1b-redpajama-200b-dolly",
    cross_attn_every_n_layers=1,
    use_local_files=False,
    gradient_checkpointing=False,
    freeze_lm_embeddings=False,
)


checkpoint = torch.load("/data/jcy/open_flamingo/OpenFlamingo-3B-DPO-symbol-gqa-10k-sym/checkpoint_final.pt", map_location="cpu")

msd = checkpoint
msd = {k.replace("module.", ""): v for k, v in msd.items()}

model.load_state_dict(msd, False)

"""
Step 1: Load images
"""
demo_image_one = Image.open('/data/jcy/data/data/coco/train2014/COCO_train2014_000000050686.jpg')

demo_image_two = Image.open('/data/jcy/data/data/coco/train2014/COCO_train2014_000000560384.jpg')

demo_image_three = Image.open('/data/jcy/data/data/coco/train2014/COCO_train2014_000000575119.jpg')

query_image = Image.open('/data/jcy/data/data/coco/train2014/COCO_train2014_000000263418.jpg')

{"id": 1449607006, "prompt": "<image>Question:What sport is the man playing? ", "answer": "Answer:The man is playing baseball.<|endofchunk|>", "chosen": "Answer:The man is playing baseball.<|endofchunk|>", "chosen_score": 5.0, "rejected": "Answer:The man is playing kite flying.<|endofchunk|>", "rejected_score": 0.0, 
 "image": ["coco/train2014/COCO_train2014_000000050686.jpg", "coco/train2014/COCO_train2014_000000560384.jpg", "coco/train2014/COCO_train2014_000000263418.jpg",
            "coco/train2014/COCO_train2014_000000575119.jpg"], 
    "context": "<image>Question:What sport is this? Answer:This is baseball.<|endofchunk|>\n<image>Question:What sport is being played? Answer:The sport being played is tennis.<|endofchunk|>\n<image>Question:What sport is she playing? Answer:She is playing skiing.<|endofchunk|>\n"}


"""
Step 2: Preprocessing images
Details: For OpenFlamingo, we expect the image to be a torch tensor of shape 
 batch_size x num_media x num_frames x channels x height x width. 
 In this case batch_size = 1, num_media = 3, num_frames = 1,
 channels = 3, height = 224, width = 224.
"""
vision_x = [image_processor(demo_image_one).unsqueeze(0), image_processor(demo_image_two).unsqueeze(0), image_processor(demo_image_three).unsqueeze(0), image_processor(query_image).unsqueeze(0)]
vision_x = torch.cat(vision_x, dim=0)
vision_x = vision_x.unsqueeze(1).unsqueeze(0)

"""
Step 3: Preprocessing text
Details: In the text we expect an <image> special token to indicate where an image is.
 We also expect an <|endofchunk|> special token to indicate the end of the text 
 portion associated with an image.
"""
tokenizer.padding_side = "left" # For generation padding tokens should be on the left
lang_x = tokenizer(
    ["<image>Question:What sport is this? Answer:This is baseball.<|endofchunk|>\n<image>Question:What sport is being played? Answer:The sport being played is tennis.<|endofchunk|>\n<image>Question:What sport is she playing? Answer:She is playing skiing.<|endofchunk|>\n<image>Question:What sport is the man playing? Answer:"],
    return_tensors="pt",
)


"""
Step 4: Generate text
"""
generated_text = model.generate(
    vision_x=vision_x,
    lang_x=lang_x["input_ids"],
    attention_mask=lang_x["attention_mask"],
    max_new_tokens=20,
    num_beams=1,
)

print("Generated text: ", tokenizer.decode(generated_text[0]))