# -*- coding: utf-8 -*-
"""3D_LightDenseNet.ipynb

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1GL7TIivwXDxM8_bbVAPobftp10VnW77x
"""

import cv2
import numpy as np
from matplotlib import pyplot as plt
import torch
import os
from math import *
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.autograd import Variable
from sklearn.metrics import roc_curve, auc

import random
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from scipy import stats
from mpl_toolkits import mplot3d
import math

from torch.utils.data import TensorDataset

import torchvision
import torchvision.transforms as transforms

import warnings
warnings.filterwarnings(action='ignore')

class Conv_Block(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(Conv_Block, self).__init__()

        inter_planes = out_planes * 4

        self.bn1 = nn.BatchNorm3d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv3d(in_planes, inter_planes, kernel_size=1, stride=1, padding=0, bias=False)

        self.bn2 = nn.BatchNorm3d(inter_planes)
        self.conv2 = nn.Conv3d(inter_planes, out_planes, kernel_size=3, stride=1, padding=1, bias=False)


    def forward(self, x):
        out = self.conv1(self.relu(self.bn1(x)))
        out = self.conv2(self.relu(self.bn2(out)))
        return torch.cat([x, out], 1)

class Dense_Block(nn.Module):
    def __init__(self, nb_layers, in_planes, growth_rate, block):
        super(Dense_Block, self).__init__()
        self.layer = self._make_layer(block, in_planes, growth_rate, nb_layers)

    def _make_layer(self, block, in_planes, growth_rate, nb_layers):
        layers = []
        for i in range(nb_layers):
            layers.append(block(in_planes + i * growth_rate, growth_rate))
        
        return nn.Sequential(*layers)

    def forward(self, x):
        return self.layer(x)

class Transition_Block(nn.Module):
    def __init__(self, in_planes, out_planes):
        super(Transition_Block, self).__init__()

        self.bn1 = nn.BatchNorm3d(in_planes)
        self.relu = nn.ReLU(inplace=True)
        self.conv1 = nn.Conv3d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False)

    def forward(self, x):
        out = self.conv1(self.relu(self.bn1(x)))
        return F.avg_pool3d(out, kernel_size=2, stride=2)

class Light_DenseNet_3D(nn.Module):
    def __init__(self, num_classes=2, growth_rate=32, reduction=0.5):
        super(Light_DenseNet_3D, self).__init__()
        input_channel = 96
        num_of_blocks = 3
        self.conv3d_1 = nn.Conv3d(in_channels = input_channel, out_channels = input_channel * 2, kernel_size = 5, stride=1)
        self.bn1 = nn.BatchNorm3d(input_channel * 2)
        self.relu = nn.ReLU(inplace=True)

        in_planes = input_channel * 2
        block_num = [6, 12, 8]

        # 1st Block
        self.dense1 = Dense_Block(block_num[0], in_planes, growth_rate, Conv_Block)
        in_planes = int(in_planes + block_num[0] * growth_rate)
        self.trans1 = Transition_Block(in_planes, int(math.floor(in_planes*reduction)))
        in_planes = int(math.floor(in_planes * reduction))

        # 2nd Block
        self.dense2 = Dense_Block(block_num[1], in_planes, growth_rate, Conv_Block)
        in_planes = int(in_planes + block_num[1] * growth_rate)
        self.trans2 = Transition_Block(in_planes, int(math.floor(in_planes*reduction)))
        in_planes = int(math.floor(in_planes * reduction))

        # 3rd Block
        self.dense3 = Dense_Block(block_num[2], in_planes, growth_rate, Conv_Block)
        in_planes = int(in_planes + block_num[2] * growth_rate)

        self.fc = nn.Linear(in_planes, num_classes)

        self.in_planes = in_planes

    def forward(self, x):
        x = self.conv3d_1(x)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.dense1(x)
        x = self.trans1(x)

        x = self.dense2(x)
        x = self.trans2(x)

        x = self.dense3(x)

        x = F.adaptive_avg_pool3d(x, (1, 1, 1))
        
        x = x.view(-1, self.in_planes)
        result = self.fc(x)
        return result