import numpy as np
import scipy.linalg
import scipy.version
# Scipy 1.9.0 added support for calling scipy.linalg.expm on stacked matrices.
if scipy.version.version >= "1.9.0":
scipy_expm = scipy.linalg.expm
else:
scipy_expm = np.vectorize(scipy.linalg.expm, signature="(m,m)->(m,m)")
[docs]def multitransp(A):
"""Vectorized matrix transpose.
``A`` is assumed to be an array containing ``M`` matrices, each of which
has dimension ``N x P``.
That is, ``A`` is an ``M x N x P`` array. Multitransp then returns an array
containing the ``M`` matrix transposes of the matrices in ``A``, each of
which will be ``P x N``.
"""
if A.ndim == 2:
return A.T
return np.transpose(A, (0, 2, 1))
[docs]def multihconj(A):
"""Vectorized matrix conjugate transpose."""
return np.conjugate(multitransp(A))
[docs]def multisym(A):
"""Vectorized matrix symmetrization.
Given an array ``A`` of matrices (represented as an array of shape ``(k, n,
n)``), returns a version of ``A`` with each matrix symmetrized, i.e.,
every matrix ``A[i]`` satisfies ``A[i] == A[i].T``.
"""
return 0.5 * (A + multitransp(A))
[docs]def multiherm(A):
return 0.5 * (A + multihconj(A))
[docs]def multiskew(A):
"""Vectorized matrix skew-symmetrization.
Similar to :func:`multisym`, but returns an array where each matrix
``A[i]`` is skew-symmetric, i.e., the components of ``A`` satisfy ``A[i] ==
-A[i].T``.
"""
return 0.5 * (A - multitransp(A))
[docs]def multiskewh(A):
return 0.5 * (A - multihconj(A))
[docs]def multieye(k, n):
"""Array of ``k`` ``n x n`` identity matrices."""
return np.tile(np.eye(n), (k, 1, 1))
[docs]def multilogm(A, *, positive_definite=False):
"""Vectorized matrix logarithm."""
if not positive_definite:
return np.vectorize(scipy.linalg.logm, signature="(m,m)->(m,m)")(A)
w, v = np.linalg.eigh(A)
w = np.expand_dims(np.log(w), axis=-1)
logmA = v @ (w * multihconj(v))
if np.isrealobj(A):
return np.real(logmA)
return logmA
[docs]def multiexpm(A, *, symmetric=False):
"""Vectorized matrix exponential."""
if not symmetric:
return scipy_expm(A)
w, v = np.linalg.eigh(A)
w = np.expand_dims(np.exp(w), axis=-1)
expmA = v @ (w * multihconj(v))
if np.isrealobj(A):
return np.real(expmA)
return expmA
[docs]def multiqr(A):
"""Vectorized QR decomposition."""
if A.ndim not in (2, 3):
raise ValueError("Input must be a matrix or a stacked matrix")
q, r = np.vectorize(np.linalg.qr, signature="(m,n)->(m,k),(k,n)")(A)
# Compute signs or unit-modulus phase of entries of diagonal of r.
s = np.diagonal(r, axis1=-2, axis2=-1).copy()
s[s == 0] = 1
s = s / np.abs(s)
s = np.expand_dims(s, axis=-1)
q = q * multitransp(s)
r = r * np.conjugate(s)
return q, r