#!/usr/bin/env python3
import cv2
import numpy as np
import argparse
import os
import sys

project_path = "..."
sys.path.append(project_path)

from embed import embed_watermark
from entropy_utils import compute_entropy_threshold

# Default parameters (adjust as needed)
DIAG = 2
BLOCK_SIZE = 96
METHOD = "DWT"
interval_length = 8
SCALE = 6000
MAX_CHANGE = 3
ADJUST_TO_CENTER = True
ROUND_PARAM = 30
NUM_KEYS_DEFAULT = 0

if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Embed watermark into a single image and save the result."
    )
    parser.add_argument(
        "image_path", help="Path to the input image file."
    )
    parser.add_argument(
        "output_path", help="Path where the watermarked image will be saved."
    )
    parser.add_argument(
        "--dwt_lvl", type=int, default=3,
        help="DWT level to use (default: 3)."
    )
    parser.add_argument(
        "--round", type=int, dest="round_param", default=ROUND_PARAM,
        help="Round parameter for embedding (default: 30)."
    )
    parser.add_argument(
        "--entropy_threshold", type=float, default=None,
        help="Entropy threshold for embedding. If not specified, uses median entropy."
    )
    parser.add_argument(
        "--num_keys", type=int, default=NUM_KEYS_DEFAULT,
        help="Number of keys for embedding (default: 0)."
    )
    args = parser.parse_args()

    # Load image
    bgr = cv2.imread(args.image_path)
    if bgr is None:
        print(f"Error: Could not read image '{args.image_path}'")
        exit(1)

    # Convert to RGB and separate channels
    rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)
    
    # Compute entropy threshold
    if args.entropy_threshold is not None:
        entropy_thr = args.entropy_threshold
        print(f"Using specified entropy threshold: {entropy_thr}")
    else:
        entropy_thr = compute_entropy_threshold(rgb)
        print(f"Using median entropy threshold: {entropy_thr}")

    channels = cv2.split(rgb)

    # Apply watermark embedding on each channel
    watermarked_channels = []
    for idx, chan in enumerate(channels):
        chan = chan.astype(np.float64)
        wm = embed_watermark(
            method=METHOD,
            image_array=chan,
            diag=DIAG,
            interval_length=interval_length,
            dwt_lvl=args.dwt_lvl,
            entropy_threshold=entropy_thr,
            image_block_size=BLOCK_SIZE,
            max_change=MAX_CHANGE,
            adjust_to_center=ADJUST_TO_CENTER,
            scale=SCALE,
            round=args.round_param,
            num_keys=args.num_keys
        )
        watermarked_channels.append(wm)

    # Merge channels back and convert to uint8
    watermarked_rgb = cv2.merge(watermarked_channels)
    watermarked_rgb_uint8 = np.clip(watermarked_rgb, 0, 255).astype(np.uint8)

    # Prepare output directory
    out_dir = os.path.dirname(args.output_path)
    if out_dir and not os.path.exists(out_dir):
        os.makedirs(out_dir, exist_ok=True)

    # Save as PNG
    cv2.imwrite(args.output_path, cv2.cvtColor(watermarked_rgb_uint8, cv2.COLOR_RGB2BGR))
    print(f"Watermarked image saved to: {args.output_path}")
