import torch
from torch import nn
import torchvision
import cv2
from PIL import Image
from torchvision.transforms import ToTensor, Normalize, Compose


class DeepLabV3(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=True)
        self.model.cuda().eval()
        self.preprocess = Compose([
            ToTensor(),
            Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def process_data(self, img):
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_tensor = self.preprocess(img).unsqueeze(0).cuda()
        raw_seg = self.model(img_tensor)['out']
        seg = raw_seg.squeeze(0).argmax(dim=0)
        seg[seg > 0] = 1
        seg *= 255
        return seg.type(torch.uint8).cpu().detach().numpy()


if __name__ == '__main__':
    model = DeepLabV3()
    img = cv2.imread('86_input.jpg')
    img_mask = model.process_data(img)
    Image.fromarray(img_mask).save('mask.png')
