Source code for pymanopt.manifolds.product

import functools
from typing import Sequence

import numpy as np

from pymanopt.manifolds.manifold import Manifold
from pymanopt.tools import ndarraySequenceMixin, return_as_class_instance


[docs]class Product(Manifold): """Cartesian product manifold. Points on the manifold and tangent vectors are represented as lists of points and tangent vectors of the individual manifolds. The metric is obtained by element-wise extension of the individual manifolds. Args: manifolds: The collection of manifolds in the product. """ def __init__(self, manifolds: Sequence[Manifold]): for manifold in manifolds: if isinstance(manifold, Product): raise ValueError("Nested product manifolds are not supported") self.manifolds = tuple(manifolds) manifold_names = " x ".join([str(manifold) for manifold in manifolds]) name = f"Product manifold: {manifold_names}" dimension = np.sum([manifold.dim for manifold in manifolds]) point_layout = tuple(manifold.point_layout for manifold in manifolds) super().__init__(name, dimension, point_layout=point_layout) @property def typical_dist(self): return np.sqrt( np.sum([manifold.typical_dist**2 for manifold in self.manifolds]) ) def _dispatch( self, method_name, *, transform=lambda value: value, reduction=lambda values: values, ): """Wrapper to delegate method calls to individual manifolds.""" @functools.wraps(getattr(self, method_name)) def wrapper(*args, **kwargs): return_values = [ transform(getattr(manifold, method_name)(*arguments)) for manifold, *arguments in zip(self.manifolds, *args) ] return reduction(return_values) return wrapper
[docs] def norm(self, point, tangent_vector): return np.sqrt( self.inner_product(point, tangent_vector, tangent_vector) )
[docs] def inner_product(self, point, tangent_vector_a, tangent_vector_b): return self._dispatch("inner_product", reduction=np.sum)( point, tangent_vector_a, tangent_vector_b )
[docs] def dist(self, point_a, point_b): return self._dispatch( "dist", transform=lambda value: value**2, reduction=lambda values: np.sqrt(np.sum(values)), )(point_a, point_b)
[docs] def projection(self, point, vector): return self._dispatch("projection", reduction=_ProductTangentVector)( point, vector )
[docs] def to_tangent_space(self, point, vector): return self._dispatch( "to_tangent_space", reduction=_ProductTangentVector )(point, vector)
[docs] def euclidean_to_riemannian_gradient(self, point, euclidean_gradient): return self._dispatch( "euclidean_to_riemannian_gradient", reduction=_ProductTangentVector )(point, euclidean_gradient)
[docs] def euclidean_to_riemannian_hessian( self, point, euclidean_gradient, euclidean_hessian, tangent_vector ): return self._dispatch( "euclidean_to_riemannian_hessian", reduction=_ProductTangentVector )(point, euclidean_gradient, euclidean_hessian, tangent_vector)
[docs] def exp(self, point, tangent_vector): return self._dispatch("exp")(point, tangent_vector)
[docs] def retraction(self, point, tangent_vector): return self._dispatch("retraction")(point, tangent_vector)
[docs] def log(self, point_a, point_b): return self._dispatch("log", reduction=_ProductTangentVector)( point_a, point_b )
[docs] def random_point(self): return self._dispatch("random_point")()
[docs] def random_tangent_vector(self, point): scale = len(self.manifolds) ** (-1 / 2) return self._dispatch( "random_tangent_vector", transform=lambda value: scale * value, reduction=_ProductTangentVector, )(point)
[docs] def transport(self, point_a, point_b, tangent_vector_a): return self._dispatch("transport", reduction=_ProductTangentVector)( point_a, point_b, tangent_vector_a )
[docs] def pair_mean(self, point_a, point_b): return self._dispatch("pair_mean")(point_a, point_b)
[docs] def zero_vector(self, point): return self._dispatch("zero_vector", reduction=_ProductTangentVector)( point )
class _ProductTangentVector(ndarraySequenceMixin, list): @return_as_class_instance(unpack=False) def __add__(self, other): if len(self) != len(other): raise ValueError("Arguments must be same length") return [v + other[k] for k, v in enumerate(self)] @return_as_class_instance(unpack=False) def __sub__(self, other): if len(self) != len(other): raise ValueError("Arguments must be same length") return [v - other[k] for k, v in enumerate(self)] @return_as_class_instance(unpack=False) def __mul__(self, other): return [other * val for val in self] __rmul__ = __mul__ @return_as_class_instance(unpack=False) def __truediv__(self, other): return [val / other for val in self] @return_as_class_instance(unpack=False) def __neg__(self): return [-val for val in self]