import geoopt
from geoopt.manifolds import Sphere, Stiefel
import torch
import torch.nn  as nn
import torch.nn.functional as F


class StiefelLinear(nn.Module):
    def __init__(self, in_channels, out_channels, n_components):
        super(StiefelLinear, self).__init__()
        if out_channels % n_components != 0:
            raise ValueError()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.n_components = n_components
        groups = out_channels // n_components
        if n_components == 1:
            ini = torch.randn(out_channels, in_channels)
            ini = F.normalize(ini, dim=1)
            params = geoopt.ManifoldParameter(ini, Sphere(), requires_grad=True)
        else:
            ini = torch.randn(groups, in_channels, n_components)
            ini, _ = torch.linalg.qr(ini)
            params = geoopt.ManifoldParameter(ini, Stiefel(canonical=False), requires_grad=True)
        self.weight = params

    def forward(self, x):
        if self.n_components == 1:
            return x @ self.weight.t()
        else:
            ret = torch.einsum('nd,cdp->ncp', x, self.weight)
            return ret.view(x.shape[0], self.out_channels)
