# pylint: disable=too-few-public-methods
"""Data schemas."""
from typing import Any, Dict, List, Optional

from pydantic import BaseModel, ConfigDict, Field
from torch import Tensor

from src.utils.span_utils import SpanList


class AuxDetectorOutput(BaseModel):
    """
    Detector output schema.

    Attributes:
        cls_logits (Optional[List[Tensor]]): logits predicted by AUX head.
        bbox_regression (Optional[List[Tensor]]): spans predicted by AUX head.
        bbox_ctrness (Optional[List[Tensor]]): centerness predicted by AUX head.
        anchors (Optional[List[List[SpanList]]]): anchors generated by AUX head.
        matched_gts (Optional[List[Tensor]]): matched gts.
        anchors_spans (Optional[List[SpanList]]): positive anchors.
        selected_features (Optional[List[Tensor]]): matched encoder features.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    cls_logits: Optional[List[Tensor]] = Field(None, description="Logits predicted by AUX head.")
    bbox_regression: Optional[List[Tensor]] = Field(None, description="Spans predicted by AUX head.")
    bbox_ctrness: Optional[List[Tensor]] = Field(None, description="Centerness predicted by AUX head.")
    anchors: Optional[List[List[SpanList]]] = Field(None, description="Anchors generated by AUX head.")  # noqa: WPS234
    matched_gts: Optional[List[Tensor]] = Field(None, description="Matched gts.")
    anchors_spans: Optional[List[Tensor]] = Field(None, description="Positive anchors")
    selected_features: Optional[List[Tensor]] = Field(None, description="Matched encoder features.")


class QueryProposalsOutput(BaseModel):
    """
    Schema for the output of get_query_proposals method.

    Attributes:
        query_embs (Tensor): Selected content queries (batch_size, num_queries, d_model).
        refpoint_embed_detach (Tensor): Detached reference points embeddings with shape (batch_size, num_queries, 2).
        refpoint_embed_undetach (Tensor): Undetached reference points embeddings with shape (bs, num_queries, 2).
        class_logit_enc (Tensor): Encoder top classes with shape (batch_size, num_queries, 1).
        iou_logit_enc (Tensor): Encoder top ious with shape (batch_size, num_queries, 1).
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)

    query_embs: Tensor = Field(..., description="Selected content queries (batch_size, num_queries, d_model).")
    refpoint_embed_detach: Tensor = Field(
        ...,
        description="Detached reference points embeddings with shape (batch_size, num_queries, 2).",
    )
    refpoint_embed_enc: Tensor = Field(
        ...,
        description="Undetached reference points embeddings with shape (batch_size, num_queries, 2).",
    )
    class_logit_enc: Tensor = Field(..., description="Encoder top classes with shape (batch_size, num_queries, 1).")
    iou_logit_enc: Tensor = Field(..., description="Encoder top ious with shape (batch_size, num_queries, 1).")


class DetEncoderOutput(BaseModel):
    """
    Detection Encoder output schema.

    Attributes:
        memory (Tensor): Output of the regular encoder.
        vid_pos (Tensor): Positional embeddings.
        vid_mask (Tensor): Mask for the source sequence.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    memory: Tensor = Field(..., description="Output of the encoder.")
    vid_pos: Tensor = Field(..., description="Positional embeddings.")
    vid_mask: Tensor = Field(..., description="Padding mask.")


class DetectorOutput(BaseModel):
    """
    Detector output schema.

    Attributes:
        outputs_class (Tensor): predicted labels. Shape: [num_queries, batch_size, 2]
        outputs_coord (Tensor): reference points. Shape: [num_queries, batch_size, 2]
        offsets (Tensor): predicted offsets for anchors (need only for losses). Shape: [num_queries, batch_size, 2]
        quality_scores (Optional[Tensor]): predicted IOU score for spans. [num_queries, batch_size, 2]
        co_info (Optional[Dict[str, Any]]): support output for the colab postprocessing.
        dn_info (Optional[Dict[str, Any]]): support output for the denoise postprocessing.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    outputs_class: Tensor = Field(..., description="predicted labels. Shape: [num_queries, batch_size, 2]")
    outputs_coord: Tensor = Field(..., description="reference points. Shape: [num_queries, batch_size, 2]")
    offsets: Tensor = Field(
        ...,
        description="predicted offsets for anchors (need only for losses). Shape: [num_queries, batch_size, 2]",
    )
    quality_scores: Optional[Tensor] = Field(..., description="predicted IOU score for spans. [num_queries, bs, 2]")
    co_info: Optional[Dict[str, Any]] = Field(..., description="support output for the colab postprocessing.")
    dn_info: Optional[Dict[str, Any]] = Field(..., description="support output for the denoise postprocessing.")


class MomentEncoderOutput(BaseModel):
    """
    Moment Encoder output schema.

    Attributes:
        relevant_clips_mask (Optional[Tensor]): Mask for relevant clips
        irrel_clips_mask (Optional[Tensor]): Mask for irrelevant clips
        moment_token (Optional[Tensor]): Moment token inhanced with relevant clips representation
        moment_memory (Optional[Tensor]): Relevant clips representation (output from SA encoder)
        non_moment_token (Optional[Tensor]): Non-moment token inhanced with irrelevant clips representation
        non_moment_memory (Optional[Tensor]): Irrelevant clips representation (output from SA encoder)
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    relevant_clips_mask: Optional[Tensor] = Field(None, description="Mask for relevant clips")
    irrelevant_clips_mask: Optional[Tensor] = Field(None, description="Mask for irrelevant clips")
    moment_token: Optional[Tensor] = Field(None, description="Moment token inhanced with relevant clips representation")
    moment_memory: Optional[Tensor] = Field(None, description="Relevant clips representation (output from SA encoder)")
    non_moment_token: Optional[Tensor] = Field(
        None,
        description="Non-moment token inhanced with irrelevant clips representation",
    )
    non_moment_memory: Optional[Tensor] = Field(
        None,
        description="Irrelevant clips representation (output from SA encoder)",
    )


class SentenceEncoderOutput(BaseModel):
    """
    Sentence Encoder output schema.

    Attributes:
        model_config (dict): Model configuration
        sent_txt_token (Optional[Tensor]): sentence tokens inhanced with query representation
        sent_dummy_token (Optional[Tensor]): sentence tokens inhanced with dummy representation
        sent_words_memory (Optional[Tensor]): memory of words inhanced with query representation
        sent_dummy_memory (Optional[Tensor]): memory of words inhanced with dummy representation
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    sent_txt_token: Optional[Tensor] = Field(None, description="Sentence tokens inhanced with query representation")
    sent_dummy_token: Optional[Tensor] = Field(None, description="Sentence tokens inhanced with dummy representation")
    sent_words_memory: Optional[Tensor] = Field(None, description="Memory of words inhanced with query representation")
    sent_dummy_memory: Optional[Tensor] = Field(None, description="Memory of words inhanced with dummy representation")


class InferenceOutputWrapper(BaseModel):
    """Inference output schema.

    Attributes:
        pred_logits: predicted logits.
        pred_spans: predicted spans.
        saliency_scores: predicted saliency scores.
    """

    model_config = ConfigDict(arbitrary_types_allowed=True)
    pred_logits: Tensor = Field(..., description="Predicted logits")
    pred_spans: Tensor = Field(..., description="Predicted spans")
    saliency_scores: Tensor = Field(..., description="Predicted saliency scores")
