#%% Import packages
import glob
import numpy as np
import pandas as pd
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as T

from datetime import datetime
from torch.utils.data import DataLoader, Dataset
from torch.distributions import bernoulli
from PIL import Image
from PIL import ImageOps, ImageFilter

torch.backends.cudnn.benchmark = False

#%% Helper Functions    
# Read and transform KITTI depth images
def depth_read(filename):
    my_file = Image.open(filename)
    depth_png = np.array(my_file, dtype=int)
    assert(np.max(depth_png) > 255)
    depth = depth_png.astype(np.float32)/256
    my_file.close()
    return depth

#%% Loss functions
def sparse_MAE(target_y, predicted_y):
    binary_mask = torch.where(torch.eq(target_y, 0), torch.zeros_like(target_y), torch.ones_like(target_y))
    predicted_y = binary_mask * predicted_y
    loss = target_y - predicted_y
    loss = torch.abs(loss)
    loss = torch.sum(loss) / torch.sum(binary_mask)
    return loss

#%% Neural Network Architecture    
class SICNN(nn.Module):
    def __init__(self):
        super(SICNN, self).__init__()
                
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=64, kernel_size=11, stride=1, padding=5, bias=False)
        self.conv2 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=9, stride=1, padding=4, bias=False)
        self.conv3 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False) # Down sampling
        
        self.conv4 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv5 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv6 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=2, padding=1, bias=False) # Down sampling
        
        self.conv7 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv8 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv9 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False) # Down sampling
        
        self.conv10 = nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv11 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv12 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=2, padding=1, bias=False) # Down sampling
        
        self.conv13 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv14 = nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv15 = nn.ConvTranspose2d(in_channels=512, out_channels=256, kernel_size=3, stride=2, padding=1, bias=False, output_padding=1) # Up sampling
        
        self.conv16 = nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False) # Feature concatenation
        self.conv17 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv18 = nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=1, bias=False, output_padding=1) # Up sampling
        
        self.conv19 = nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False) # Feature concatenation
        self.conv20 = nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv21 = nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=3, stride=2, padding=1, bias=False, output_padding=1) # Up sampling
        
        self.conv22 = nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False) # Feature concatenation
        self.conv23 = nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv24 = nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=3, stride=2, padding=1, bias=False, output_padding=1) # Up sampling
        
        self.conv25 = nn.Conv2d(in_channels=96, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False) # Feature concatenation
        self.conv26 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1, bias=False)
        self.conv27 = nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, bias=False) # Output channel 0: Dense depth image
                                                                                                                # Output channel 1: Depth contour image
        
        self.weight_init()
        
        # Learnable Bias parameters
        self.bias1 = nn.Parameter(torch.zeros((self.conv1.out_channels, 1, 1)), requires_grad=True)
        self.bias2 = nn.Parameter(torch.zeros((self.conv2.out_channels, 1, 1)), requires_grad=True)
        self.bias3 = nn.Parameter(torch.zeros((self.conv3.out_channels, 1, 1)), requires_grad=True)
        self.bias4 = nn.Parameter(torch.zeros((self.conv4.out_channels, 1, 1)), requires_grad=True)
        self.bias5 = nn.Parameter(torch.zeros((self.conv5.out_channels, 1, 1)), requires_grad=True)
        self.bias6 = nn.Parameter(torch.zeros((self.conv6.out_channels, 1, 1)), requires_grad=True)
        self.bias7 = nn.Parameter(torch.zeros((self.conv7.out_channels, 1, 1)), requires_grad=True)
        self.bias8 = nn.Parameter(torch.zeros((self.conv8.out_channels, 1, 1)), requires_grad=True)
        self.bias9 = nn.Parameter(torch.zeros((self.conv9.out_channels, 1, 1)), requires_grad=True)
        self.bias10 = nn.Parameter(torch.zeros((self.conv10.out_channels, 1, 1)), requires_grad=True)
        self.bias11 = nn.Parameter(torch.zeros((self.conv11.out_channels, 1, 1)), requires_grad=True)
        self.bias12 = nn.Parameter(torch.zeros((self.conv12.out_channels, 1, 1)), requires_grad=True)
        self.bias13 = nn.Parameter(torch.zeros((self.conv13.out_channels, 1, 1)), requires_grad=True)
        self.bias14 = nn.Parameter(torch.zeros((self.conv14.out_channels, 1, 1)), requires_grad=True)
        self.bias15 = nn.Parameter(torch.zeros((self.conv15.out_channels, 1, 1)), requires_grad=True)
        self.bias16 = nn.Parameter(torch.zeros((self.conv16.out_channels, 1, 1)), requires_grad=True)
        self.bias17 = nn.Parameter(torch.zeros((self.conv17.out_channels, 1, 1)), requires_grad=True)
        self.bias18 = nn.Parameter(torch.zeros((self.conv18.out_channels, 1, 1)), requires_grad=True)
        self.bias19 = nn.Parameter(torch.zeros((self.conv19.out_channels, 1, 1)), requires_grad=True)
        self.bias20 = nn.Parameter(torch.zeros((self.conv20.out_channels, 1, 1)), requires_grad=True)
        self.bias21 = nn.Parameter(torch.zeros((self.conv21.out_channels, 1, 1)), requires_grad=True)
        self.bias22 = nn.Parameter(torch.zeros((self.conv22.out_channels, 1, 1)), requires_grad=True)
        self.bias23 = nn.Parameter(torch.zeros((self.conv23.out_channels, 1, 1)), requires_grad=True)
        self.bias24 = nn.Parameter(torch.zeros((self.conv24.out_channels, 1, 1)), requires_grad=True)
        self.bias25 = nn.Parameter(torch.zeros((self.conv25.out_channels, 1, 1)), requires_grad=True)
        self.bias26 = nn.Parameter(torch.zeros((self.conv26.out_channels, 1, 1)), requires_grad=True)
        self.bias27 = nn.Parameter(torch.zeros((self.conv27.out_channels, 1, 1)), requires_grad=True)
        
        # Convolution kernels containing only ones.
        self.ones1 = nn.Parameter(torch.ones((1,1,11,11)), requires_grad=False)
        self.ones2 = nn.Parameter(torch.ones((1,1,9,9)), requires_grad=False)  
        self.ones3 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False)  
        self.ones4 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False)  
        self.ones5 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False)  
        self.ones6 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False)  
        self.ones7 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False)  
        self.ones8 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False)  
        self.ones9 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False)  
        self.ones10 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones11 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones12 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones13 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones14 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones15 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False)
        self.ones16 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones17 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones18 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones19 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones20 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones21 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones22 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False)
        self.ones23 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones24 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones25 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones26 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        self.ones27 = nn.Parameter(torch.ones((1,1,3,3)), requires_grad=False) 
        
        
    def weight_init(self):
        torch.nn.init.kaiming_normal_(self.conv1.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv2.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv3.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv4.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv5.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv6.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv7.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv8.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv9.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv10.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv11.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv12.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv13.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv14.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv15.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv16.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv17.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv18.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv19.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv20.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv21.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv22.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv23.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv24.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv25.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv26.weight, mode='fan_in', nonlinearity='relu')
        torch.nn.init.kaiming_normal_(self.conv27.weight, mode='fan_in', nonlinearity='relu')
        
    def sparse_conv(self, x, conv_layer, bias_layer, ones_layer, binary_mask):
        if binary_mask==None: #first input has no binary mask, so create one
            binary_mask = torch.where(torch.eq(x[:,:1,:,:], 0), torch.zeros_like(x[:,:1,:,:]), torch.ones_like(x[:,:1,:,:]))
            binary_mask.requires_grad = False
        
        # Step 1: Pointwise multiplication of input depth image/features with binary mask.
        x = x * binary_mask
        # Step 2: Convolution of input depth features.
        x = conv_layer(x)
        # Step 3: Convolution of binary mask with Ones.
        norm = F.conv2d(binary_mask, ones_layer, bias=None, stride=conv_layer.stride, padding=conv_layer.padding)
        norm.requires_grad = False
        # Step 4: Normalization.
        norm = torch.where(torch.eq(norm, 0), torch.zeros_like(norm), torch.reciprocal(norm))
        # Step 5: Adding bias
        x = (x * norm) + bias_layer
        # Step 6: Up sampling the binary mask        
        binary_mask = F.max_pool2d(binary_mask, stride=conv_layer.stride , kernel_size=conv_layer.kernel_size, padding=conv_layer.padding)
        
        return x, binary_mask
    
    def t_sparse_conv(self, x, t_conv_layer, bias_layer, ones_layer, binary_mask):
    
        # Step 1: Pointwise multiplication of input depth image/features with binary mask.
        x = x * binary_mask
        # Step 2: Convolution of input depth features.
        x = t_conv_layer(x)
        # Step 3: Convolution of binary mask with Ones.
        norm = F.conv_transpose2d(binary_mask, ones_layer, bias=None, stride=t_conv_layer.stride, padding=t_conv_layer.padding, output_padding=1)
        norm.requires_grad = False
        # Step 4: Normalization.
        norm = torch.where(torch.eq(norm, 0), torch.zeros_like(norm), torch.reciprocal(norm))
        # Step 5: Adding bias
        x = (x * norm) + bias_layer
        # Step 6: Up sampling the binary mask        
        binary_mask = F.upsample_nearest(binary_mask, scale_factor=t_conv_layer.stride)
        
        return x, binary_mask

    def forward(self, x):
        x, mask = self.sparse_conv(x, self.conv1, self.bias1, self.ones1, None)
        x2, mask = self.sparse_conv(F.relu(x), self.conv2, self.bias2, self.ones2, mask) # Skip connection
        x, mask = self.sparse_conv(F.relu(x2), self.conv3, self.bias3, self.ones3, mask)
        x, mask = self.sparse_conv(F.relu(x), self.conv4, self.bias4, self.ones4, mask)
        x5, mask = self.sparse_conv(F.relu(x), self.conv5, self.bias5, self.ones5, mask) # Skip connection
        x, mask = self.sparse_conv(F.relu(x5), self.conv6, self.bias6, self.ones6, mask)
        x, mask = self.sparse_conv(F.relu(x), self.conv7, self.bias7, self.ones7, mask)
        x8, mask = self.sparse_conv(F.relu(x), self.conv8, self.bias8, self.ones8, mask) # Skip connection
        x, mask = self.sparse_conv(F.relu(x8), self.conv9, self.bias9, self.ones9, mask)
        x, mask = self.sparse_conv(F.relu(x), self.conv10, self.bias10, self.ones10, mask)
        x11, mask = self.sparse_conv(F.relu(x), self.conv11, self.bias11, self.ones11, mask) # Skip connection
        x, mask = self.sparse_conv(F.relu(x11), self.conv12, self.bias12, self.ones12, mask)
        x, mask = self.sparse_conv(F.relu(x), self.conv13, self.bias13, self.ones13, mask)
        x, mask = self.sparse_conv(F.relu(x), self.conv14, self.bias14, self.ones14, mask)
        x, mask = self.t_sparse_conv(F.relu(x), self.conv15, self.bias15, self.ones15, mask)
        # Feature concatenation
        x, mask = self.sparse_conv(torch.hstack((F.relu(x), F.relu(x11))), self.conv16, self.bias16, self.ones16, mask) 
        x, mask = self.sparse_conv(F.relu(x), self.conv17, self.bias17, self.ones17, mask)
        x, mask = self.t_sparse_conv(F.relu(x), self.conv18, self.bias18, self.ones18, mask)
        # Feature concatenation
        x, mask = self.sparse_conv(torch.hstack((F.relu(x), F.relu(x8))), self.conv19, self.bias19, self.ones19, mask) 
        x, mask = self.sparse_conv(F.relu(x), self.conv20, self.bias20, self.ones20, mask)
        x, mask = self.t_sparse_conv(F.relu(x), self.conv21, self.bias21, self.ones21, mask)
        # Feature concatenation
        x, mask = self.sparse_conv(torch.hstack((F.relu(x), F.relu(x5))), self.conv22, self.bias22, self.ones22, mask) 
        x, mask = self.sparse_conv(F.relu(x), self.conv23, self.bias23, self.ones23, mask)
        x, mask = self.t_sparse_conv(F.relu(x), self.conv24, self.bias24, self.ones24, mask)
        # Feature concatenation
        x, mask = self.sparse_conv(torch.hstack((F.relu(x), F.relu(x2))), self.conv25, self.bias25, self.ones25, mask) 
        x, mask = self.sparse_conv(F.relu(x), self.conv26, self.bias26, self.ones26, mask)
        x, mask = self.sparse_conv(F.relu(x), self.conv27, self.bias27, self.ones27, mask)
        
        return F.relu(x)

