# Standard library imports
import argparse
import glob
import math
import os
import random
import shutil
import sys
import time

# Third-party imports
import cv2
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from diffusers.schedulers.scheduling_ddim import DDIMScheduler
from diffusers.schedulers.scheduling_ddpm import DDPMScheduler
from PIL import Image
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm
from typing import Type

# Local imports
from model import DiT_models
from map_constructor import MapBuilder
from vae.autoencoder import AutoencoderKL

# Image transformation pipeline
to_tensor = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=3000):
        super().__init__()
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(0)]

class LogicNet(nn.Module):
    def __init__(self, num_embeddings, embedding_dim=64):
        super(LogicNet, self).__init__()
        
        # Image processing: input [batch, 32, 24, 24]
        self.conv_layers = nn.Sequential(
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # -> [batch, 32, 12, 12]
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),  # -> [batch, 32, 6, 6]
        )
        
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        
        self.fusion_layer = nn.Sequential(
            nn.Linear(1152 + embedding_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.5)
        )
        
        self.classifier = nn.Linear(512, 1)
        
    def forward(self, image, integer):
        img_features = self.conv_layers(image)
        img_features = img_features.view(img_features.size(0), -1)
        int_features = self.embedding(integer)
        combined = torch.cat([img_features, int_features], dim=1)
        fused_features = self.fusion_layer(combined)
        logits = self.classifier(fused_features)
        return logits

class ResBlock2d(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, activation: Type[nn.Module] = nn.ReLU):
        super(ResBlock2d, self).__init__()
        self.activation = activation()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.shortcut = nn.Identity()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
            )

    def forward(self, x):
        out = self.activation(self.conv1(x))
        out = self.conv2(out)
        out += self.shortcut(x)
        return out

