import os
import sys
import shutil
import warnings
import random
import numpy as np
import pandas as pd
from PIL import Image, ImageEnhance
from tqdm import tqdm

BASE_COUNT = 1000
SAMPLE_SIZE = 100
CURVE_POINTS = 30

RIVA_INTENSITY_BOOST = 1.15

KILLZONE_CLIP = 75.0
KILLZONE_PROV = 0.20

BASE_DIR = "/content/synergy_benchmark_results"
if os.path.exists(BASE_DIR): shutil.rmtree(BASE_DIR)
os.makedirs(BASE_DIR, exist_ok=True)

os.system("pip install -q diffusers transformers accelerate invisible-watermark open_clip_torch scipy ftfy datasets matplotlib seaborn torch-dct")

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline, StableDiffusionImg2ImgPipeline, DPMSolverMultistepScheduler, DDIMScheduler
import open_clip

warnings.filterwarnings("ignore")
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe_gen = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False).to(device)
pipe_gen.scheduler = DPMSolverMultistepScheduler.from_config(pipe_gen.scheduler.config)

pipe_inp = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False).to(device)
pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False).to(device)

clip_model, _, clip_preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
clip_model.to(device)

if not os.path.exists("RivaGAN"):
    os.system("git clone https://github.com/DAI-Lab/RivaGAN.git")
sys.path.append('RivaGAN')

from rivagan import RivaGAN
from rivagan.attention import AttentiveEncoder, AttentiveDecoder

class DummyNetwork(nn.Module):
    def __init__(self, *args, **kwargs): super().__init__()

import __main__
__main__.RivaGAN = RivaGAN
__main__.AttentiveEncoder = AttentiveEncoder
__main__.AttentiveDecoder = AttentiveDecoder
__main__.Adversary = DummyNetwork
__main__.Critic = DummyNetwork

base_rivagan = torch.load("rivagan_32bit_model.pt", map_location=device, weights_only=False)

if hasattr(base_rivagan, 'encoder'):
    riva_encoder = base_rivagan.encoder.to(device).eval()
    riva_decoder = base_rivagan.decoder.to(device).eval()
else:
    temp_rivagan = RivaGAN()
    if hasattr(temp_rivagan, 'load_state_dict'):
        temp_rivagan.load_state_dict(base_rivagan, strict=False)
    riva_encoder = temp_rivagan.encoder.to(device).eval()
    riva_decoder = temp_rivagan.decoder.to(device).eval()

class SynergisticAdapter(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=3, padding=1)
    def forward(self, img_4d):
        x = F.relu(self.conv1(img_4d))
        x = F.relu(self.conv2(x))
        delta = self.conv3(x)
        return torch.clamp(img_4d + delta, -1.0, 1.0)

adapter = SynergisticAdapter().to(device)
if os.path.exists("synergistic_adapter.pth"):
    adapter.load_state_dict(torch.load("synergistic_adapter.pth", map_location=device, weights_only=True))
    print("-> Synergistic Adapter loaded successfully")
else:
    print("-> Adapter weights not found")
adapter.eval()

torch.manual_seed(42)
TARGET_PAYLOAD = torch.randint(0, 2, (1, 32), dtype=torch.float32).to(device)

def embed_rivagan(img_pil, use_adapter=False, boost=1.0):
    img_pil = img_pil.resize((512, 512), Image.Resampling.LANCZOS)
    img_tensor = transforms.ToTensor()(img_pil).unsqueeze(0).to(device)
    img_tensor = (img_tensor * 2.0) - 1.0 # Scale to [-1, 1]

    img_5d = img_tensor.unsqueeze(2)
    with torch.no_grad():
        output = riva_encoder(img_5d, TARGET_PAYLOAD)
        wm_4d = output[0].squeeze(2) if isinstance(output, tuple) else output.squeeze(2)

        if use_adapter:
            wm_4d = adapter(wm_4d)

        if boost != 1.0:
            delta = wm_4d - img_tensor
            wm_4d = torch.clamp(img_tensor + (delta * boost), -1.0, 1.0)

    wm_tensor = (wm_4d.squeeze(0) + 1.0) / 2.0 # Scale back to [0, 1]
    wm_tensor = torch.clamp(wm_tensor, 0.0, 1.0)
    return transforms.ToPILImage()(wm_tensor.cpu())