#%% Dataloader        
class KittiDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.input_files = glob.glob(self.root_dir, recursive=True)
        self.input_files = pd.DataFrame(self.input_files)
            
    def __len__(self):
        return len(self.input_files)
        
    def __getitem__(self, idx):
        input_depth_array = depth_read(self.input_files.iloc[idx, 0])
        target_depth_array = depth_read(self.input_files.iloc[idx, 0].replace("velodyne_raw", "groundtruth"))
        
        # Convert to tensors
        inputs = self.transform(input_depth_array)
        targets = self.transform(target_depth_array)
        
        # Data augmentation:  Adding gaussian noise with depth-dependent variance
        # This feature works only with PyTorch 1.7.1
        noisy_input = torch.distributions.normal.Normal(inputs, 0.005 * inputs)
        inputs = noisy_input.sample()
        
        # Data augmentation: Horizontal flipping
        if random.random() > 0.5:
            inputs = T.functional.hflip(inputs)
            targets = T.functional.hflip(targets)
            
        return {'inputs': inputs,
                'targets': targets}


#%% Main function
if __name__=="__main__":
    gen = torch.manual_seed(1515870810)
    if torch.cuda.is_available():
        device = "cuda:0"
        print(device)
    else:
        device = "cpu"
        
    dataset = KittiDataset(root_dir="../../dataset/kitti/train/*/*/velodyne_raw/*/*.png", 
                             transform=T.Compose([T.ToTensor(), T.CenterCrop((352,1216))]))
                             
    print("Length of dataset: ",len(dataset))

    train_dataloader = DataLoader(dataset=dataset, 
                              batch_size=4,
                              shuffle=True,
                              pin_memory=True,
                              drop_last=True,
                              num_workers=2)
    
    print("Length of dataloader: ",len(train_dataloader))

