import torch
import torch.nn as nn


class NaiveCompressor(nn.Module):
    def __init__(self, input_dim, compress_raito):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(input_dim, input_dim//compress_raito, kernel_size=3,
                      stride=1, padding=1),
            nn.BatchNorm2d(input_dim//compress_raito, eps=1e-3, momentum=0.01),
            nn.ReLU()
        )
        self.decoder = nn.Sequential(
            nn.Conv2d(input_dim//compress_raito, input_dim, kernel_size=3,
                      stride=1, padding=1),
            nn.BatchNorm2d(input_dim, eps=1e-3, momentum=0.01),
            nn.ReLU(),
            nn.Conv2d(input_dim, input_dim, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(input_dim, eps=1e-3,
                           momentum=0.01),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)

        return x


class HeteroCompressor(nn.Module):
    def __init__(self, input_dim, compress_raito, num_types=2):
        super().__init__()
        self.num_types = num_types
        self.net = nn.ModuleList([])
        for i in range(self.num_types):
            self.net.append(NaiveCompressor(input_dim, compress_raito))

    def add_net(self, fn):
        self.net.append(fn)

    def forward(self, x, mode, *args, **kwargs):
        out = []
        for i in range(self.num_types):
            x_mode = x[mode == i, ...]
            x_mode = self.net[i](x_mode)
            out.append(x_mode)
        out = self.combine_features(out, mode)
        return out

    def combine_features(self, x_list, mode):
        count_list = [0] * self.num_types
        B, L = mode.shape[:2]
        x_batch = []
        for i in range(B):
            x_agent = []
            for j in range(L):
                for type in range(self.num_types):
                    if mode[i, j] == type:
                        x_agent.append(x_list[type][count_list[type]])
                        count_list[type] += 1
            x_agent = torch.stack(x_agent, dim=0)
            x_batch.append(x_agent)
        return torch.stack(x_batch, dim=0)
