import torch
import torch.nn as nn

class NormConv2d(nn.Conv2d):
    def __init__(self, channels, kernel_size=2):
        super().__init__(1, channels, kernel_size,
                         padding='same',
                         padding_mode='replicate',
                         bias=False)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert input.dim() == 4
        batch_size, in_channels, *image_size = input.shape

        input = input.reshape(batch_size * in_channels, 1, *image_size)
        weight = self.weight - self.weight.mean([2, 3], keepdim=True)
        output = self._conv_forward(input, weight, self.bias)
        output = output.reshape(batch_size, in_channels,
                                self.out_channels, *image_size)

        output = torch.abs(output)
        output = torch.mean(output, 1)

        return output
