"""
Point Transformer V1 for Object Classification

Might be a bit different from the original paper

Author: Xiaoyang Wu (xiaoyang.wu.cs@gmail.com)
Please cite our work if the code is helpful to you.
"""

import torch
import torch.nn as nn

from .utils.point_transformer_v1_utils import TransitionDown, Bottleneck
from .basic_model import BasicModel


class PointTransformerCls(BasicModel):
    def __init__(self, block, blocks, in_channels=3, num_classes=40):
        super().__init__()
        self.in_channels = in_channels
        self.in_planes, planes = in_channels, [32, 64, 128, 256, 512]
        fpn_planes, fpnhead_planes, share_planes = 128, 64, 8
        stride, nsample = [1, 4, 4, 4, 4], [8, 16, 16, 16, 16]
        self.enc1 = self._make_enc(block, planes[0], blocks[0], share_planes,
                                   stride=stride[0], nsample=nsample[0])  # N/1
        self.enc2 = self._make_enc(block, planes[1], blocks[1], share_planes,
                                   stride=stride[1], nsample=nsample[1])  # N/4
        self.enc3 = self._make_enc(block, planes[2], blocks[2], share_planes,
                                   stride=stride[2], nsample=nsample[2])  # N/16
        self.enc4 = self._make_enc(block, planes[3], blocks[3], share_planes,
                                   stride=stride[3], nsample=nsample[3])  # N/64
        self.enc5 = self._make_enc(block, planes[4], blocks[4], share_planes,
                                   stride=stride[4], nsample=nsample[4])  # N/256
        self.cls = nn.Sequential(nn.Linear(planes[4], 256),
                                 nn.BatchNorm1d(256),
                                 nn.ReLU(inplace=True),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(256, 128),
                                 nn.BatchNorm1d(128),
                                 nn.ReLU(inplace=True),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(128, num_classes))

    def _make_enc(self, block, planes, blocks, share_planes=8, stride=1, nsample=16):
        layers = [TransitionDown(self.in_planes, planes * block.expansion, stride, nsample)]
        self.in_planes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_planes, self.in_planes, share_planes, nsample=nsample))
        return nn.Sequential(*layers)

    def forward(self, data_dict):
        p0 = data_dict["xyz"].view(-1, 3)
        x0 = data_dict["feat"].view(-1, self.in_channels)
        o0 = data_dict["offset"].int()
        o0 = torch.cumsum(o0, dim=0)

        x0 = p0 if self.in_channels == 3 else torch.cat((p0, x0), 1)
        p1, x1, o1 = self.enc1([p0, x0, o0])
        p2, x2, o2 = self.enc2([p1, x1, o1])
        p3, x3, o3 = self.enc3([p2, x2, o2])
        p4, x4, o4 = self.enc4([p3, x3, o3])
        p5, x5, o5 = self.enc5([p4, x4, o4])
        x = []
        for i in range(o5.shape[0]):
            if i == 0:
                s_i, e_i, cnt = 0, o5[0], o5[0]
            else:
                s_i, e_i, cnt = o5[i - 1], o5[i], o5[i] - o5[i - 1]
            x_b = x5[s_i:e_i, :].sum(0, True) / cnt
            x.append(x_b)
        x = torch.cat(x, 0)
        x = self.cls(x)
        return x

    def configure_optimizers(self): # type: ignore
        optimizer = torch.optim.SGD(
            self.parameters(), 
            lr=0.05, 
            momentum=0.9, 
            weight_decay=0.0001, 
            nesterov=True
        )
        epochs = self.trainer.max_epochs  # type: ignore
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer, 
            milestones=[epochs * 0.6, epochs * 0.8],  # type: ignore
            gamma=0.1
        )
        return {'optimizer': optimizer, 'lr_scheduler': scheduler}


class PointTransformerCls26(PointTransformerCls):
    def __init__(self, **kwargs):
        super(PointTransformerCls26, self).__init__(Bottleneck, [1, 1, 1, 1, 1], **kwargs)


class PointTransformerCls38(PointTransformerCls):
    def __init__(self, **kwargs):
        super(PointTransformerCls38, self).__init__(Bottleneck, [1, 2, 2, 2, 2], **kwargs)


class PointTransformerCls50(PointTransformerCls):
    def __init__(self, **kwargs):
        super(PointTransformerCls50, self).__init__(Bottleneck, [1, 2, 3, 5, 2], **kwargs)
