"""
Implementation of F-cooper maxout fusing.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class SpatialMaxFusion(nn.Module):
    def __init__(self):
        super(SpatialMaxFusion, self).__init__()
    def forward(self, x):
        return torch.max(x, dim=1, keepdim=True)[0]

class SpatialFusion(nn.Module):
    def __init__(self):
        super(SpatialFusion, self).__init__()

    def regroup(self, x, record_len):
        cum_sum_len = torch.cumsum(record_len, dim=0)
        split_x = torch.tensor_split(x, cum_sum_len[:-1].cpu())
        return split_x

    def forward(self, x, record_len):
        # x: B, C, H, W, split x:[(B1, C, W, H), (B2, C, W, H)]
        split_x = self.regroup(x, record_len)
        out = []

        for xx in split_x:
            xx = torch.max(xx, dim=0, keepdim=True)[0]
            out.append(xx)
        return torch.cat(out, dim=0)


class SpatialFusionMask(nn.Module):
    def __init__(self):
        super(SpatialFusionMask, self).__init__()

    def forward(self, x):
        # x: B, L, H, W, C
        output = torch.max(x, dim=1)[0]
        return output