def detect_rivagan(img_pil):
    img_pil = img_pil.resize((512, 512), Image.Resampling.LANCZOS)
    img_tensor = transforms.ToTensor()(img_pil).unsqueeze(0).to(device)
    img_tensor = (img_tensor * 2.0) - 1.0
    img_5d = img_tensor.unsqueeze(2)
    try:
        with torch.no_grad():
            decoded_msg = riva_decoder(img_5d)
            pred_bits = (torch.sigmoid(decoded_msg) > 0.5).float().squeeze()
            target_bits = TARGET_PAYLOAD.squeeze()
            pred_bytes = pred_bits.view(4, 8)
            target_bytes = target_bits.view(4, 8)
            byte_matches = (pred_bytes == target_bytes).all(dim=1).float()
            return byte_matches.mean().item()
    except: return 0.0

TR_SHAPE = (1, 4, 64, 64)
TR_RADIUS = 10
TR_CHANNEL = 3

def get_watermarking_pattern(shape, w_radius=10, w_channel=3, device='cpu'):
    g_key = torch.Generator(device).manual_seed(999999)
    mask = torch.zeros(shape, dtype=torch.bool, device=device)
    _, _, h, w = shape
    y, x = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing='ij')
    center_y, center_x = h // 2, w // 2
    dist = torch.sqrt((y - center_y)**2 + (x - center_x)**2)
    dist_int = torch.round(dist).long()
    mask[:, w_channel, dist <= w_radius] = True
    pattern = torch.zeros(shape, dtype=torch.complex64, device=device)
    std = (h * w) ** 0.5
    ring_1d = torch.randn(w_radius + 1, generator=g_key, device=device) * std
    for r in range(1, w_radius + 1):
        pattern[:, w_channel, dist_int == r] = ring_1d[r].to(torch.complex64)
    return pattern, mask

TR_PATTERN, TR_MASK = get_watermarking_pattern(TR_SHAPE, w_radius=TR_RADIUS, w_channel=TR_CHANNEL, device=device)

def inject_watermark(init_latents, pattern, mask):
    latents_fft = torch.fft.fftshift(torch.fft.fft2(init_latents.to(torch.float32)), dim=(-2, -1))
    latents_fft[mask] = pattern[mask].to(latents_fft.dtype)
    watermarked_latents = torch.fft.ifft2(torch.fft.ifftshift(latents_fft, dim=(-2, -1))).real
    return watermarked_latents.to(init_latents.dtype)

def get_inverted_noise(pipe, img, num_inference_steps=20):
    device = pipe.device
    dtype = pipe.dtype
    img = img.resize((512, 512))
    img_tensor = torch.from_numpy(np.array(img).astype(np.float32) / 127.5 - 1.0)
    img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0).to(device, dtype=dtype)
    with torch.no_grad():
        latents = pipe.vae.encode(img_tensor).latent_dist.mean * pipe.vae.config.scaling_factor
        text_inputs = pipe.tokenizer("", padding="max_length", max_length=pipe.tokenizer.model_max_length, truncation=True, return_tensors="pt")
        prompt_embeds = pipe.text_encoder(text_inputs.input_ids.to(device))[0]

        sched = DDIMScheduler.from_config(pipe.scheduler.config)
        sched.set_timesteps(num_inference_steps, device=device)
        timesteps = list(reversed(sched.timesteps.tolist()))
        alphas = sched.alphas_cumprod.to(device)

        for i, t in enumerate(timesteps):
            t_idx = int(t)
            alpha_t = alphas[t_idx]
            t_next_idx = int(timesteps[i+1]) if i < len(timesteps) - 1 else len(alphas) - 1
            alpha_next = alphas[t_next_idx]
            model_output = pipe.unet(latents, t_idx, encoder_hidden_states=prompt_embeds).sample
            x0_pred = (latents - (1 - alpha_t).sqrt() * model_output) / alpha_t.sqrt()
            dir_xt = (1 - alpha_next).sqrt() * model_output
            latents = alpha_next.sqrt() * x0_pred + dir_xt
    return latents

