from cavsam2.cavsam2_video_predictor import build_cavsam2_video_predictor
import torch
import torch.nn.functional as F
from dbst import DBSTPipeline
import numpy as np
from utils import seg2box
import argparse
from utils import otsu_mask
from torchvision.tv_tensors import Mask
import torchvision.transforms.v2 as v2
from torchvision.io import read_image, ImageReadMode
from pathlib import Path


class SAM2FTBackbone(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()

        sam2_checkpoint = "checkpoints/sam2_hiera_tiny.pt"
        sam2_backbone_cfg = "sam2_hiera_t.yaml"
        predictor = build_cavsam2_video_predictor(
            config_file=sam2_backbone_cfg, ckpt_path=sam2_checkpoint, device="cuda"
        )
        self.backbone = predictor.image_encoder
        for p in self.backbone.trunk.parameters():
            p: torch.Tensor
            p.requires_grad = False

    def forward(
        self,
        support_image: torch.Tensor,
        query_image: torch.Tensor,
        support_mask: torch.Tensor,
    ) -> torch.Tensor:
        assert (
            support_image.dim() == 4
            and query_image.dim() == 4
            and support_mask.dim() == 4
        )
        supp_feat, query_feat = (
            self.backbone(support_image)["vision_features"],
            self.backbone(query_image)["vision_features"],
        )
        pseudo_mask = self.get_pseudo_mask(supp_feat, query_feat, support_mask)
        pseudo_mask = F.interpolate(
            pseudo_mask, support_image.shape[-2:], mode="bilinear"
        )
        return pseudo_mask

    def get_pseudo_mask(
        self, supp_feat: torch.Tensor, query_feat: torch.Tensor, mask: torch.Tensor
    ) -> torch.Tensor:
        resize_size = supp_feat.size(2)
        tmp_mask = F.interpolate(
            mask,
            size=(resize_size, resize_size),
            mode="bilinear",
            align_corners=True,
        )

        supp_feat = supp_feat * tmp_mask
        q = query_feat
        s = supp_feat
        bsize, ch_sz, sp_sz, _ = q.size()[:]

        tmp_query = q
        tmp_query = tmp_query.reshape(bsize, ch_sz, -1)
        tmp_query_norm = torch.norm(tmp_query, 2, 1, True)

        tmp_supp = s
        tmp_supp = tmp_supp.reshape(bsize, ch_sz, -1)
        tmp_supp = tmp_supp.permute(0, 2, 1)
        tmp_supp_norm = torch.norm(tmp_supp, 2, 2, True)

        cosine_eps = 1e-7
        similarity = torch.bmm(tmp_supp, tmp_query) / (
            torch.bmm(tmp_supp_norm, tmp_query_norm) + cosine_eps
        )
        similarity = similarity.max(1)[0].reshape(bsize, sp_sz * sp_sz)
        corr_query = similarity.reshape(bsize, 1, sp_sz, sp_sz)
        return corr_query


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--reference-image-path", type=str, required=True)
    parser.add_argument("--reference-mask-path", type=str, required=True)
    parser.add_argument("--target-image-path", type=str, required=True)
    parser.add_argument("--reference-lora-path", type=str, required=True)
    parser.add_argument("--target-lora-path", type=str, required=True)
    parser.add_argument("--pred-save-path", type=str, required=True)
    parser.add_argument("--class-name", type=str, required=True)
    args = parser.parse_args()

    sam2_checkpoint = "checkpoints/sam2_hiera_tiny.pt"
    model_cfg = "sam2_hiera_t.yaml"

    predictor = build_cavsam2_video_predictor(
        config_file=model_cfg, ckpt_path=sam2_checkpoint, device="cuda"
    )

    to_tenor = v2.Compose(
        [
            v2.Resize((512, 512), antialias=True),
            v2.ToDtype(torch.float32, scale=True),
        ]
    )

    i_r, m_r = to_tenor(
        read_image(args.reference_image_path, mode=ImageReadMode.RGB),
        Mask(read_image(args.reference_mask_path, mode=ImageReadMode.GRAY)),
    )
    m_r = (m_r != 0).float()
    i_t = to_tenor(read_image(args.target_image_path, mode=ImageReadMode.RGB))

    i_r, m_r, i_t = (
        i_r.unsqueeze(0).cuda(),
        m_r.unsqueeze(0).cuda(),
        i_t.unsqueeze(0).cuda(),
    )

    """
    DBST
    """
    pipeline: DBSTPipeline = DBSTPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-1-base", torch_dtype=torch.float32
    )
    pipeline = pipeline.to("cuda")
    image_series = pipeline(
        img_0=i_r,
        img_1=i_t,
        prompt_0=f"a photo of {args.class_name}",
        prompt_1=f"a photo of {args.class_name}",
        lora_path_0=args.reference_lora_path,
        lora_path_1=args.target_lora_path,
        num_frames=9,
        num_inference_steps=20,
        alpha_list=[0, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 1],
    )

    """
    Change image size for TTGA
    """
    image_series = F.interpolate(image_series, (1024, 1024), mode="bilinear")
    m_r = F.interpolate(m_r, (1024, 1024), mode="nearest")
    i_r = F.interpolate(i_r, (1024, 1024), mode="bilinear")

    """
    TTGA
    """
    augmentation = v2.Compose(
        [
            v2.RandomPhotometricDistort(),
            v2.RandomAffine(degrees=90, scale=(0.5, 2), shear=20),
            v2.Resize((1024, 1024), antialias=True),
            v2.ToDtype(torch.float32, scale=True),
        ]
    )
    sam2_backbone = SAM2FTBackbone().cuda()
    optimizer = torch.optim.Adam(sam2_backbone.parameters(), lr=1e-3, weight_decay=5e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
    scaler = torch.amp.GradScaler("cuda")
    i_ft_inter = image_series[1 : len(image_series) // 2]
    for n_iter in range(100):
        optimizer.zero_grad()
        aug_i_r, aug_m_r = augmentation(i_r, Mask(m_r))
        aug_m_r = aug_m_r.float()
        aug_m_r[aug_m_r != 0] = 1

        aug_pm_r = sam2_backbone(i_r, aug_i_r, m_r)
        pm_r = sam2_backbone(aug_i_r, i_r, otsu_mask(aug_pm_r))
        loss = F.binary_cross_entropy_with_logits(
            aug_pm_r, aug_m_r
        ) + F.binary_cross_entropy_with_logits(pm_r, m_r)

        loss.backward()
        optimizer.step()
        scheduler.step()

    additional_masks = [
        otsu_mask(sam2_backbone(i_r, i_ft.unsqueeze(0), m_r)) for i_ft in i_ft_inter
    ]

    """
    SAM2 prediction
    """
    inference_state = predictor.init_state(image_series, 1024, 1024)
    with torch.amp.autocast("cuda"):
        predictor.add_new_points_or_box(
            inference_state,
            frame_idx=0,
            obj_id=0,
            box=np.array(list(seg2box(m_r))),
        )
        for i, inter_mask in enumerate(additional_masks):
            if torch.count_nonzero(inter_mask) != 0:
                predictor.add_new_points_or_box(
                    inference_state,
                    frame_idx=i + 1,
                    obj_id=0,
                    box=np.array(list(seg2box(inter_mask))),
                )
        predictor.propagate_in_video(inference_state, start_frame_idx=0, reverse=False)
        predictor.add_new_mask(
            inference_state, frame_idx=0, obj_id=0, mask=m_r.squeeze()
        )
        for i, inter_mask in enumerate(additional_masks):
            predictor.add_new_mask(
                inference_state,
                frame_idx=i + 1,
                obj_id=0,
                mask=inter_mask.squeeze(),
            )
        preds = (
            torch.cat(
                [
                    video_res_masks
                    for _, _, video_res_masks in predictor.propagate_in_video(
                        inference_state, start_frame_idx=0, reverse=False
                    )
                ]
            )
            > 0
        ).float()

    save_path = Path(args.pred_save_path)
    save_path.mkdir(parents=True, exist_ok=True)

    for i, pred in enumerate(preds):
        v2.functional.to_pil_image(pred.byte() * 255).save(
            save_path / f"pred-{str(i)}.png"
        )
    for i, image in enumerate(image_series):
        v2.functional.to_pil_image(image).save(save_path / f"series-{str(i)}.jpg")
