# Standard library imports
import glob
import math
import os
import random
import sys
from datetime import datetime, timedelta
from functools import partial
from typing import List, Optional, Type

# Third-party imports
import cv2
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import wandb
from PIL import Image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

# Local imports
from data import ContinuousFrameDataset
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from model import DiT_models
from vae.autoencoder import AutoencoderKL

print('cpu count', os.cpu_count())

# Set all random seeds for reproducibility
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)
np.random.seed(42)
random.seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

class PositionalEncoding(nn.Module):
    """Positional encoding layer for transformer models."""
    def __init__(self, d_model, max_len=3000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)]

class LogicNet(nn.Module):
    """Network for processing logical operations on image and integer inputs."""
    def __init__(self, num_embeddings, embedding_dim=64):
        super(LogicNet, self).__init__()
        
        # Image processing: input [batch, 32, 24, 24]
        self.conv_layers = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # -> [batch, 32, 12, 12]
            
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # -> [batch, 32, 6, 6]
        )
        
        # Integer embedding layer
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        
        # Fusion layer for combined features
        self.fusion_layer = nn.Sequential(
            nn.Linear(1152 + embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        # Classification layer
        self.classifier = nn.Linear(512, 1)

    def forward(self, image, integer):
        img_features = self.conv_layers(image)  # [batch, 128, 4, 4]
        img_features = img_features.view(img_features.size(0), -1)  # [batch, 2048]
        
        int_features = self.embedding(integer)  # [batch, embedding_dim]
        
        combined = torch.cat([img_features, int_features], dim=1)
        fused_features = self.fusion_layer(combined)
        
        logits = self.classifier(fused_features)
        
        return logits

class ResBlock2d(nn.Module):
    """2D Residual block with configurable activation."""
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, activation: Type[nn.Module] = nn.ReLU):
        super(ResBlock2d, self).__init__()
        self.activation = activation()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)

        self.shortcut = nn.Identity()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
            )

    def forward(self, x):
        out = self.activation(self.conv1(x))
        out = self.conv2(out)
        out += self.shortcut(x)
        return out

