import torch
from torch import nn


class PartitionLoss(nn.Module):
    def __init__(self, alpha: float = 0.4, beta: float = 0.4, gamma: float = 0.2):
        """
        Loss function for masked partitioning.

        :param alpha: Weight for individual mask BCE losses (front + back)
        :param beta: Weight for partition constraint loss
        :param gamma: Weight for Overlap penalty loss
        """
        super().__init__()
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

    def update_weights(self, alpha: float | None = None, beta: float | None = None, gamma: float | None = None):
        if alpha is not None:
            self.alpha = alpha
        if beta is not None:
            self.beta = beta
        if gamma is not None:
            self.gamma = gamma

    def forward(self, pred_logits, target_masks, original_mask):
        """
        :param pred_logits: (B, 2, H, W) - raw logits for front and back
        :param target_masks: (B, 2, H, W) - ground truth front and back masks
        :param original_mask: (B, 1, H, W) - original segmentation mask
        :return:
        """
        # pred: (B, 2, H, W), target: (B, 2, H, W), original_mask: (B, 1, H, W)
        pred_sig = torch.sigmoid(pred_logits)
        front_pred, back_pred = pred_sig[:, 0:1], pred_sig[:, 1:2]
        front_target, back_target = target_masks[:, 0:1], target_masks[:, 1:2]

        # 1. Individual mask losses - direct supervision
        front_loss = nn.functional.binary_cross_entropy_with_logits(pred_logits[:, 0:1], front_target)
        back_loss = nn.functional.binary_cross_entropy_with_logits(pred_logits[:, 1:2], back_target)
        individual_loss = (front_loss + back_loss) / 2

        # 2. Partition constraint: front + back should equal original mask
        total_pred = front_pred + back_pred
        partition_loss = nn.functional.mse_loss(total_pred, original_mask)

        # 3. Overlap penalty: discourage overlapping predictions
        overlap_penalty = (front_pred * back_pred).mean()  # Should be close to 0

        # 4. Coverage loss: ensure we don't miss parts of the original mask
        coverage = torch.clamp(total_pred - original_mask, min=0)  # Excess prediction
        undercoverage = torch.clamp(original_mask - total_pred, min=0)  # Missed areas
        coverage_loss = coverage.mean() + undercoverage.mean()

        total_loss = (
                self.alpha * individual_loss +
                self.beta * partition_loss +
                self.gamma * (overlap_penalty + coverage_loss)
        )

        return total_loss, {
            'individual_loss': individual_loss.item(),
            'partition_loss': partition_loss.item(),
            'overlap_penalty': overlap_penalty.item(),
            'coverage_loss': coverage_loss.item()
        }


class LossWeightScheduler:
    def __init__(self, loss_fn, total_epochs: int, start_weights: tuple[float, float, float] = (0.9, 0.05, 0.05),
                 end_weights: tuple[float, float, float] = (0.4, 0.4, 0.2)):
        self.loss_fn = loss_fn
        self.total_epochs = total_epochs
        self.start_weights = start_weights
        self.end_weights = end_weights

    def step(self, epoch):
        alpha = self._interpolate(epoch, self.start_weights[0], self.end_weights[0])
        beta = self._interpolate(epoch, self.start_weights[1], self.end_weights[1])
        gamma = self._interpolate(epoch, self.start_weights[2], self.end_weights[2])
        self.loss_fn.update_weights(alpha=alpha, beta=beta, gamma=gamma)

    def _interpolate(self, epoch, start, end):
        return start + (end - start) * (epoch / self.total_epochs)
