{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Riemannian Optimization for Inference in MoG models" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The Mixture of Gaussians (MoG) model assumes that datapoints $\\mathbf{x}_i\\in\\mathbb{R}^d$ follow a distribution described by the following probability density function:\n", "\n", "$p(\\mathbf{x}) = \\sum_{m=1}^M \\pi_m p_\\mathcal{N}(\\mathbf{x};\\mathbf{\\mu}_m,\\mathbf{\\Sigma}_m)$ where $\\pi_m$ is the probability that the data point belongs to the $m^\\text{th}$ mixture component and $p_\\mathcal{N}(\\mathbf{x};\\mathbf{\\mu}_m,\\mathbf{\\Sigma}_m)$ is the probability density function of a multivariate Gaussian distribution with mean $\\mathbf{\\mu}_m \\in \\mathbb{R}^d$ and psd covariance matrix $\\mathbf{\\Sigma}_m \\in \\{\\mathbf{M}\\in\\mathbb{R}^{d\\times d}: \\mathbf{M}\\succeq 0\\}$.\n", "\n", "As an example consider the mixture of three Gaussians with means\n", "$\\mathbf{\\mu}_1 = \\begin{bmatrix} -4 \\\\ 1 \\end{bmatrix}$,\n", "$\\mathbf{\\mu}_2 = \\begin{bmatrix} 0 \\\\ 0 \\end{bmatrix}$ and\n", "$\\mathbf{\\mu}_3 = \\begin{bmatrix} 2 \\\\ -1 \\end{bmatrix}$, covariances\n", "$\\mathbf{\\Sigma}_1 = \\begin{bmatrix} 3 & 0 \\\\ 0 & 1 \\end{bmatrix}$,\n", "$\\mathbf{\\Sigma}_2 = \\begin{bmatrix} 1 & 1 \\\\ 1 & 3 \\end{bmatrix}$ and\n", "$\\mathbf{\\Sigma}_3 = \\begin{bmatrix} 0.5 & 0 \\\\ 0 & 0.5 \\end{bmatrix}$\n", "and mixture probability vector $\\boldsymbol{\\pi}=\\left[0.1, 0.6, 0.3\\right]^\\top$.\n", "Let's generate $N=1000$ samples of that MoG model and scatter plot the samples:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import autograd.numpy as np\n", "\n", "\n", "np.set_printoptions(precision=2)\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "%matplotlib inline\n", "\n", "# Number of data points\n", "N = 1000\n", "\n", "# Dimension of each data point\n", "D = 2\n", "\n", "# Number of clusters\n", "K = 3\n", "\n", "pi = [0.1, 0.6, 0.3]\n", "mu = [np.array([-4, 1]), np.array([0, 0]), np.array([2, -1])]\n", "Sigma = [\n", " np.array([[3, 0], [0, 1]]),\n", " np.array([[1, 1.0], [1, 3]]),\n", " 0.5 * np.eye(2),\n", "]\n", "\n", "components = np.random.choice(K, size=N, p=pi)\n", "samples = np.zeros((N, D))\n", "# For each component, generate all needed samples\n", "for k in range(K):\n", " # indices of current component in X\n", " indices = k == components\n", " # number of those occurrences\n", " n_k = indices.sum()\n", " if n_k > 0:\n", " samples[indices, :] = np.random.multivariate_normal(\n", " mu[k], Sigma[k], n_k\n", " )\n", "\n", "colors = [\"r\", \"g\", \"b\", \"c\", \"m\"]\n", "for k in range(K):\n", " indices = k == components\n", " plt.scatter(\n", " samples[indices, 0],\n", " samples[indices, 1],\n", " alpha=0.4,\n", " color=colors[k % K],\n", " )\n", "plt.axis(\"equal\")\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Given a data sample the de facto standard method to infer the parameters is the [expectation maximisation](https://en.wikipedia.org/wiki/Expectation-maximization_algorithm) (EM) algorithm that, in alternating so-called E and M steps, maximises the log-likelihood of the data.\n", "In [arXiv:1506.07677](http://arxiv.org/pdf/1506.07677v1.pdf) Hosseini and Sra propose Riemannian optimisation as a powerful counterpart to EM. Importantly, they introduce a reparameterisation that leaves local optima of the log-likelihood unchanged while resulting in a geodesically convex optimisation problem over a product manifold $\\prod_{m=1}^M\\mathcal{PD}^{(d+1)\\times(d+1)}$ of manifolds of $(d+1)\\times(d+1)$ symmetric positive definite matrices.\n", "The proposed method is on par with EM and shows less variability in running times.\n", "\n", "The reparameterised optimisation problem for augmented data points $\\mathbf{y}_i=[\\mathbf{x}_i^\\top, 1]^\\top$ can be stated as follows:\n", "\n", "$$\\min_{(\\mathbf{S}_1, ..., \\mathbf{S}_m, \\boldsymbol{\\nu}) \\in \\mathcal{D}}\n", "-\\sum_{n=1}^N\\log\\left(\n", "\\sum_{m=1}^M \\frac{\\exp(\\nu_m)}{\\sum_{k=1}^M\\exp(\\nu_k)}\n", "q_\\mathcal{N}(\\mathbf{y}_n;\\mathbf{S}_m)\n", "\\right)$$\n", "\n", "where\n", "\n", "* $\\mathcal{D} := \\left(\\prod_{m=1}^M \\mathcal{PD}^{(d+1)\\times(d+1)}\\right)\\times\\mathbb{R}^{M-1}$ is the search space\n", "* $\\mathcal{PD}^{(d+1)\\times(d+1)}$ is the manifold of symmetric positive definite\n", "$(d+1)\\times(d+1)$ matrices\n", "* $\\mathcal{\\nu}_m = \\log\\left(\\frac{\\alpha_m}{\\alpha_M}\\right), \\ m=1, ..., M-1$ and $\\nu_M=0$\n", "* $q_\\mathcal{N}(\\mathbf{y}_n;\\mathbf{S}_m) =\n", "2\\pi\\exp\\left(\\frac{1}{2}\\right)\n", "|\\operatorname{det}(\\mathbf{S}_m)|^{-\\frac{1}{2}}(2\\pi)^{-\\frac{d+1}{2}}\n", "\\exp\\left(-\\frac{1}{2}\\mathbf{y}_i^\\top\\mathbf{S}_m^{-1}\\mathbf{y}_i\\right)$\n", "\n", "**Optimisation problems like this can easily be solved using Pymanopt – even without the need to differentiate the cost function manually!**\n", "\n", "So let's infer the parameters of our toy example by Riemannian optimisation using Pymanopt:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "\n", "sys.path.insert(0, \"../..\")\n", "\n", "from autograd.scipy.special import logsumexp\n", "\n", "import pymanopt\n", "from pymanopt import Problem\n", "from pymanopt.manifolds import Euclidean, Product, SymmetricPositiveDefinite\n", "from pymanopt.optimizers import SteepestDescent\n", "\n", "\n", "# (1) Instantiate the manifold\n", "manifold = Product([SymmetricPositiveDefinite(D + 1, k=K), Euclidean(K - 1)])\n", "\n", "# (2) Define cost function\n", "# The parameters must be contained in a list theta.\n", "@pymanopt.function.autograd(manifold)\n", "def cost(S, v):\n", " # Unpack parameters\n", " nu = np.append(v, 0)\n", "\n", " logdetS = np.expand_dims(np.linalg.slogdet(S)[1], 1)\n", " y = np.concatenate([samples.T, np.ones((1, N))], axis=0)\n", "\n", " # Calculate log_q\n", " y = np.expand_dims(y, 0)\n", "\n", " # 'Probability' of y belonging to each cluster\n", " log_q = -0.5 * (np.sum(y * np.linalg.solve(S, y), axis=1) + logdetS)\n", "\n", " alpha = np.exp(nu)\n", " alpha = alpha / np.sum(alpha)\n", " alpha = np.expand_dims(alpha, 1)\n", "\n", " loglikvec = logsumexp(np.log(alpha) + log_q, axis=0)\n", " return -np.sum(loglikvec)\n", "\n", "\n", "problem = Problem(manifold, cost)\n", "\n", "# (3) Instantiate a Pymanopt optimizer\n", "optimizer = SteepestDescent(verbosity=1)\n", "\n", "# let Pymanopt do the rest\n", "Xopt = optimizer.run(problem).point" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once Pymanopt has finished the optimisation we can obtain the inferred parameters as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "mu1hat = Xopt[0][0][0:2, 2:3]\n", "Sigma1hat = Xopt[0][0][:2, :2] - mu1hat @ mu1hat.T\n", "mu2hat = Xopt[0][1][0:2, 2:3]\n", "Sigma2hat = Xopt[0][1][:2, :2] - mu2hat @ mu2hat.T\n", "mu3hat = Xopt[0][2][0:2, 2:3]\n", "Sigma3hat = Xopt[0][2][:2, :2] - mu3hat @ mu3hat.T\n", "pihat = np.exp(np.concatenate([Xopt[1], [0]], axis=0))\n", "pihat = pihat / np.sum(pihat)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And convince ourselves that the inferred parameters are close to the ground truth parameters.\n", "\n", "The ground truth parameters $\\mathbf{\\mu}_1, \\mathbf{\\Sigma}_1, \\mathbf{\\mu}_2, \\mathbf{\\Sigma}_2, \\mathbf{\\mu}_3, \\mathbf{\\Sigma}_3, \\pi_1, \\pi_2, \\pi_3$:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(mu[0])\n", "print(Sigma[0])\n", "print(mu[1])\n", "print(Sigma[1])\n", "print(mu[2])\n", "print(Sigma[2])\n", "print(pi[0])\n", "print(pi[1])\n", "print(pi[2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And the inferred parameters $\\hat{\\mathbf{\\mu}}_1, \\hat{\\mathbf{\\Sigma}}_1, \\hat{\\mathbf{\\mu}}_2, \\hat{\\mathbf{\\Sigma}}_2, \\hat{\\mathbf{\\mu}}_3, \\hat{\\mathbf{\\Sigma}}_3, \\hat{\\pi}_1, \\hat{\\pi}_2, \\hat{\\pi}_3$:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "lines_to_next_cell": 2 }, "outputs": [], "source": [ "print(mu1hat)\n", "print(Sigma1hat)\n", "print(mu2hat)\n", "print(Sigma2hat)\n", "print(mu3hat)\n", "print(Sigma3hat)\n", "print(pihat[0])\n", "print(pihat[1])\n", "print(pihat[2])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Et voilà – this was a brief demonstration of how to do inference for MoG models by performing Manifold optimisation using Pymanopt." ] }, { "cell_type": "markdown", "metadata": { "lines_to_next_cell": 2 }, "source": [ "## When Things Go Astray\n", "\n", "A well-known problem when fitting parameters of a MoG model is that one Gaussian may collapse onto a single data point resulting in singular covariance matrices (cf. e.g. p. 434 in Bishop, C. M. \"Pattern Recognition and Machine Learning.\" 2001). This problem can be avoided by the following heuristic: if a component's covariance matrix is close to being singular we reset its mean and covariance matrix. Using Pymanopt this can be accomplished by using an appropriate line search rule (based on [BackTrackingLineSearcher](https://github.com/pymanopt/pymanopt/blob/master/pymanopt/optimizers/line_search.py)) -- here we demonstrate this approach:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class LineSearchMoG:\n", " \"\"\"\n", " Back-tracking line-search that checks for close to singular matrices.\n", " \"\"\"\n", "\n", " def __init__(\n", " self,\n", " contraction_factor=0.5,\n", " optimism=2,\n", " sufficient_decrease=1e-4,\n", " max_iterations=25,\n", " initial_step_size=1,\n", " ):\n", " self.contraction_factor = contraction_factor\n", " self.optimism = optimism\n", " self.sufficient_decrease = sufficient_decrease\n", " self.max_iterations = max_iterations\n", " self.initial_step_size = initial_step_size\n", "\n", " self._oldf0 = None\n", "\n", " def search(self, objective, manifold, x, d, f0, df0):\n", " \"\"\"\n", " Function to perform backtracking line-search.\n", " Arguments:\n", " - objective\n", " objective function to optimise\n", " - manifold\n", " manifold to optimise over\n", " - x\n", " starting point on the manifold\n", " - d\n", " tangent vector at x (descent direction)\n", " - df0\n", " directional derivative at x along d\n", " Returns:\n", " - step_size\n", " norm of the vector retracted to reach newx from x\n", " - newx\n", " next iterate suggested by the line-search\n", " \"\"\"\n", " # Compute the norm of the search direction\n", " norm_d = manifold.norm(x, d)\n", "\n", " if self._oldf0 is not None:\n", " # Pick initial step size based on where we were last time.\n", " alpha = 2 * (f0 - self._oldf0) / df0\n", " # Look a little further\n", " alpha *= self.optimism\n", " else:\n", " alpha = self.initial_step_size / norm_d\n", " alpha = float(alpha)\n", "\n", " # Make the chosen step and compute the cost there.\n", " newx, newf, reset = self._newxnewf(x, alpha * d, objective, manifold)\n", " step_count = 1\n", "\n", " # Backtrack while the Armijo criterion is not satisfied\n", " while (\n", " newf > f0 + self.sufficient_decrease * alpha * df0\n", " and step_count <= self.max_iterations\n", " and not reset\n", " ):\n", "\n", " # Reduce the step size\n", " alpha = self.contraction_factor * alpha\n", "\n", " # and look closer down the line\n", " newx, newf, reset = self._newxnewf(\n", " x, alpha * d, objective, manifold\n", " )\n", "\n", " step_count = step_count + 1\n", "\n", " # If we got here without obtaining a decrease, we reject the step.\n", " if newf > f0 and not reset:\n", " alpha = 0\n", " newx = x\n", "\n", " step_size = alpha * norm_d\n", "\n", " self._oldf0 = f0\n", "\n", " return step_size, newx\n", "\n", " def _newxnewf(self, x, d, objective, manifold):\n", " newx = manifold.retraction(x, d)\n", " try:\n", " newf = objective(newx)\n", " except np.linalg.LinAlgError:\n", " replace = np.asarray(\n", " [\n", " np.linalg.matrix_rank(newx[0][k, :, :])\n", " != newx[0][0, :, :].shape[0]\n", " for k in range(newx[0].shape[0])\n", " ]\n", " )\n", " x[0][replace, :, :] = manifold.random_point()[0][replace, :, :]\n", " return x, objective(x), True\n", " return newx, newf, False" ] } ], "metadata": { "jupytext": { "encoding": "# -*- coding: utf-8 -*-", "formats": "ipynb,py" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" } }, "nbformat": 4, "nbformat_minor": 1 }