class MNISTDiffusionTransformer(nn.Module):
    """Transformer model for MNIST diffusion with additional context processing."""
    def __init__(self, model_name='I24_S_2'):
        super().__init__()
        
        # Diffusion hyperparameters
        self.num_timesteps = 1000
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=1000,
            beta_start=1e-4,
            beta_end=0.02,
            clip_sample=False
        )

        # Core transformer model
        self.model = DiT_models[model_name](in_channels=36)

        # Embedding layers
        self.time_embedding = nn.Embedding(1024, 192)
        self.action_embedding = nn.Embedding(10, 192)
        self.digit_embedding = nn.Embedding(10, 384)

        # Convolutional layers
        self.conv_out = nn.Conv2d(36, 32, kernel_size=3, padding=1)
        self.x_from_z = nn.Sequential(
            ResBlock2d(32, 4),
            nn.Conv2d(4, 4, 1, padding=0),
        )

        # Logic network for score prediction
        self.logic_net = LogicNet(num_embeddings=10, embedding_dim=64)
        
        # Learnable initial state
        self.learnable_init_state = nn.Parameter(torch.randn(1, 32, 24, 24))

        # Map processing network
        self.map_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 384, kernel_size=3, padding=1),
        )

    def forward(self, x, t, action, digits, maps, hidden_state=None, inference=False):
        """
        Forward pass of the model.
        
        Args:
            x: Input tensor
            t: Timestep tensor
            action: Action tensor
            digits: Digit tensor
            maps: Map tensor
            hidden_state: Optional hidden state tensor
            inference: Boolean indicating inference mode
        """
        B = x.shape[0]

        # Process maps
        combined_maps = maps
        combined_maps = self.map_conv(combined_maps)
        combined_maps = combined_maps.flatten(2).transpose(1, 2)

        # Process digits
        digits = digits.unsqueeze(-1)
        hundreds = (digits // 100) % 10
        tens = (digits // 10) % 10 
        ones = digits % 10
        digit_tokens = torch.cat([hundreds, tens, ones], dim=1)
        digit_embeddings = self.digit_embedding(digit_tokens)

        # Initialize or use provided hidden state
        if hidden_state is None:
            hidden_state = self.learnable_init_state.repeat(B, 1, 1, 1)

        score_shift_result = self.logic_net(hidden_state, action)

        # Add noise during training
        if not inference:
            noise = torch.randn_like(x)
            x_noisy = self.noise_scheduler.add_noise(x, noise, t)
            epsilon = noise
        else:
            x_noisy = x

        # Combine input and hidden state
        x = torch.cat([x_noisy, hidden_state], dim=1) 

        # Process time and action embeddings
        time_token = self.time_embedding(t)
        action_token = self.action_embedding(action)
        conditions = torch.cat([time_token, action_token], dim=1)
        conditions = conditions.squeeze(1)

        # Forward through transformer
        x = self.model(x, conditions, digit_embeddings, combined_maps)

        # Final processing
        z = self.conv_out(x)
        x = self.x_from_z(z)

        if not inference:
            return x, epsilon, z, score_shift_result
        else:
            return x, None, z, score_shift_result

if __name__ == "__main__":
    # Parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument("--local-rank", type=int, default=-1)
    parser.add_argument("--batch_size", type=int, default=12)
    parser.add_argument("--num_epochs", type=int, default=50)
    parser.add_argument("--sequence_length", type=int, default=32)
    parser.add_argument("--data_fetch_mode", type=int, default=0)
    parser.add_argument("--save_dir", type=str, default="results/")
    parser.add_argument("--data_dir", type=str, default="data/dataset")
    parser.add_argument("--support_data_dir", type=str, default="data/")
    parser.add_argument("--lr", type=float, default=5e-4)
    parser.add_argument("--postfix", type=str, default='')
    parser.add_argument("--dynamic", type=int, default=1)
    parser.add_argument("--model_name", type=str, default='I24_S_2')
    parser.add_argument("--nomap", default=False, action='store_true')
    parser.add_argument("--nodigit", default=False, action='store_true')
    args = parser.parse_args()

    # Setup save directory with timestamp
    current_date = (datetime.now() + timedelta(hours=8)).strftime("%m%d")
    args.save_dir = f"{args.save_dir}/{current_date}_seq{args.sequence_length}_bs{args.batch_size}_fetch{args.data_fetch_mode}_model{args.model_name}_dynamic{args.dynamic}_lr{args.lr}_{args.postfix}"
    os.makedirs(args.save_dir, exist_ok=True)
    
    # Initialize distributed training
    torch.cuda.set_device(args.local_rank)
    dist.init_process_group(backend='nccl')
    device = torch.device(f"cuda:{args.local_rank}")
    
    # Initialize model
    model = MNISTDiffusionTransformer(model_name=args.model_name).to(device)
    if args.local_rank == 0:
        print(f'Total number of model parameters: {sum(p.numel() for p in model.parameters()):,}')
    model = DDP(model, device_ids=[args.local_rank])

    # Initialize VAE
    vae = AutoencoderKL()
    state_dict = torch.load('vae_epoch_last_iter.pt', map_location='cpu')['model_state_dict']
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    vae.load_state_dict(new_state_dict, strict=True)
    vae = vae.to(device)
    
    # Setup optimizer and data transforms
    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr)
    transform = transforms.Compose([
        transforms.Resize((96, 96)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    transform_map = transforms.Compose([
        transforms.Resize((96, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    
    # Setup dataset and dataloader
    dataset = ContinuousFrameDataset(
        data_dir=args.data_dir,
        support_data_dir=args.support_data_dir,
        sequence_length=args.sequence_length,
        transform=transform,
        transform_map=transform_map,
        data_fetch_mode=args.data_fetch_mode
    )
    train_sampler = DistributedSampler(dataset)
    train_loader = DataLoader(
        dataset, 
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=16,
        sampler=train_sampler,
        pin_memory=True
    )

    # Training loop
    for epoch in range(args.num_epochs):
        train_sampler.set_epoch(epoch)
        model.train()

        # Save checkpoint at epoch start
        if args.local_rank == 0:
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.module.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, f'{args.save_dir}/last_epoch_{epoch}.pth')

        for batch_idx, (images, actions, score_shifts, digits, maps) in enumerate(tqdm(train_loader)):
            # Save periodic checkpoints
            if args.local_rank == 0 and batch_idx % 500 == 0:
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.module.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, f'{args.save_dir}/last_iter.pth')

            # Move data to device
            images = images.to(device)
            actions = actions.to(device)
            digits = digits.to(device)
            score_shifts = score_shifts.to(device)
            maps = maps.to(device)

            if args.nodigit:
                digits = digits - digits

            # Generate and save sample images periodically
            if (batch_idx + 0) % 500 == 0 and args.local_rank == 0:
                generate_and_save_samples(model, vae, images, actions, digits, maps, epoch, batch_idx, args.save_dir)

            # Training step
            B, T = images.shape[:2]
            total_loss = []
            for frame_idx in range(T):
                current_frame = images[:, frame_idx]
                current_action = actions[:, frame_idx] if frame_idx > 0 else torch.ones_like(actions[:, 0]) * 5
                current_digits = digits[:, frame_idx]
                current_maps = maps[:, frame_idx]

                if args.nomap:
                    current_maps = current_maps - current_maps

                # Encode through VAE
                with torch.no_grad():
                    current_frame = vae.encode(current_frame).sample() / 1.11

                # Random timestep for training
                t = torch.randint(0, model.module.num_timesteps, (B,), device=device)

                optimizer.zero_grad()

                # Forward pass
                outputs, epsilon, hidden, score_shift_result = model(
                    current_frame, t.long(), current_action.long(), 
                    current_digits, current_maps, 
                    None if frame_idx == 0 else hidden
                )

                # Calculate losses
                score_shift_loss = F.binary_cross_entropy_with_logits(
                    score_shift_result, 
                    score_shifts[:, frame_idx].unsqueeze(1).float()
                )
                pixel_loss = F.mse_loss(outputs, epsilon)
                
                # Combine losses
                if args.nodigit:
                    total_loss.append(pixel_loss + score_shift_loss * 0)
                else:
                    total_loss.append(pixel_loss + score_shift_loss * 0.0001)

            # Backward pass
            mean_loss = torch.stack(total_loss).mean()
            mean_loss.backward()
            optimizer.step()

            # Logging
            if batch_idx % 1 == 0 and args.local_rank == 0:
                print(f'Epoch: {epoch}, Batch: {batch_idx}, Loss: {mean_loss}, pixel_loss: {pixel_loss}, score_shift_loss: {score_shift_loss}')

            if batch_idx % 10 == 0 and args.local_rank == 0:
                wandb.log({
                    "train/loss": mean_loss,
                    "train/pixel_loss": pixel_loss,
                    "train/score_shift_loss": score_shift_loss,
                    "epoch": epoch,
                    "batch": batch_idx
                })

    # Cleanup
    dist.destroy_process_group()

def generate_and_save_samples(model, vae, images, actions, digits, maps, epoch, batch_idx, save_dir):
    """Generate and save sample images and videos during training."""
    model.eval()
    model.module.noise_scheduler.set_timesteps(50)

    with torch.no_grad():
        # Take first sequence from batch
        eval_images = images[0:1]
        eval_actions = actions[0:1]
        eval_digits = digits[0:1]
        eval_maps = maps[0:1]

        generated_frames = []
        hidden = None

        # Process first frame
        first_frame = eval_images[:,0]
        first_frame = vae.encode(first_frame).sample() / 1.11
        t = torch.zeros(1, device=first_frame.device)
        action = torch.ones(1, device=first_frame.device) * 5

        predicted_frame, _, hidden, _ = model(
            first_frame, t.long(), action.long(), 
            eval_digits[:,0], eval_maps[:,0], 
            hidden, inference=True
        )

        # Decode first frame
        with torch.no_grad():
            predicted_frame = model.module.noise_scheduler.step(
                predicted_frame, 0, first_frame
            ).prev_sample
            first_frame_decoded = vae.decode(predicted_frame * 1.11)
            generated_frames.append(first_frame_decoded[0])

        # Generate remaining frames
        current_frame = first_frame
        for frame_idx in range(1, eval_images.shape[1]):
            current_action = eval_actions[:,frame_idx-1]
            current_digits = eval_digits[:,frame_idx]
            current_maps = eval_maps[:,frame_idx]

            # DDPM sampling
            current_noise = torch.randn_like(current_frame)
            for t_idx in tqdm(model.module.noise_scheduler.timesteps):
                t = torch.ones(1, device=current_frame.device) * t_idx
                noise_pred, _, hidden_pred, _ = model(
                    current_noise, t.long(), current_action.long(),
                    current_digits, current_maps, hidden, inference=True
                )

                current_noise = model.module.noise_scheduler.step(
                    model_output=noise_pred,
                    timestep=t_idx,
                    sample=current_noise
                ).prev_sample

                if t_idx == 0:
                    hidden = hidden_pred

            current_frame = current_noise
            decoded_frame = vae.decode(current_frame * 1.11)
            generated_frames.append(decoded_frame[0])

        # Save frames and video
        for t, gen_frame in enumerate(generated_frames):
            save_path = os.path.join(save_dir, f"epoch_{epoch}_batch_{batch_idx}_gen_frame_{t}.jpg")
            torchvision.utils.save_image(gen_frame, save_path)

        # Create video
        video_path = os.path.join(save_dir, f"epoch_{epoch}_batch_{batch_idx}_video.mp4")
        out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), 10, (96, 96))
        
        for frame in generated_frames:
            frame = frame.cpu().permute(1,2,0).numpy()
            frame = np.clip(frame, 0, 1)
            frame = (frame * 255).astype(np.uint8)
            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            out.write(frame)
        
        out.release()

    print(f"Generated sample images and videos for epoch {epoch}")
    model.train()
    model.module.noise_scheduler.set_timesteps(1000)
        
        