#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch


class VGGUnet(nn.Module):
    def __init__(self, args, level):
        super(VGGUnet, self).__init__()
        self.args = args
        self.level = level

        vgg16 = torchvision.models.vgg16(pretrained=False)
        vgg16.load_state_dict(torch.load('/ws/external/checkpoints/pretrained/vgg16-397923af.pth'))
        # vgg16 = torchvision.models.vgg16(pretrained=True)

        # load CNN from VGG16, the first three block
        # Feature Encoder
        self.conv0 = vgg16.features[0]
        self.conv2 = vgg16.features[2]  # \\64
        self.conv5 = vgg16.features[5]  #
        self.conv7 = vgg16.features[7]  # \\128
        self.conv10 = vgg16.features[10]
        self.conv12 = vgg16.features[12]
        self.conv14 = vgg16.features[14]  # \\256
        self.conv17 = vgg16.features[17]
        self.conv19 = vgg16.features[19]
        self.conv21 = vgg16.features[21]  # \\512


        # Feature Decoder
        self.conv_dec0 = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(self.conv21.out_channels + self.conv14.out_channels, self.conv14.out_channels, kernel_size=(3, 3),
                      stride=(1, 1), padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.conv14.out_channels, self.conv14.out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1,
                      bias=False),
        )

        self.conv_dec1 = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(self.conv14.out_channels + self.conv7.out_channels, self.conv7.out_channels, kernel_size=(3, 3),
                      stride=(1, 1), padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.conv7.out_channels, self.conv7.out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1,
                      bias=False),
        )

        self.conv_dec2 = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(self.conv7.out_channels + self.conv2.out_channels, self.conv2.out_channels, kernel_size=(3, 3),
                      stride=(1, 1), padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(self.conv2.out_channels, self.conv2.out_channels, kernel_size=(3, 3), stride=(1, 1), padding=1,
                      bias=False)
        )

        self.conv_dec3 = nn.Sequential(
            nn.ReLU(inplace=True),
            nn.Conv2d(self.conv2.out_channels + self.conv2.out_channels, 32, kernel_size=(3, 3),
                      stride=(1, 1), padding=1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=1,
                      bias=False)
        )

        self.relu = nn.ReLU(inplace=True)
        self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False,
                                     return_indices=True)

        # Feature confidence
        self.conf_1 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(512, 1, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
            nn.Sigmoid(),
        )
        self.conf0 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(256, 1, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
            nn.Sigmoid(),
        )
        self.conf1 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(128, 1, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
            nn.Sigmoid(),
        )
        self.conf2 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
            nn.Sigmoid(),
        )
        self.conf3 = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=1, bias=False),
            nn.Sigmoid(),
        )


    def forward(self, x):
        # block0
        x0 = self.conv0(x)
        x1 = self.relu(x0)
        x2 = self.conv2(x1)
        x3, ind3 = self.max_pool(x2)  # [H/2, W/2]

        # block1
        x4 = self.relu(x3)
        x5 = self.conv5(x4)
        x6 = self.relu(x5)
        x7 = self.conv7(x6)
        x8, ind8 = self.max_pool(x7)  # [H/4, W/4]

        # block2
        x9 = self.relu(x8)
        x10 = self.conv10(x9)
        x11 = self.relu(x10)
        x12 = self.conv12(x11)
        x13 = self.relu(x12)
        x14 = self.conv14(x13)
        x15, ind15 = self.max_pool(x14)  # [H/8, W/8]

        # block3
        x15_1 = self.relu(x15)
        x15_2 = self.conv17(x15_1)
        x15_3 = self.relu(x15_2)
        x15_4 = self.conv19(x15_3)
        x15_5 = self.relu(x15_4)
        x15_6 = self.conv21(x15_5)
        x15_7 = self.relu(x15_6)
        x15_8, ind15_ = self.max_pool(x15_7)  # [H/16, W/16]

        self.enc_feature = [x15_8, x15, x8, x3, x2]

        # dec0
        x15_9 = F.interpolate(x15_8, [x15.shape[2], x15.shape[3]], mode="nearest")
        x15_10 = torch.cat([x15_9, x15], dim=1)
        x15_11 = self.conv_dec0(x15_10)  # [H/8, W/8]

        # dec1
        x16 = F.interpolate(x15_11, [x8.shape[2], x8.shape[3]], mode="nearest")
        x17 = torch.cat([x16, x8], dim=1)
        x18 = self.conv_dec1(x17)  # [H/4, W/4]

        # dec2
        x19 = F.interpolate(x18, [x3.shape[2], x3.shape[3]], mode="nearest")
        x20 = torch.cat([x19, x3], dim=1)
        x21 = self.conv_dec2(x20)  # [H/2, W/2]

        # dec3
        x22 = F.interpolate(x21, [x2.shape[2], x2.shape[3]], mode="nearest")
        x23 = torch.cat([x22, x2], dim=1)
        x24 = self.conv_dec3(x23)  # [H, W]

        # attention
        c_1 = nn.Sigmoid()(-self.conf_1(x15_8))
        c0 = nn.Sigmoid()(-self.conf0(x15_11))
        c1 = nn.Sigmoid()(-self.conf1(x18))
        c2 = nn.Sigmoid()(-self.conf2(x21))
        c3 = nn.Sigmoid()(-self.conf3(x24))

        if self.level == -1:
            return [x15], [c0]
        elif self.level == -2:
            return [x18], [c1]
        elif self.level == -3:
            return [x21], [c2]
        elif self.level == 2:
            return [x18, x21], [c1, c2]
        elif self.level == 3:   # 024
            return [x15_8, x18, x24], [c_1, c1, c3]
        elif self.level == 3.1: # 023
            return [x15_8, x18, x21], [c_1, c1, c2]
        elif self.level == 3.2: # 124
            return [x15_11, x18, x24], [c0, c1, c3]
        elif self.level == 4:
            return [x15, x18, x21, x24], [c0, c1, c2, c3]
        elif self.level == 5:
            return [x15_8, x15_11, x18, x21, x24], [c_1, c0, c1, c2, c3]