#%% Checkpoint Loading
    start_epoch = 0
    
    checkpoint = torch.load(r'../../model/ablation_study_test5_step1.pt')
    sicnn = SICNN().to(device=device, dtype=torch.float32)
    sicnn.load_state_dict(checkpoint['model_state_dict'])
    
    # Fine-tuning or transfer learning
    for p in sicnn.named_parameters():
        p[1].requires_grad = False
    
    trainable_parameters_list = []
    for p in sicnn.named_parameters(): # Only the last 6 layers are trainable
        if p[0] == "conv27.weight" or p[0] == "bias27" or p[0] == "conv26.weight" or p[0] == "bias26" or p[0] == "conv25.weight" or p[0] == "bias25":
            p[1].requires_grad = True
            trainable_parameters_list.append(p[1])
 
    optimizer = optim.Adam(trainable_parameters_list)
 
#%% Training
    sicnn = sicnn.train()
    print("Start Epoch: ",start_epoch)
    
    if device == "cuda:0":
        torch.cuda.empty_cache()

    for epoch in range(start_epoch,10):  # loop over the dataset multiple times
        running_loss = 0.0
        
        total_size = 0
        start_time = datetime.now()
        for i, data in enumerate(train_dataloader, 0):
            inputs, targets = data['inputs'].to(dtype=torch.float32, device=device), data['targets'].to(dtype=torch.float32, device=device)
            
            optimizer.zero_grad()

            # Forwared pass
            outputs = sicnn(inputs)
            
            loss = sparse_MAE(targets[:,:1,:,:], outputs[:,:1,:,:]) # First output channel should contain depth
            loss.backward()
            optimizer.step()

            # Update statistics
            total_size += len(inputs)
            running_loss += loss.item()
                        
        end_time = datetime.now()
        print("Total time taken: ", end_time - start_time, end="\t")
        print("Epoch: %d \t Loss: %.3f" % (epoch+1, running_loss/len(train_dataloader)), end="\n")

        # Create a checkpoint
        path = '../../model/ablation_study_test5_step2.pt'
        torch.save({
                    'description':'ablation_study_test5_step2',
                    'epoch': epoch, 
                    'model_state_dict': sicnn.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'mae_loss': running_loss/len(train_dataloader)
                    }, path)
    
    print('Finished Training')