#!/usr/bin/env python3
import cv2
import numpy as np
import argparse
from detect import detect_watermark
from eval import eval_m1
from entropy_utils import compute_entropy_threshold
import sys
import os

project_path = "..."
sys.path.append(project_path)

# Default parameters (adjust if needed)
DIAG = 2
interval_length = 8
METHOD = "DWT"
SCALE = 6000
Z_SCORE = -1.64
ROUND_PARAM = 30
IMAGE_BLOCK_SIZE = 96
ALPHA = 0.3 # transparency of overlay color

def create_full_block_mask(image_height, image_width, block_size):
    """Create a mask indicating which blocks are full size (block_size x block_size)."""
    n_blocks_y = (image_height + block_size - 1) // block_size
    n_blocks_x = (image_width + block_size - 1) // block_size
    
    full_block_mask = np.zeros((n_blocks_y, n_blocks_x), dtype=bool)
    
    for by in range(n_blocks_y):
        for bx in range(n_blocks_x):
            y0 = by * block_size
            x0 = bx * block_size
            y1 = min(y0 + block_size, image_height)
            x1 = min(x0 + block_size, image_width)
            
            # Check if this block is full size
            if (y1 - y0) == block_size and (x1 - x0) == block_size:
                full_block_mask[by, bx] = True
    
    return full_block_mask

def postprocess_mask(mask, full_block_mask):
    """
    Apply postprocessing to fill isolated blocks:
    - If a block is surrounded by blocks of opposite type, flip it
    - Only applies to full blocks
    """
    processed_mask = mask.copy()
    n_blocks_y, n_blocks_x = mask.shape
    
    for by in range(n_blocks_y):
        for bx in range(n_blocks_x):
            # Only process full blocks
            if not full_block_mask[by, bx]:
                continue
                
            # Get neighbors (8-connected)
            neighbors = []
            neighbor_positions = [
                (-1, -1), (-1, 0), (-1, 1),
                ( 0, -1),          ( 0, 1),
                ( 1, -1), ( 1, 0), ( 1, 1)
            ]
            
            for dy, dx in neighbor_positions:
                ny, nx = by + dy, bx + dx
                # Check if neighbor is within bounds and is a full block
                if (0 <= ny < n_blocks_y and 0 <= nx < n_blocks_x and 
                    full_block_mask[ny, nx]):
                    neighbors.append(mask[ny, nx])
            
            # If we have neighbors and they all agree on a different value than current
            if len(neighbors) >= 3:  # Need at least 3 neighbors to make a decision
                current_value = mask[by, bx]
                neighbor_consensus = all(n == (not current_value) for n in neighbors)
                
                if neighbor_consensus:
                    processed_mask[by, bx] = not current_value
                    
    return processed_mask

