import os
import torch

os.system("pip install -q datasets diffusers transformers accelerate")
from datasets import load_dataset
from diffusers import StableDiffusionPipeline, DDIMScheduler

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

DATA_DIR = "./tree_ring_dataset"
os.makedirs(DATA_DIR, exist_ok=True)

TR_SHAPE = (1, 4, 64, 64)
TR_RADIUS = 10
TR_CHANNEL = 3

def get_watermarking_pattern(shape, w_radius=10, w_channel=3, device='cpu'):
    g_key = torch.Generator(device).manual_seed(999999)
    mask = torch.zeros(shape, dtype=torch.bool, device=device)
    c, h, w = shape[1], shape[2], shape[3]
    y, x = torch.meshgrid(torch.arange(h, device=device), torch.arange(w, device=device), indexing='ij')
    center_y, center_x = h // 2, w // 2
    dist = torch.sqrt((y - center_y)**2 + (x - center_x)**2)
    dist_int = torch.round(dist).long()

    mask[:, w_channel, dist <= w_radius] = True
    pattern = torch.zeros(shape, dtype=torch.complex64, device=device)
    std = (h * w) ** 0.5
    ring_1d = torch.randn(w_radius + 1, generator=g_key, device=device) * std

    for r in range(1, w_radius + 1):
        pattern[:, w_channel, dist_int == r] = ring_1d[r].to(torch.complex64)
    return pattern, mask

TR_PATTERN, TR_MASK = get_watermarking_pattern(TR_SHAPE, w_radius=TR_RADIUS, w_channel=TR_CHANNEL, device=device)
torch.save(TR_PATTERN, os.path.join(DATA_DIR, "tr_pattern.pt"))
torch.save(TR_MASK, os.path.join(DATA_DIR, "tr_mask.pt"))
print("Saved Tree-Ring Pattern and Mask to disk.")

pipe_gen = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False).to(device)
pipe_gen.scheduler = DDIMScheduler.from_config(pipe_gen.scheduler.config)
pipe_gen.set_progress_bar_config(disable=True)

NUM_IMAGES = 1000
prompts = []
print("Fetching 1,000 unique prompts from DiffusionDB...")
try:
    ds = load_dataset("poloclub/diffusiondb", "2m_first_1k", split="train", streaming=True)
    for x in ds:
        if len(prompts) < NUM_IMAGES: prompts.append(x['prompt'])
        else: break
except Exception as e:
    print(f"Failed to load dataset: {e}")

fallback_prompts = ["a photo of a cat", "a futuristic city", "a red sports car", "a deep forest", "a cup of coffee"]
while len(prompts) < NUM_IMAGES:
    prompts.append(fallback_prompts[len(prompts) % len(fallback_prompts)])

print(f"Generating {NUM_IMAGES} Tree-Ring watermarked images. This will take some time...")
for i in range(NUM_IMAGES):
    g_tr = torch.Generator(device).manual_seed(1000 + i)
    init_latents = torch.randn(TR_SHAPE, generator=g_tr, device=device, dtype=pipe_gen.dtype)

    latents_fft = torch.fft.fftshift(torch.fft.fft2(init_latents.to(torch.float32)), dim=(-2, -1))
    latents_fft[TR_MASK] = TR_PATTERN[TR_MASK].to(latents_fft.dtype)
    watermarked_latents = torch.fft.ifft2(torch.fft.ifftshift(latents_fft, dim=(-2, -1))).real.to(pipe_gen.dtype)

    img = pipe_gen(prompts[i], num_inference_steps=20, latents=watermarked_latents).images[0]
    img.save(os.path.join(DATA_DIR, f"tr_img_{i}.png"))

    if (i + 1) % 50 == 0:
        print(f"  -> Generated {i + 1}/{NUM_IMAGES} images...")

print("Data generation complete. Freeing up VRAM...")
del pipe_gen
torch.cuda.empty_cache()