def detect_tr_oracle(img):
    try:
        inv_latents = get_inverted_noise(pipe_gen, img, num_inference_steps=20)
        inv_latents_fft = torch.fft.fftshift(torch.fft.fft2(inv_latents.to(torch.float32)), dim=(-2, -1))
        extracted = inv_latents_fft[TR_MASK]
        target = TR_PATTERN[TR_MASK].to(device)
        distance = torch.abs(extracted - target).mean().item()

        rand_latents = torch.randn_like(inv_latents)
        rand_fft = torch.fft.fftshift(torch.fft.fft2(rand_latents.to(torch.float32)), dim=(-2, -1))
        d_rand = torch.abs(rand_fft[TR_MASK] - target).mean().item()

        score = max(0.0, 1.0 - (distance / (d_rand + 1e-6)))
        return min(1.0, score * 1.5)
    except Exception as e:
        return 0.0

def detect_combined(img):
    # EITHER the Tree-Ring OR RivaGAN watermark surviving marks a success
    return max(detect_tr_oracle(img), detect_rivagan(img))

def get_prov(model_name, img):
    if model_name == "Tree-Ring": return detect_tr_oracle(img)
    elif "Combined" in model_name: return detect_combined(img)
    else: return detect_rivagan(img) # Old and New RivaGAN only

BASE_IMAGES = {
    "RivaGAN (Old)": [],
    "RivaGAN (New)": [],
    "Tree-Ring": [],
    "Combined (TR + Old)": [],
    "Combined (TR + New)": []
}

prompts = ["a high quality masterpiece"] * BASE_COUNT
try:
    ds = load_dataset("poloclub/diffusiondb", "2m_first_1k", split="train", streaming=True)
    prompts = [x['prompt'] for idx, x in enumerate(ds) if idx < BASE_COUNT]
except: pass

for i in tqdm(range(BASE_COUNT), desc="Generating Data"):
    g = torch.Generator(device).manual_seed(1000+i)

    img_clean = pipe_gen(prompts[i], num_inference_steps=20, generator=g).images[0]
    BASE_IMAGES["RivaGAN (Old)"].append(embed_rivagan(img_clean, use_adapter=False, boost=RIVA_INTENSITY_BOOST))
    BASE_IMAGES["RivaGAN (New)"].append(embed_rivagan(img_clean, use_adapter=True, boost=RIVA_INTENSITY_BOOST))

    g_tr = torch.Generator(device).manual_seed(1000+i)
    init_latents = torch.randn(TR_SHAPE, generator=g_tr, device=device, dtype=pipe_gen.dtype)
    watermarked_latents = inject_watermark(init_latents, TR_PATTERN, TR_MASK)

    old_sched = pipe_gen.scheduler
    pipe_gen.scheduler = DDIMScheduler.from_config(pipe_gen.scheduler.config)
    img_tr = pipe_gen(prompts[i], num_inference_steps=20, latents=watermarked_latents).images[0]
    pipe_gen.scheduler = old_sched

    BASE_IMAGES["Tree-Ring"].append(img_tr)

    BASE_IMAGES["Combined (TR + Old)"].append(embed_rivagan(img_tr, use_adapter=False, boost=RIVA_INTENSITY_BOOST))
    BASE_IMAGES["Combined (TR + New)"].append(embed_rivagan(img_tr, use_adapter=True, boost=RIVA_INTENSITY_BOOST))

