"""
Modified from https://github.com/kenshohara/3D-ResNets-PyTorch
"""
import random
import math


class Compose(object):

    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, frame_indices):
        for i, t in enumerate(self.transforms):
            if isinstance(frame_indices[0], list):
                next_transforms = Compose(self.transforms[i:])
                dst_frame_indices = [
                    next_transforms(clip_frame_indices)
                    for clip_frame_indices in frame_indices
                ]

                return dst_frame_indices
            else:
                frame_indices = t(frame_indices)
        return frame_indices


class LoopPadding(object):

    def __init__(self, size):
        self.size = size

    def __call__(self, frame_indices):
        out = frame_indices

        for index in out:
            if len(out) >= self.size:
                break
            out.append(index)

        return out


class TemporalBeginCrop(object):

    def __init__(self, size):
        self.size = size

    def __call__(self, frame_indices):
        out = frame_indices[:self.size]

        for index in out:
            if len(out) >= self.size:
                break
            out.append(index)

        return out


class TemporalEndCrop(object):
    def __init__(self, size):
        self.size = size

    def __call__(self, frame_indices):
        out = frame_indices[-self.size:]
        for index in out:
            if len(out) >= self.size:
                break
            out.append(index)
        return out


class TemporalCenterCrop(object):

    def __init__(self, size):
        self.size = size

    def __call__(self, frame_indices):

        center_index = len(frame_indices) // 2
        begin_index = max(0, center_index - (self.size // 2))
        end_index = min(begin_index + self.size, len(frame_indices))

        out = frame_indices[begin_index:end_index]

        for index in out:
            if len(out) >= self.size:
                break
            out.append(index)

        return out


class TemporalSpecificCrop(object):

    def __init__(self, begin_index, size):
        self.begin_index = begin_index
        self.size = size

    def __call__(self, frame_indices):
        out = frame_indices[self.begin_index : self.begin_index + self.size]

        for index in out:
            if len(out) >= self.size:
                break
            out.append(index)

        return out


class TemporalRandomCrop(object):

    def __init__(self, size, start_index=0):
        self.size = size
        self.loop = LoopPadding(size)
        self.start_index = start_index

    def __call__(self, frame_indices):
        rand_end = max(0, len(frame_indices) - self.size)
        rand_start = min(rand_end, self.start_index)

        begin_index = random.randint(rand_start, rand_end)
        end_index = min(begin_index + self.size, len(frame_indices))

        out = frame_indices[begin_index:end_index]
        # if len(out) < self.size:
        #     out = self.loop(out)
        return out


class TemporalRandomCrop2xSpeed(object):

    def __init__(self, size, start_index=0):
        self.size = size
        self.loop = LoopPadding(size)
        self.start_index = start_index

    def __call__(self, frame_indices):
        rand_end = max(0, len(frame_indices) - 2*self.size)
        rand_start = min(rand_end, self.start_index)

        begin_index = random.randint(rand_start, rand_end)
        end_index = min(begin_index + 2*self.size, len(frame_indices))

        out = frame_indices[begin_index:end_index:2]
        # if len(out) < self.size:
        #     out = self.loop(out)
        return out


class TemporalEvenCrop(object):

    def __init__(self, size, n_samples=1):
        self.size = size
        self.n_samples = n_samples
        self.loop = LoopPadding(size)

    def __call__(self, frame_indices):
        n_frames = len(frame_indices)
        stride = max(
            1, math.ceil((n_frames - 1 - self.size) / (self.n_samples - 1)))

        out = []
        for begin_index in frame_indices[::stride]:
            if len(out) >= self.n_samples:
                break
            end_index = min(frame_indices[-1] + 1, begin_index + self.size)
            sample = list(range(begin_index, end_index))

            if len(sample) < self.size:
                out.append(self.loop(sample))
                break
            else:
                out.append(sample)

        return out


class TemporalCenterFrame(object):

    def __init__(self):
        pass

    def __call__(self, frame_indices):
        center_index = len(frame_indices) // 2
        return [center_index]


class TemporalEndFrame(object):

    def __init__(self):
        pass

    def __call__(self, frame_indices):
        last_index = len(frame_indices) // 2
        return [last_index]


class SlidingWindow(object):

    def __init__(self, size, stride=0):
        self.size = size
        if stride == 0:
            self.stride = self.size
        else:
            self.stride = stride
        self.loop = LoopPadding(size)

    def __call__(self, frame_indices):
        out = []
        for begin_index in frame_indices[::self.stride]:
            end_index = min(frame_indices[-1] + 1, begin_index + self.size)
            sample = list(range(begin_index, end_index))

            if len(sample) < self.size:
                out.append(self.loop(sample))
                break
            else:
                out.append(sample)

        return out


class TemporalSubsampling(object):

    def __init__(self, stride):
        self.stride = stride

    def __call__(self, frame_indices):
        return frame_indices[::self.stride]


class Shuffle(object):

    def __init__(self, block_size=2): #TODO: make it configurable
        self.block_size = block_size

    def __call__(self, frame_indices):
        frame_indices = [
            frame_indices[i:(i + self.block_size)]
            for i in range(0, len(frame_indices), self.block_size)
        ]
        random.shuffle(frame_indices)
        frame_indices = [t for block in frame_indices for t in block]
        # print(frame_indices)
        return frame_indices


#temporal shuffle
