import torch
from torch.utils.data import Dataset
import json
import os
import io
from PIL import Image
from utils import transform
import google.auth
from google.cloud import storage
from torchvision import transforms
import sparseconvnet as scn


resize = transforms.Resize((300, 300))
to_tensor = transforms.ToTensor()

# get data from google cloud
credentials, project = google.auth.default()  
client = storage.Client(credentials=credentials)
bucket_name = "ft-annotation-analysis"
BUCKET= client.get_bucket(bucket_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def read_file_gcs(BUCKET, general_file_name):
    """
    Read a file from the bucket
    """
    # get bucket data as blob
    blob = BUCKET.blob(general_file_name)
    if not blob.exists():  # Will be false if the next line would raise NotFound
        return False
    # convert to string
    data = blob.download_as_string()
    return data

class IXRealDataset_local(Dataset):
    def __init__(self, data_folder, split, keep_difficult=True):
        """
        :param data_folder: folder where data files are stored
        :param keep_difficult: keep or discard objects that are considered difficult to detect?
        """

        self.split = split
        self.data_folder = data_folder
        self.keep_difficult = keep_difficult

        # get image paths
        with open(os.path.join(data_folder, self.split + '_images.json'), 'r') as j:
            self.images = json.load(j)
        with open(os.path.join(data_folder, self.split + '_objects.json'), 'r') as j:
            self.objects = json.load(j)

        assert len(self.images) == len(self.objects)

    def __getitem__(self, i):
        image = Image.open(self.images[i], mode='r')
        image = image.convert('RGB')
        objects = self.objects[i]
        boxes = torch.FloatTensor(objects['boxes'])
        labels = torch.LongTensor(objects['labels'])
        difficulties = torch.ByteTensor(objects['difficulties'])
        if not self.keep_difficult:
            boxes = boxes[1 - difficulties]
            labels = labels[1 - difficulties]
            difficulties = difficulties[1 - difficulties]

        # Apply transformations
        image, boxes, labels, difficulties = transform(image, boxes, labels, difficulties, split='notTRAIN')

        return image, boxes, labels, difficulties

    def __len__(self):
        return len(self.images)

    def collate_fn(self, batch):
        """
        Since each image may have a different number of objects, we need a collate function (to be passed to the DataLoader).

        This describes how to combine these tensors of different sizes. We use lists.

        Note: this need not be defined in this Class, can be standalone.

        :param batch: an iterable of N sets from __getitem__()
        :return: a tensor of images, lists of varying-size tensors of bounding boxes, labels, and difficulties
        """

        images = list()
        boxes = list()
        labels = list()
        difficulties = list()

        for b in batch:
            images.append(b[0])
            boxes.append(b[1])
            labels.append(b[2])
            difficulties.append(b[3])

        images = torch.stack(images, dim=0)

        return images, boxes, labels, difficulties  # tensor (N, 3, 300, 300), 3 lists of N tensors each



class IXRealDataset_cloud(Dataset):
    def __init__(self, data_folder, split, keep_difficult=True):
        """
        :param data_folder: folder where data files are stored
        :param keep_difficult: keep or discard objects that are considered difficult to detect?
        """

        self.split = split
        self.data_folder = data_folder
        self.keep_difficult = keep_difficult

        # get image paths
        with open(os.path.join(data_folder, self.split + '_images.json'), 'r') as j:
            self.images = json.load(j)
        with open(os.path.join(data_folder, self.split + '_objects.json'), 'r') as j:
            self.objects = json.load(j)

        assert len(self.images) == len(self.objects)

    def __getitem__(self, i):
        image_path = self.images[i]
        img_stream = read_file_gcs(BUCKET, image_path)
        image = Image.open(io.BytesIO(img_stream), mode='r')
        image = image.convert('RGB')
        
        objects = self.objects[i]
        boxes = torch.FloatTensor(objects['boxes'])
        labels = torch.LongTensor(objects['labels'])
        difficulties = torch.ByteTensor(objects['difficulties'])
        if not self.keep_difficult:
            boxes = boxes[1 - difficulties]
            labels = labels[1 - difficulties]
            difficulties = difficulties[1 - difficulties]

        # Apply transformations
        image, boxes, labels, difficulties = transform(image, boxes, labels, difficulties, split='notTRAIN')

        return image, boxes, labels, difficulties

    def __len__(self):
        return len(self.images)

    def collate_fn(self, batch):
        """
        Since each image may have a different number of objects, we need a collate function (to be passed to the DataLoader).

        This describes how to combine these tensors of different sizes. We use lists.

        Note: this need not be defined in this Class, can be standalone.

        :param batch: an iterable of N sets from __getitem__()
        :return: a tensor of images, lists of varying-size tensors of bounding boxes, labels, and difficulties
        """

        images = list()
        boxes = list()
        labels = list()
        difficulties = list()

        for b in batch:
            images.append(b[0])
            boxes.append(b[1])
            labels.append(b[2])
            difficulties.append(b[3])

        images = torch.stack(images, dim=0)

        return images, boxes, labels, difficulties  # tensor (N, 3, 300, 300), 3 lists of N tensors each



class IXRealDataset_sparse_local(Dataset):
    """
    A PyTorch Dataset class to be used in a PyTorch DataLoader to create batches.
    """

    def __init__(self, img_path, obj_path, seg_path, split, keep_difficult=False, sparse=True):
        """
        :param data_folder: folder where data files are stored
        :param split: split, prefix of file name
        :param keep_difficult: keep or discard objects that are considered difficult to detect?
        """
        self.split = split
        assert self.split in {'TRAIN', 'TEST'}
        
        self.keep_difficult = keep_difficult
        self.sparse = sparse

        # Read data files
        with open(img_path, 'r') as j:
            self.images = json.load(j)
        # Read objects
        with open(obj_path, 'r') as j:
            self.objects = json.load(j)
        # Read segmentation map
        with open(seg_path, 'r') as j:
            self.segs = json.load(j)

        assert len(self.images) == len(self.objects) == len(self.segs)


    def __getitem__(self, i):
        # Read image
        image = Image.open(self.images[i], mode='r')
        image = image.convert('RGB')
        
        # Read objects in this image (bounding boxes, labels, difficulties)
        objects = self.objects[i]
        boxes = torch.FloatTensor(objects['boxes'])  # (n_objects, 4)
        labels = torch.LongTensor(objects['labels'])  # (n_objects)
        difficulties = torch.ByteTensor(objects['difficulties'])  # (n_objects)

        # Read segmentation maps
        seg = Image.open(self.segs[i], mode='r')

        # Discard difficult objects, if desired
        if not self.keep_difficult:
            boxes = boxes[1 - difficulties]
            labels = labels[1 - difficulties]
            difficulties = difficulties[1 - difficulties]

        # Apply transformations
        image, boxes, labels, difficulties = transform(image, boxes, labels, difficulties, split=self.split) #image.shape = [3, 300, 300]
        seg = to_tensor(resize(seg))
        #change segmentation nonzero value to 1, otherwise these nonzero values (0,1] will be forced to 0 at compressImg step
        seg[seg!=0] = 1
        
        #Convert sparse image to compressed dense format
        if self.sparse:
            location, feature = self.compressImg(image, seg)
            return location, feature, boxes, labels, difficulties
            
        return image, boxes, labels, difficulties

    def compressImg(self, image, seg):
        locations = torch.nonzero(seg[0].int()) # locations.shape = [x,y]
        # pad locations third column with 0
        padding = torch.nn.ConstantPad1d((0, 1), 0)
        locations = padding(locations) # locations.shape = [x,y,0]
        # use gather to further speedup
        features = []
        for loc in locations:
            i, j = loc[0], loc[1]
            features.append([image[0][i][j], image[1][i][j], image[2][i][j]])
        features = torch.FloatTensor(features).to(device)
    
        return locations, features
    
    def __len__(self):
        return len(self.images)

    def collate_fn_sparse(self, batch):
        #add batch info to image location list 
        #stack image to images 
        locations = list()
        features = list()
        boxes = list()
        labels = list()
        difficulties = list()
        
        for b_id, b in enumerate(batch):
            #update batch info in each image location tensor 
            for location in b[0]:
                location[2] = b_id
            #locations -> b[0]   features -> b[1]
            locations.append(b[0])
            features.append(b[1])
            boxes.append(b[2])
            labels.append(b[3])
            difficulties.append(b[4])
        
        images = scn.InputLayerInput(torch.cat(locations,0), torch.cat(features,0))
        
        return images, boxes, labels, difficulties

