import torch
from torch import nn


class Convnet(nn.Module):
    def __init__(self, in_dim=3, hid_dim=64, out_dim=64):
        """
        """
        super().__init__()
        self.encoder = nn.Sequential(conv_block(in_dim, hid_dim),
                                     conv_block(hid_dim, hid_dim),
                                     conv_block(hid_dim, hid_dim),
                                     conv_block(hid_dim, out_dim))

    def forward(self, x):
        """
        """
        x = self.encoder(x)
        x = nn.MaxPool2d(5)(x)
        x = x.view(x.size(0), -1)
        return x
    
        
def conv_block(in_dim: int, out_dim: int):
    return nn.Sequential(nn.Conv2d(in_dim, out_dim, 3, padding=1),
                         nn.BatchNorm2d(out_dim),
                         nn.ReLU(),
                         nn.MaxPool2d(2))