def get_mask(img, r):
    w, h = img.size; m = int((w * r) / 2) or 1
    mask = Image.new("L", (w, h), 255)
    import PIL.ImageDraw as D
    D.Draw(mask).rectangle((m, m, w-m, h-m), fill=0)
    return mask

def crop_center(img, s):
    w, h = img.size
    return img.crop((int(w*s/2), int(h*s/2), int(w*(1-s/2)), int(h*(1-s/2))))

attacks = {
    "Img2Img":    (lambda img, s: pipe_i2i(prompt="art", image=img, strength=max(0.01, s), num_inference_steps=15).images[0], np.linspace(0.01, 0.95, CURVE_POINTS)),
    "Inpainting": (lambda img, s: pipe_inp(prompt="bg", image=img, mask_image=get_mask(img, s), num_inference_steps=15).images[0], np.linspace(0.05, 0.6, CURVE_POINTS)),
    "Crop":       (lambda img, s: crop_center(img, s), np.linspace(0.05, 0.9, CURVE_POINTS)),
    "Brightness": (lambda img, s: ImageEnhance.Brightness(img).enhance(s), np.linspace(1.1, 3.0, CURVE_POINTS))
}

ALL_DATA_SINGLE = []

for model_name, images in BASE_IMAGES.items():
    for idx in range(min(SAMPLE_SIZE, BASE_COUNT)):
        img = images[idx]
        ALL_DATA_SINGLE.append({"Model": model_name, "Attack": "Control", "Level": 0.0, "Provenance": get_prov(model_name, img), "CLIP": 100.0})

for atk_name, (func, levels) in attacks.items():
    print(f"Sweeping {atk_name}...")
    for level in tqdm(levels):
        indices = random.sample(range(BASE_COUNT), min(SAMPLE_SIZE, BASE_COUNT))
        for model_name, images in BASE_IMAGES.items():
            for idx in indices:
                orig = images[idx]
                try: atk_img = func(orig, level)
                except: continue

                prov = get_prov(model_name, atk_img)

                with torch.no_grad():
                    c1 = clip_model.encode_image(clip_preprocess(orig).unsqueeze(0).to(device))
                    c2 = clip_model.encode_image(clip_preprocess(atk_img).unsqueeze(0).to(device))
                    c1 /= c1.norm(dim=-1, keepdim=True); c2 /= c2.norm(dim=-1, keepdim=True)
                    clip_score = (c1 @ c2.T).item() * 100

                ALL_DATA_SINGLE.append({"Model": model_name, "Attack": atk_name, "Level": level, "Provenance": prov, "CLIP": clip_score})

df_single = pd.DataFrame(ALL_DATA_SINGLE)
df_single.to_csv(os.path.join(BASE_DIR, "results_single_attacks.csv"), index=False)

# Analytics directly on df_single
df_single['In_Killzone'] = ((df_single['CLIP'] > KILLZONE_CLIP) & (df_single['Provenance'] < KILLZONE_PROV)).astype(int)

summary = df_single.groupby(['Attack', 'Model']).apply(
    lambda x: pd.Series({
        'Avg_Provenance': x['Provenance'].mean(),
        'Avg_CLIP': x['CLIP'].mean(),
        'Killzone_Rate (%)': (x['In_Killzone'].mean() * 100)
    })
).reset_index()

pd.set_option('display.max_rows', 100)
pd.set_option('display.float_format', '{:.3f}'.format)
print(summary.sort_values(by=['Attack', 'Model']).to_string(index=False))

summary.to_csv(os.path.join(BASE_DIR, "killzone_analysis.csv"), index=False)

shutil.make_archive("/content/synergy_benchmark_results", 'zip', BASE_DIR)
print("Archived to: /content/synergy_benchmark_results.zip")