import os
import sys

os.system("pip install -q diffusers transformers accelerate pandas torch-dct")

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import pandas as pd

WEIGHTS_FILE = "rivagan_32bit_model.pt"

if not os.path.exists("RivaGAN"):
    os.system("git clone https://github.com/DAI-Lab/RivaGAN.git")
sys.path.append('RivaGAN')

from diffusers import StableDiffusionPipeline
from rivagan import RivaGAN
from rivagan.attention import AttentiveEncoder, AttentiveDecoder

class DummyNetwork(nn.Module):
    def __init__(self, *args, **kwargs): super().__init__()

import __main__
__main__.RivaGAN = RivaGAN
__main__.AttentiveEncoder = AttentiveEncoder
__main__.AttentiveDecoder = AttentiveDecoder
__main__.Adversary = DummyNetwork
__main__.Critic = DummyNetwork

device = "cuda" if torch.cuda.is_available() else "cpu"
DATA_DIR = "./tree_ring_dataset"

class VAELatentProxyLoss(nn.Module):
    def __init__(self, vae, mask):
        super().__init__()
        self.vae = vae
        self.vae.eval()
        for param in self.vae.parameters(): param.requires_grad = False
        self.mask_sq = mask.squeeze(0).to(device)

    def forward(self, img_orig_scaled, img_wm_scaled):
        # Tensors are already scaled to [-1, 1]
        with torch.no_grad():
            lat_orig = self.vae.encode(img_orig_scaled).latent_dist.mean * self.vae.config.scaling_factor
            fft_orig = torch.fft.fftshift(torch.fft.fft2(lat_orig), dim=(-2, -1))
            ext_orig = fft_orig[:, self.mask_sq]

        lat_wm = self.vae.encode(img_wm_scaled).latent_dist.mean * self.vae.config.scaling_factor
        fft_wm = torch.fft.fftshift(torch.fft.fft2(lat_wm), dim=(-2, -1))
        ext_wm = fft_wm[:, self.mask_sq]

        return torch.mean(torch.abs(ext_wm - ext_orig))

class TreeRingDataset(Dataset):
    def __init__(self, data_dir):
        self.image_paths = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith('.png')]
        self.transform = transforms.Compose([
            transforms.Resize((512, 512)),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # Scale to [-1, 1] natively
        ])
    def __len__(self): return len(self.image_paths)
    def __getitem__(self, idx): return self.transform(Image.open(self.image_paths[idx]).convert('RGB'))

