# -*- coding: utf-8 -*-
import torch
import pandas as pd
import os 
from PIL import Image
from torchvision import transforms

class VideoDataset(torch.utils.data.Dataset):
    def __init__(self,path,img_size=128,mode="train"):
        self.path = path
        self.transform=self.transform_func(img_size)
        self.frame_loader = pd.read_csv(os.path.join(self.path,"frame_loader.csv"))
        self.mode = mode
        if mode!="train":
            self.mask_loader = pd.read_csv(os.path.join(self.path,"mask_loader.csv"))
            self.label_loader = pd.read_csv(os.path.join(self.path,"label_loader.csv"))
        
    def transform_func(self,img_size):
        return transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor()])
    def __len__(self):
        return len(self.frame_loader)
    def __getitem__(self, index):
        t0_name = self.frame_loader.loc[index,"t0"]
        t1_name = self.frame_loader.loc[index,"t1"]
        t2_name = self.frame_loader.loc[index,"t2"]
        
        t0 = Image.open(os.path.join(self.path,"frames",t0_name))
        t1 = Image.open(os.path.join(self.path,"frames",t1_name))
        t2 = Image.open(os.path.join(self.path,"frames",t2_name))
        
        #t0 = Image.open(self.path+"//full_frames//"+t0)
        #t1 = Image.open(self.path+"//full_frames//"+t1)
        #t2 = Image.open(self.path+"//full_frames//"+t2)
        
        t0 = self.transform(t0)
        t1 = self.transform(t1)
        t2 = self.transform(t2)
        
        if self.mode == "train":
            return t0,t1,t2
        else:
            l0 = self.label_loader[self.label_loader["frame"].isin([t0_name])]["label"]
            l1 = self.label_loader[self.label_loader["frame"].isin([t1_name])]["label"]
            l2 = self.label_loader[self.label_loader["frame"].isin([t2_name])]["label"]
            return t0,t1,t2,l0,l1,l2
    
