#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed April 1 13:39:00 2021
This repository is based on the repository at https://github.com/artest08/LateTemporalModeling3DCNN. We thank the authors for the repository.
This repository is authored by Jiajun Chen
We thank the authors for the repository.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
from functools import partial

from .BERT.self_attention import self_attention

__all__ = ['dark_light_resneXt3D']
    
#dark_light_resneXt3D
class dark_light_resneXt3D(nn.Module):
    def __init__(self, num_classes , length, both_flow, modelPath=''):
        super(dark_light_resneXt3D, self).__init__()
        self.hidden_size=512
        self.n_layers=1
        self.attn_heads=8
        self.num_classes=num_classes
        self.length=length
        self.dp = nn.Dropout(p=0.8)
        self.both_flow = both_flow
        
        self.avgpool = nn.AvgPool3d((1, 4, 4), stride=1)
        self.features=nn.Sequential(*list(_trained_resnext101(model_path=modelPath, sample_size=112, sample_duration=64).children())[:-2])
        
        downsample = nn.Sequential(
            nn.Conv3d(
                2048,
                512,
                kernel_size=1,
                stride=1,
                bias=False), nn.BatchNorm3d(512))

        mapper = ResNeXtBottleneck(2048, 256, cardinality = 32, stride = 1, downsample = downsample)

        for m in mapper.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight)
                #m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()   
                
        self.features[7][2] = mapper
        if self.both_flow == 'True':
            max_length = 8
        elif self.both_flow == 'False':
            max_length  = 4
        self.self_attention = self_attention(self.hidden_size, max_length , hidden=self.hidden_size, n_layers=self.n_layers, attn_heads=self.attn_heads)
        print(sum(p.numel() for p in self.self_attention.parameters() if p.requires_grad))
        
        self.fc_action = nn.Linear(self.hidden_size, num_classes)
      
        for param in self.features.parameters():
            param.requires_grad = True
  
        torch.nn.init.xavier_uniform_(self.fc_action.weight)
        self.fc_action.bias.data.zero_()
        
    def forward(self, x):
        if self.both_flow == 'True':
            x, x_light = x

            x = self.features(x)
            x_light = self.features(x_light)  # x(b,512,8,7,7)

            x = self.avgpool(x)
            x = x.view(x.size(0), self.hidden_size, 4)
            x = x.transpose(1, 2)

            x_light = self.avgpool(x_light)  # b,512,8,1,1
            x_light = x_light.view(x_light.size(0), self.hidden_size, 4)  # x(b,512,8)
            x_light = x_light.transpose(1, 2)  # x
            x_cat = torch.cat((x,x_light),1)
            output, maskSample = self.self_attention(x_cat)  # output(b,9,512),masksample(b,9)
        elif self.both_flow == 'False':
            _, x = x
            x = self.features(x)
            x = self.avgpool(x)
            x = x.view(x.size(0), self.hidden_size, 4)
            x = x.transpose(1,2)

            output , maskSample = self.self_attention(x)

        classificationOut = output[:,0,:]
        output=self.dp(classificationOut)
        x = self.fc_action(output)
        return x

def conv3x3x3(in_planes, out_planes, stride=1):
    # 3x3x3 convolution with padding
    return nn.Conv3d(
        in_planes,
        out_planes,
        kernel_size=3,
        stride=stride,
        padding=1,
        bias=False)


def downsample_basic_block(x, planes, stride):
    out = F.avg_pool3d(x, kernel_size=1, stride=stride)
    zero_pads = torch.Tensor(
        out.size(0), planes - out.size(1), out.size(2), out.size(3),
        out.size(4)).zero_()
    if isinstance(out.data, torch.cuda.FloatTensor):
        zero_pads = zero_pads.cuda()

    out = Variable(torch.cat([out.data, zero_pads], dim=1))

    return out