class SynergisticAdapter(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(32, 3, kernel_size=3, padding=1)
        nn.init.zeros_(self.conv3.weight)
        nn.init.zeros_(self.conv3.bias)

    def forward(self, img_4d):
        x = F.relu(self.conv1(img_4d))
        x = F.relu(self.conv2(x))
        delta = self.conv3(x)
        # RivaGAN operates in [-1, 1], so clamp accordingly
        return torch.clamp(img_4d + delta, -1.0, 1.0)

pipe_eval = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", safety_checker=None, requires_safety_checker=False).to(device)
loaded_mask = torch.load(os.path.join(DATA_DIR, "tr_mask.pt"), weights_only=False).to(device)
proxy_criterion = VAELatentProxyLoss(pipe_eval.vae, loaded_mask).to(torch.float32)

real_rivagan = torch.load(WEIGHTS_FILE, map_location=device, weights_only=False)

if hasattr(real_rivagan, 'encoder'):
    riva_encoder = real_rivagan.encoder.to(device).eval()
    riva_decoder = real_rivagan.decoder.to(device).eval()
else:
    temp_rivagan = RivaGAN()
    if hasattr(temp_rivagan, 'load_state_dict'):
        temp_rivagan.load_state_dict(real_rivagan, strict=False)
    riva_encoder = temp_rivagan.encoder.to(device).eval()
    riva_decoder = temp_rivagan.decoder.to(device).eval()

for param in riva_encoder.parameters(): param.requires_grad = False
for param in riva_decoder.parameters(): param.requires_grad = False

adapter = SynergisticAdapter().to(device)
EPOCHS = 40
ACCUMULATION_STEPS = 8
optimizer = torch.optim.Adam(adapter.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)

dataset = TreeRingDataset(DATA_DIR)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

def get_lambda(epoch):
    if epoch < 2: return 0.0
    elif epoch < 6: return 0.02
    elif epoch < 12: return 0.05
    else: return 0.1

training_logs = []

for epoch in range(EPOCHS):
    adapter.train()
    current_lambda = get_lambda(epoch)
    print(f"\n--- EPOCH {epoch+1}/{EPOCHS} | LR: {scheduler.get_last_lr()[0]:.6f} | Proxy Penalty: {current_lambda} ---")
    optimizer.zero_grad()

    for batch_idx, img_tr_scaled in enumerate(dataloader):
        img_tr_scaled = img_tr_scaled.to(device)
        img_tr_5d = img_tr_scaled.unsqueeze(2)

        target_msg = torch.randint(0, 2, (img_tr_scaled.size(0), 32), dtype=torch.float32).to(device)

        with torch.no_grad():
            output = riva_encoder(img_tr_5d, target_msg)
            img_combined_5d = output[0] if isinstance(output, tuple) else output
            img_combined_4d = img_combined_5d.squeeze(2)

        img_corrected_4d = adapter(img_combined_4d) # Stays in [-1, 1]

        img_01 = (img_corrected_4d + 1.0) / 2.0
        rand_brightness = 0.8 + 0.4 * torch.rand((img_01.size(0), 1, 1, 1), device=device)
        rand_noise = torch.randn_like(img_01) * 0.015
        img_aug_01 = torch.clamp((img_01 * rand_brightness) + rand_noise, 0.0, 1.0)
        img_aug_5d = ((img_aug_01 * 2.0) - 1.0).unsqueeze(2) # Back to [-1, 1]

        decoded_msg = riva_decoder(img_aug_5d)

        loss_msg = F.binary_cross_entropy_with_logits(decoded_msg, target_msg)
        loss_img = F.mse_loss(img_corrected_4d, img_tr_scaled.detach()) * 2.0

        if current_lambda > 0.0:
            loss_interf = proxy_criterion(img_tr_scaled, img_corrected_4d)
        else:
            loss_interf = torch.tensor(0.0).to(device)

        loss_total = ((loss_msg * 3.0) + loss_img + (current_lambda * loss_interf)) / ACCUMULATION_STEPS
        loss_total.backward()

        if (batch_idx + 1) % ACCUMULATION_STEPS == 0 or (batch_idx + 1) == len(dataloader):
            torch.nn.utils.clip_grad_norm_(adapter.parameters(), max_norm=1.0)
            optimizer.step()
            optimizer.zero_grad()

        if (batch_idx + 1) % 50 == 0:
            with torch.no_grad():
                pred_bits = (torch.sigmoid(decoded_msg) > 0.5).float()

                bit_acc = (pred_bits == target_msg).float().mean().item() * 100.0

                match_bits = (pred_bits == target_msg) # [Batch, 32]
                byte_match = match_bits.view(-1, 4, 8).all(dim=2).float() # [Batch, 4 bytes]
                byte_acc = byte_match.mean().item() * 100.0

            training_logs.append({
                "Epoch": epoch + 1, "Batch": batch_idx + 1,
                "Bit_Acc": bit_acc, "Byte_Acc": byte_acc,
                "L_Msg": loss_msg.item() * ACCUMULATION_STEPS,
                "L_Img": loss_img.item(),
                "Proxy_Interf": loss_interf.item()
            })
            print(f"Batch {batch_idx+1:03d} | Bit Acc: {bit_acc:3.0f}% | Byte Acc: {byte_acc:3.0f}% | L_Msg: {loss_msg.item()*ACCUMULATION_STEPS:.3f} | L_Img: {loss_img.item():.4f}")

    scheduler.step()

torch.save(adapter.state_dict(), "synergistic_adapter.pth")
pd.DataFrame(training_logs).to_csv("training_log_synergistic.csv", index=False)