# Source code for pymanopt.manifolds.stiefel

import numpy as np

from pymanopt.manifolds.manifold import RiemannianSubmanifold
from pymanopt.tools.multi import (
multiexpm,
multieye,
multiqr,
multisym,
multitransp,
)

[docs]class Stiefel(RiemannianSubmanifold):
r"""The (product) Stiefel manifold.

The Stiefel manifold :math:\St(n, p) is the manifold of orthonormal n x
p matrices.
A point :math:\vmX \in \St(n, p) therefore satisfies the condition
:math:\transp{\vmX}\vmX = \Id_p.
Points on the manifold are represented as arrays of shape (n, p) if
k == 1.
For k > 1, the class represents the product manifold of k Stiefel
manifolds, in which case points on the manifold are represented as arrays
of shape (k, n, p).

The metric is the usual Euclidean metric on :math:\R^{n \times p} which
turns :math:\St(n, p)^k into a Riemannian submanifold.

Args:
n: The number of rows.
p: The number of columns.
k: The number of elements in the product.
retraction: The type of retraction to use.
Possible choices are qr and polar.

Note:
The formula for the exponential map can be found in [ZH2021]_.

The Weingarten map is taken from [AMT2013]_.

The default retraction used here is a first-order one based on
the QR decomposition.
To switch to a second-order polar retraction, use Stiefel(n, p, k=k,
retraction="polar").
"""

def __init__(self, n: int, p: int, *, k: int = 1, retraction: str = "qr"):
self._n = n
self._p = p
self._k = k

# Check that n is greater than or equal to p
if n < p or p < 1:
raise ValueError(
f"Need n >= p >= 1. Values supplied were n = {n} and p = {p}"
)
if k < 1:
raise ValueError(f"Need k >= 1. Value supplied was k = {k}")

if k == 1:
name = f"Stiefel manifold St({n},{p})"
elif k >= 2:
name = f"Product Stiefel manifold St({n},{p})^{k}"
dimension = int(k * (n * p - p * (p + 1) / 2))
super().__init__(name, dimension)

try:
self._retraction = getattr(self, f"_retraction_{retraction}")
except AttributeError:
raise ValueError(f"Invalid retraction type '{retraction}'")

@property
def typical_dist(self):
return np.sqrt(self._p * self._k)

[docs]    def inner_product(self, point, tangent_vector_a, tangent_vector_b):
return np.tensordot(
tangent_vector_a, tangent_vector_b, axes=tangent_vector_a.ndim
)

[docs]    def projection(self, point, vector):
return vector - point @ multisym(multitransp(point) @ vector)

to_tangent_space = projection

[docs]    def weingarten(self, point, tangent_vector, normal_vector):
return -tangent_vector @ multitransp(
point
) @ normal_vector - point @ multisym(
multitransp(tangent_vector) @ normal_vector
)

[docs]    def retraction(self, point, tangent_vector):
return self._retraction(point, tangent_vector)

def _retraction_qr(self, point, tangent_vector):
a = point + tangent_vector
point, _ = multiqr(a)
return point

def _retraction_polar(self, point, tangent_vector):
Y = point + tangent_vector
u, _, vt = np.linalg.svd(Y, full_matrices=False)
return u @ vt

[docs]    def norm(self, point, tangent_vector):
return np.linalg.norm(tangent_vector)

[docs]    def random_point(self):
point, _ = multiqr(np.random.normal(size=(self._k, self._n, self._p)))
if self._k == 1:
return point[0]
return point

[docs]    def random_tangent_vector(self, point):
vector = np.random.normal(size=point.shape)
vector = self.projection(point, vector)
return vector / np.linalg.norm(vector)

[docs]    def transport(self, point_a, point_b, tangent_vector_a):
return self.projection(point_b, tangent_vector_a)

[docs]    def exp(self, point, tangent_vector):
pt_tv = multitransp(point) @ tangent_vector
if self._k == 1:
identity = np.eye(self._p)
else:
identity = multieye(self._k, self._p)

a = np.block([point, tangent_vector])
b = multiexpm(
np.block(
[
[
pt_tv,
-multitransp(tangent_vector) @ tangent_vector,
],
[identity, pt_tv],
]
)
)[..., : self._p]
c = multiexpm(-pt_tv)
return a @ (b @ c)

[docs]    def zero_vector(self, point):
if self._k == 1:
return np.zeros((self._n, self._p))
return np.zeros((self._k, self._n, self._p))