class ResNeXtBottleneck(nn.Module):
    expansion = 2

    def __init__(self, inplanes, planes, cardinality, stride=1,
                 downsample=None):
        super(ResNeXtBottleneck, self).__init__()
        mid_planes = cardinality * int(planes / 32)
        self.conv1 = nn.Conv3d(inplanes, mid_planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm3d(mid_planes)
        self.conv2 = nn.Conv3d(
            mid_planes,
            mid_planes,
            kernel_size=3,
            stride=stride,
            padding=1,
            groups=cardinality,
            bias=False)
        self.bn2 = nn.BatchNorm3d(mid_planes)
        self.conv3 = nn.Conv3d(
            mid_planes, planes * self.expansion, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm3d(planes * self.expansion)
        self.relu = nn.ReLU(inplace=True)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn3(out)

        if self.downsample is not None:
            residual = self.downsample(x)

        out += residual
        out = self.relu(out)

        return out


class ResNeXt(nn.Module):

    def __init__(self,
                 block,
                 layers,
                 sample_size,
                 sample_duration,
                 shortcut_type='B',
                 cardinality=32,
                 num_classes=400):
        self.inplanes = 64
        super(ResNeXt, self).__init__()
        self.conv1 = nn.Conv3d(
            3,
            64,
            kernel_size=7,
            stride=(1, 2, 2),
            padding=(3, 3, 3),
            bias=False)
        self.bn1 = nn.BatchNorm3d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool3d(kernel_size=(3, 3, 3), stride=2, padding=1)
        self.layer1 = self._make_layer(block, 128, layers[0], shortcut_type,
                                       cardinality)
        self.layer2 = self._make_layer(
            block, 256, layers[1], shortcut_type, cardinality, stride=2)
        self.layer3 = self._make_layer(
            block, 512, layers[2], shortcut_type, cardinality, stride=2)
        self.layer4 = self._make_layer(
            block, 1024, layers[3], shortcut_type, cardinality, stride=2)
        last_duration = int(math.ceil(sample_duration / 16))
        last_size = int(math.ceil(sample_size / 32))
        self.avgpool = nn.AvgPool3d(
            (last_duration, last_size, last_size), stride=1)
        self.fc = nn.Linear(cardinality * 32 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv3d):
                nn.init.kaiming_normal_(m.weight)
                #m.weight = nn.init.kaiming_normal(m.weight, mode='fan_out')
            elif isinstance(m, nn.BatchNorm3d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()

    def _make_layer(self,
                    block,
                    planes,
                    blocks,
                    shortcut_type,
                    cardinality,
                    stride=1):
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            if shortcut_type == 'A':
                downsample = partial(
                    downsample_basic_block,
                    planes=planes * block.expansion,
                    stride=stride)
            else:
                downsample = nn.Sequential(
                    nn.Conv3d(
                        self.inplanes,
                        planes * block.expansion,
                        kernel_size=1,
                        stride=stride,
                        bias=False), nn.BatchNorm3d(planes * block.expansion))

        layers = []
        layers.append(
            block(self.inplanes, planes, cardinality, stride, downsample))
        self.inplanes = planes * block.expansion
        for i in range(1, blocks):
            layers.append(block(self.inplanes, planes, cardinality))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)

        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x


def get_fine_tuning_parameters(model, ft_begin_index):
    if ft_begin_index == 0:
        return model.parameters()

    ft_module_names = []
    for i in range(ft_begin_index, 5):
        ft_module_names.append('layer{}'.format(i))
    ft_module_names.append('fc')

    parameters = []
    for k, v in model.named_parameters():
        for ft_module in ft_module_names:
            if ft_module in k:
                parameters.append({'params': v})
                break
        else:
            parameters.append({'params': v, 'lr': 0.0})

    return parameters


def resnext3D50(**kwargs):
    """Constructs a ResNet-50 model.
    """
    model = ResNeXt(ResNeXtBottleneck, [3, 4, 6, 3], **kwargs)
    return model


def resnext3D101(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], **kwargs)
    return model

def rgb_mars_resnext3D64f101(num_classes , length, modelPath=''):
    """Constructs a ResNet-101 model.
    """
    model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], sample_size=112, sample_duration=64, num_classes=51)
    return model

def flow_mars_resnext3D64f101(num_classes , length, modelPath=''):
    """Constructs a ResNet-101 model.
    """
    model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], sample_size=112, sample_duration=64, num_classes=51)
    model.conv1 = nn.Conv3d(
        2,
        64,
        kernel_size=7,
        stride=(1, 2, 2),
        padding=(3, 3, 3),
        bias=False)
    return model


def _trained_resnext101(model_path, **kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], **kwargs)
    if model_path=='':
        return model
    params = torch.load(model_path)
    new_dict = {k[7:]: v for k, v in params['state_dict'].items()} 
    model_dict=model.state_dict() 
    model_dict.update(new_dict)
    model.load_state_dict(new_dict)
    return model

def _trained_resnext101_flow(model_path_flow, **kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNeXt(ResNeXtBottleneck, [3, 4, 23, 3], **kwargs)
    model.conv1 = nn.Conv3d(
        2,
        64,
        kernel_size=7,
        stride=(1, 2, 2),
        padding=(3, 3, 3),
        bias=False)
    if model_path_flow=='':
        return model
    params = torch.load(model_path_flow)
    new_dict = {k[7:]: v for k, v in params['state_dict'].items()} 
    model_dict=model.state_dict() 
    model_dict.update(new_dict)
    model.load_state_dict(new_dict)
    return model

def resnext3D152(**kwargs):
    """Constructs a ResNet-101 model.
    """
    model = ResNeXt(ResNeXtBottleneck, [3, 8, 36, 3], **kwargs)
    return model




