import os
import random
from pathlib import Path

import cv2
import numpy as np
import torch
from natsort import natsorted
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms

# Define standard transformations
transform_map = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

class ContinuousFrameDataset(Dataset):
    def __init__(self, data_dir, support_data_dir, sequence_length=200, transform=None, transform_map=None, data_fetch_mode=0):
        """
        Args:
            data_dir: Root directory path containing all subfolders
            support_data_dir: Directory path for support data
            sequence_length: Number of continuous frames in each sample
            transform: Optional image transformation operations
            transform_map: Optional map transformation operations
            data_fetch_mode: Data loading mode (0: first sequence only, 1: all sequences)
        """
        self.data_dir_name = data_dir 
        self.support_data_dir_name = support_data_dir
        self.data_dir = Path(data_dir)
        self.support_data_dir = Path(support_data_dir)
        self.sequence_length = sequence_length
        self.transform = transform
        self.transform_map = transform_map

        # Get all image paths sorted by folder and name
        self.all_images = []
        self.actions = []
        for folder in natsorted(self.data_dir.glob("*")):
            if folder.is_dir():
                folder_images = natsorted(folder.glob("*.jpg"))
                self.all_images.append(folder_images)
        
        # Build valid sequences and corresponding actions
        self.sequences = []
        if data_fetch_mode == 0:
            self._build_sequences(1)  # Only first sequence
        elif data_fetch_mode == 1:
            self._build_sequences(None)  # All sequences

    def _build_sequences(self, sequence_limit):
        for folder_images in self.all_images:
            range_end = 1 if sequence_limit else len(folder_images) - self.sequence_length + 1
            for i in range(range_end):
                self.sequences.append(folder_images[i:i+self.sequence_length])
                sequence_actions = []
                for frame_path in folder_images[i:i+self.sequence_length]:
                    frame_name = frame_path.name.lower()
                    if 'left' in frame_name:
                        sequence_actions.append(0)
                    elif 'right' in frame_name:
                        sequence_actions.append(1)
                    else:
                        sequence_actions.append(2)
                self.actions.append(torch.tensor(sequence_actions))

    def __len__(self):
        return len(self.sequences)

    def __getitem__(self, idx):
        try:
            sequence_paths = self.sequences[idx]
            frames = []
            score_shift = []
            digits = []
            maps = []
            last_score = None

            current_folder = str(sequence_paths[0]).split('/')[-2]

            for img_path in enumerate(sequence_paths):
                score = int(img_path.name.split('_')[-1].split('.')[0])
                digits.append(score)
                
                # Calculate score shift
                if last_score is None:
                    score_shift.append(0)
                else:
                    score_shift.append(1 if score > last_score else 0)
                last_score = score

                # Process image
                image = Image.open(img_path).resize((96, 96), Image.Resampling.LANCZOS).convert('RGB')
                if self.transform:
                    image = self.transform(image)
                frames.append(image)

                # Process map
                img_path_str = str(img_path)
                basename = os.path.basename(img_path_str)
                basename_index = int(basename.split('_')[0])
                map_path = img_path_str.replace('dataset', 'dataset_support').replace(
                    basename, f'{basename_index-1}_map.jpg'
                )
                map_image = Image.open(map_path).convert('RGB')
                map_tensor = self.transform_map(map_image)
                maps.append(map_tensor)

            return (
                torch.stack(frames),
                self.actions[idx],
                torch.tensor(score_shift).long(),
                torch.tensor(digits).long(),
                torch.stack(maps)
            )

        except Exception as e:
            print(f"Error processing item {idx}: {e}")
            import traceback
            traceback.print_exc()
            return self.__getitem__(random.randint(0, len(self.sequences) - 1))


if __name__ == "__main__":
    transform = transforms.Compose([
        transforms.Resize((96, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    dataset = ContinuousFrameDataset(
        data_dir='path/to/data/directory',
        support_data_dir='path/to/support/directory',
        sequence_length=32,
        data_fetch_mode=0,
        transform=transform,
        transform_map=transform_map
    )

    print(f"Dataset size: {len(dataset)}")
    sample = dataset[0]
    frames, actions, score_shift, digits, maps = sample
    print(f"Frames shape: {frames.shape}")
    print(f"Actions shape: {actions.shape}")
    print(f"Score shift shape: {score_shift.shape}")
    print(f"Digits shape: {digits.shape}")
    print(f"Maps shape: {maps.shape}")