def save_detection_map(rgb_img, final_mask, full_block_mask, image_path, apply_postprocessing=False):
    """Overlay semi-transparente grüne/rote Blöcke und speichere Map."""
    
    # Apply postprocessing if requested
    if apply_postprocessing:
        print("Applying postprocessing to fill isolated blocks...")
        mask_to_use = postprocess_mask(final_mask, full_block_mask)
        
        # Count changes
        changes = np.sum((final_mask != mask_to_use) & full_block_mask)
        print(f"Postprocessing changed {changes} blocks")
    else:
        mask_to_use = final_mask
    
    h, w = rgb_img.shape[:2]
    overlay = rgb_img.astype(np.float32)

    n_blocks_y, n_blocks_x = mask_to_use.shape
    
    # Color only the full blocks
    for by in range(n_blocks_y):
        for bx in range(n_blocks_x):
            if not full_block_mask[by, bx]:
                continue
            y0 = by * IMAGE_BLOCK_SIZE
            x0 = bx * IMAGE_BLOCK_SIZE
            y1, x1 = y0 + IMAGE_BLOCK_SIZE, x0 + IMAGE_BLOCK_SIZE
            if y1 > h or x1 > w:
                continue
            color = np.array([0,255,0], dtype=np.float32) if mask_to_use[by, bx] else np.array([255,0,0], dtype=np.float32)
            block = overlay[y0:y1, x0:x1]
            overlay[y0:y1, x0:x1] = (1-ALPHA)*block + ALPHA*color

    # Draw white grid for ALL blocks (including non-full ones)
    border_color = np.array([255, 255, 255], dtype=np.float32)
    border_thickness = 1
    
    for by in range(n_blocks_y):
        for bx in range(n_blocks_x):
            y0 = by * IMAGE_BLOCK_SIZE
            x0 = bx * IMAGE_BLOCK_SIZE
            y1 = min(y0 + IMAGE_BLOCK_SIZE, h)
            x1 = min(x0 + IMAGE_BLOCK_SIZE, w)
            
            # Draw borders (top, bottom, left, right)
            # Top border
            if y0 >= border_thickness:
                overlay[y0:y0+border_thickness, x0:x1] = border_color
            # Bottom border
            if y1-border_thickness < h:
                overlay[y1-border_thickness:y1, x0:x1] = border_color
            # Left border
            if x0 >= border_thickness:
                overlay[y0:y1, x0:x0+border_thickness] = border_color
            # Right border
            if x1-border_thickness < w:
                overlay[y0:y1, x1-border_thickness:x1] = border_color

    # Add postprocessing suffix to filename if applied
    base_name = os.path.splitext(image_path)[0]
    suffix = "_postprocessed" if apply_postprocessing else ""
    map_path = f"{base_name}_detection_map{suffix}.png"
    
    out = np.clip(overlay, 0, 255).astype(np.uint8)
    cv2.imwrite(map_path, cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
    print(f"Detection map saved to: {map_path}")

def main():
    parser = argparse.ArgumentParser(
        description="Detect watermark on a single image and evaluate the M1 score."
    )
    parser.add_argument(
        "image_path", help="Path to the input image file."
    )
    parser.add_argument(
        "--dwt_lvl", type=int, default=3,
        help="DWT level to use (default: 3)."
    )
    parser.add_argument(
        "--round", type=int, dest="round_param", default=ROUND_PARAM,
        help="Round parameter for detection (default: 30)."
    )
    parser.add_argument(
        "--entropy_threshold", type=float, default=None,
        help="Entropy threshold for detection. If not specified, uses median entropy."
    )
    parser.add_argument(
        "--postprocess", action="store_true", default=True,
        help="Apply postprocessing to fill isolated blocks (default: True)."
    )
    parser.add_argument(
        "--no-postprocess", dest="postprocess", action="store_false",
        help="Disable postprocessing."
    )
    args = parser.parse_args()

    # Load image
    bgr = cv2.imread(args.image_path)
    if bgr is None:
        print(f"Error: Could not read image '{args.image_path}'")
        return

    # Convert to RGB
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)

    # Create full block mask once (independent of channel processing)
    h, w = rgb.shape[:2]
    full_block_mask = create_full_block_mask(h, w, IMAGE_BLOCK_SIZE)
    print(f"Full block mask shape: {full_block_mask.shape}")
    print(f"Number of full blocks: {np.sum(full_block_mask)}")

    # Compute entropy threshold
    if args.entropy_threshold is not None:
        entropy_thr = args.entropy_threshold
        print(f"Using specified entropy threshold: {entropy_thr}")
    else:
        entropy_thr = compute_entropy_threshold(rgb)
        print(f"Using median entropy threshold: {entropy_thr}")

    # Process each channel
    channel_masks = []
    scores = []
    for i in range(3):
        chan = rgb[:, :, i]
        _, violations, pixels = detect_watermark(
            method=METHOD,
            image_array=chan,
            diag=DIAG,
            interval_length=interval_length,
            dwt_lvl=args.dwt_lvl,
            scale=SCALE,
            entropy_threshold=entropy_thr,
            round=args.round_param
        )
        score, mask = eval_m1(violations, pixels, Z_SCORE)
        scores.append(score)
        channel_masks.append(mask)  # bool array shape (n_blocks_y, n_blocks_x)

    # Score ausgeben
    avg_score = np.mean(scores)
    print(f"Average score: {avg_score:.4f}")

    # Final mask: majority voting (at least 2 out of 3 channels agree)
    sum_masks = sum(channel_masks)  # True->1, False->0
    final_mask = sum_masks >= 2  # Majority vote: at least 2 out of 3 channels

    # Save detection map
    save_detection_map(rgb, final_mask, full_block_mask, args.image_path, args.postprocess)

if __name__ == "__main__":
    main()
