Automatic Differentiation
Cost functions, gradients and Hessian-vector products (hvps) in Pymanopt must
be defined as Python callables annotated with one of the backend
decorators below.
Decorating a callable with a backend decorator will wrap it in an instance of
the pymanopt.autodiff.Function
class that provides a backend-agnostic
API to the pymanopt.core.problem.Problem
class to compute derivatives.
If an autodiff backend is used via one of the provided decorators, the
signature of the decorated callable must match the point layout of the manifold
it is defined on.
For instance, for memory efficiency points on the
pymanopt.manifolds.fixed_rank.FixedRankEmbedded
manifold are not
represented as m x n
matrices in the ambient space but as a singular value
decomposition.
As such a cost function defined on the manifold must accept three arguments
u
, s
and vt
. Refer to the documentation of the respective manifold
on how points are represented.
New backends can be created by inheriting from the
pymanopt.autodiff.backends._backend.Backend
class, and creating a
backend decorator using pymanopt.autodiff.backend_decorator_factory()
.
- pymanopt.autodiff.backend_decorator_factory(backend_cls)[source]
Create function decorator for a backend.
Function to create a backend decorator that is used to annotate a callable:
decorator = backend_decorator_factory(backend_cls) @decorator(manifold) def function(x): ...
- Parameters
backend_cls – a class implementing the backend interface defined by
pymanopt.autodiff.backend._backend._Backend
.- Returns
A new backend decorator.
- Return type
Callable[[pymanopt.manifolds.manifold.Manifold], Callable[[Callable], pymanopt.autodiff.Function]]
Backends
- pymanopt.function.autograd(manifold)
- Parameters
manifold (pymanopt.manifolds.manifold.Manifold) –
- Return type
Callable
- pymanopt.function.jax(manifold)
- Parameters
manifold (pymanopt.manifolds.manifold.Manifold) –
- Return type
Callable
- pymanopt.function.numpy(manifold)
- Parameters
manifold (pymanopt.manifolds.manifold.Manifold) –
- Return type
Callable
- pymanopt.function.pytorch(manifold)
- Parameters
manifold (pymanopt.manifolds.manifold.Manifold) –
- Return type
Callable
- pymanopt.function.tensorflow(manifold)
- Parameters
manifold (pymanopt.manifolds.manifold.Manifold) –
- Return type
Callable