class MNISTDiffusionTransformer(nn.Module):
    def __init__(self, model_name='I24_S_2'):
        super().__init__()
        self.num_timesteps = 1000
        self.noise_scheduler = DDPMScheduler(
            num_train_timesteps=1000,
            beta_start=1e-4,
            beta_end=0.02,
            clip_sample=False
        )

        self.model = DiT_models[model_name](in_channels=36)
        self.time_embedding = nn.Embedding(1024, 192)
        self.action_embedding = nn.Embedding(10, 192)
        self.conv_out = nn.Conv2d(36, 32, kernel_size=3, padding=1)
        self.x_from_z = nn.Sequential(
            ResBlock2d(32, 4),
            nn.Conv2d(4, 4, 1, padding=0),
        )

        self.logic_net = LogicNet(num_embeddings=10, embedding_dim=64)
        self.learnable_init_state = nn.Parameter(torch.randn(1, 32, 24, 24))
        self.digit_embedding = nn.Embedding(10, 384)

        self.map_conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 384, kernel_size=3, padding=1),
        )

    def forward(self, x, t, action, digits, maps, hidden_state=None, directly_return_score_shift_result=False, inference=False):
        B = x.shape[0]

        if hidden_state is None:
            hidden_state = self.learnable_init_state.repeat(B, 1, 1, 1)

        score_shift_result = self.logic_net(hidden_state, action)

        if directly_return_score_shift_result:
            return score_shift_result

        combined_maps = maps
        combined_maps = self.map_conv(combined_maps)
        combined_maps = combined_maps.flatten(2).transpose(1, 2)

        digits = digits.unsqueeze(-1)
        hundreds = (digits // 100) % 10
        tens = (digits // 10) % 10 
        ones = digits % 10

        digit_tokens = torch.cat([hundreds, tens, ones], dim=1)
        digit_embeddings = self.digit_embedding(digit_tokens)

        if not inference:
            noise = torch.randn_like(x)
            x_noisy = self.noise_scheduler.add_noise(x, noise, t)
            epsilon = noise
        else:
            x_noisy = x

        x = torch.cat([x_noisy, hidden_state], dim=1)

        time_token = self.time_embedding(t)
        action_token = self.action_embedding(action)
        conditions = torch.cat([time_token, action_token], dim=1)
        conditions = conditions.squeeze(1)

        x = self.model(x, conditions, digit_embeddings, combined_maps)
        z = self.conv_out(x)
        x = self.x_from_z(z)

        if not inference:
            return x, epsilon, z, score_shift_result
        else:
            return x, None, z, score_shift_result

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_image", type=str, required=False, help="Path to input image", default='data/output_90/1_right_38.jpg')
    parser.add_argument("--output_dir", type=str, default="inference_results", help="Output directory")
    parser.add_argument("--num_frames", type=int, default=12, help="Number of frames to generate")
    parser.add_argument("--actions", type=str, default="5,1,1,1,1,1,1,1,1,1,1,1", help="Comma-separated list of actions")
    args = parser.parse_args()

    args.input_image = 'data/output_0/7_left_6.jpg'

    # Clean output directory
    if os.path.exists(args.output_dir):
        shutil.rmtree(args.output_dir)
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Define action sequence
    args.actions = [0] * 32 + [1] * 32
    args.num_frames = len(args.actions)

    # Initialize model and move to GPU
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = MNISTDiffusionTransformer().to(device)

    # Load model checkpoint
    ckpt_path = 'checkpoints/last_epoch.pth'
    state_dict = torch.load(ckpt_path, map_location='cpu')['model_state_dict']
    model.load_state_dict(state_dict, strict=True)
    print(f"Loaded model checkpoint from {ckpt_path}")

    # Load VAE
    vae = AutoencoderKL().to(device)
    vae_path = 'vae_epoch_last_iter.pt'
    state_dict = torch.load(vae_path, map_location='cpu')['model_state_dict']
    new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
    vae.load_state_dict(new_state_dict, strict=True)

    # Prepare input image
    transform = transforms.Compose([
        transforms.Resize((96, 96)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ])
    
    input_image = Image.open(args.input_image).resize((96, 96), Image.Resampling.LANCZOS).convert('RGB')
    input_tensor = transform(input_image).unsqueeze(0).to(device)

    # Parse actions
    actions = args.actions
    if len(actions) < args.num_frames:
        actions.extend([1] * (args.num_frames - len(actions)))
    actions = torch.tensor(actions[:args.num_frames]).unsqueeze(0).to(device)

    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)

    model.noise_scheduler.set_timesteps(20)
    model.eval()

    map_builder = MapBuilder()

    with torch.no_grad():
        # Process first frame
        start_time = time.time()
        first_frame = vae.encode(input_tensor).sample() / 1.11
        t = torch.zeros(1, device=device)
        action = actions[:, 0]
        action = action - action + 5  # Initialize

        init_score = int(args.input_image.split('/')[-1].replace('.jpg', '').split('_')[-1])
        digits = torch.zeros(1, device=device).long().cuda() + init_score

        init_maps = torch.zeros(1, 3, 96, 128, device=device) - 1    

        predicted_noise, _, hidden, score_shift_result = model(first_frame, t.long(), action.long(), digits, init_maps, inference=True)
        
        first_frame = model.noise_scheduler.step(
            predicted_noise,
            0,
            first_frame
        ).prev_sample
        
        first_frame_decoded = vae.decode(first_frame * 1.11)
        generated_frames = [first_frame_decoded[0]]

        first_frame_pil = transforms.ToPILImage()(torch.clamp((generated_frames[-1].cpu()+1)/2, 0, 1))
        first_frame_pil.save(os.path.join(args.output_dir, "observation_0.jpg"))

        frame_times = [time.time() - start_time]
        print(f"Frame 0 generated in {frame_times[0]:.2f} seconds")

        current_frame = first_frame

        map_pil, map_cache_pil, recorded_time_map_pil, recorded_time_map_cache_pil = map_builder.register_observation(np.array(first_frame_pil), 0)
        map_tensor = to_tensor(map_pil).unsqueeze(0).to(device)

        # Generate remaining frames
        for frame_idx in range(1, args.num_frames):
            frame_start_time = time.time()

            current_action = actions[:, frame_idx]
            current_noise = torch.randn_like(current_frame, device=device)
            
            for _, t_idx in enumerate(tqdm(model.noise_scheduler.timesteps)):
                t = torch.ones(1, device=device) * t_idx

                if _ == 0:
                    score_shift = model(
                        current_noise,
                        t.long(),
                        current_action.long(), 
                        digits,
                        map_tensor,
                        hidden,
                        directly_return_score_shift_result=True,
                        inference=True
                    )
                    print(f'index {_} score_shift: {score_shift}')

                    if score_shift < 0:
                        pass
                    else:
                        digits = digits + 1
                
                noise_pred, _, hidden_pred, score_shift_result = model(
                    current_noise,
                    t.long(),
                    current_action.long(), 
                    digits,
                    map_tensor,
                    hidden,
                    inference=True
                )

                current_noise = model.noise_scheduler.step(
                    noise_pred,
                    t_idx,
                    current_noise
                ).prev_sample
                
                if t_idx == 0:
                    hidden = hidden_pred
            
            current_frame = current_noise
            decoded_frame = vae.decode(current_frame * 1.11)
            generated_frames.append(decoded_frame[0])

            current_frame_pil = transforms.ToPILImage()(torch.clamp((decoded_frame[0].cpu()+1)/2, 0, 1))
            current_frame_pil.save(os.path.join(args.output_dir, f"observation_{frame_idx}.jpg"))

            map_pil, map_cache_pil, recorded_time_map_pil, recorded_time_map_cache_pil = map_builder.register_observation(np.array(current_frame_pil), frame_idx)
            
            map_pil.save(os.path.join(args.output_dir, f"map_{frame_idx}.jpg"))
            map_tensor = to_tensor(map_pil).unsqueeze(0).to(device)

            frame_time = time.time() - frame_start_time
            frame_times.append(frame_time)
            print(f"Frame {frame_idx} generated in {frame_time:.2f} seconds")

        # Save frames as MP4 video
        frames_array = []
        for frame in generated_frames:
            frame_np = torch.clamp((frame.cpu() + 1) / 2, 0, 1)
            frame_np = (frame_np * 255).numpy().astype('uint8')
            frame_np = frame_np.transpose(1, 2, 0)
            frames_array.append(frame_np)

        height, width = frames_array[0].shape[:2]
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        video_path = os.path.join(args.output_dir, 'generated_video.mp4')
        out = cv2.VideoWriter(video_path, fourcc, 24.0, (width, height))

        for frame in frames_array:
            frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
            out.write(frame_bgr)

        out.release()
        print(f"Video saved to {video_path}")
