Variational Inference

Last updated at 03-06-2018

It took me more than two weeks to finally to get the essence of variational inference. The painful but fulfilling process brought me to appreciate the really difficult (at least for me) but beautiful math behind it.

:books: A couple of useful tutorials I found:


Table of Contents

import numpy as np
import scipy as sp
import seaborn as sns
from matplotlib import pyplot as plt
%matplotlib inline

Introduction

A motivating example

As with expectation maximization, I start by describing a problem to motivate variational inference. Please refer to Prof. Blei’s review for more details above.

Let’s start by considering a problem where we have data points sampled from mixtures of Gaussian distributions. Specifically, there are $K$ univariate Gaussian distributions with means $\mathbf{\mu} = { \mu_1, …, \mu_K }$ and unit variance ($\mathbf{\sigma}=\mathbf{1}$) (for simplicity):

mixture of gaussian example

Please refer to my EM post on details of this sample data

In a Bayesian setting, we can assume that all the means come from the same prior distribution, which is also a Gaussian $\mathcal{N}(0, \sigma^2)$, with variance $\sigma^2$ being a hyperparameter. Specifically, we can setup a very simple generative model:

This gives us a straightforward view of how the joint probability can be written out:

\begin{align} p(\mathbf{c}, \mathbf{\mu}, \mathbf{x}) & = p(\mathbf{\mu})p(\mathbf{c})p(\mathbf{x} \vert \mathbf{c}, \mathbf{\mu}) \\ & = p(\mathbf{\mu}) \prod_{i} p(c^{(i)})p(x^{(i)} \vert c^{(i)}, \mathbf{\mu}) \end{align}

Summing/integrating out the latent variables, we can obtain the marginal likelihood (i.e., evidence):

\begin{align} p(\mathbf{x}) & = \int_{\mathbf{\mu}} p(\mathbf{\mu}) \prod_{i} \sum_{c^{(i)}} p(c^{(i)})p(x^{(i)} \vert c^{(i)}, \mathbf{\mu}) d \mathbf{\mu}\\ & [\text{We can switch the order of how we integrate/sum out the latent variables}] \\ & = \sum_{\mathbf{c}} p(\mathbf{c}) \int_{\mathbf{\mu}} p(\mathbf{\mu}) \prod_{i} p(x^{(i)} \vert c^{(i)}, \mathbf{\mu}) d \mathbf{\mu} \end{align}

Note that while it is possible to compute individual termins within the integral (Gaussian prior and Gaussian likelihood), the overall complexity will go up to $\mathcal{O}(K^n)$ (which is all possible configurations). Therefore, we need to consider approximate inference due to the intractability.

General situation

Actually, the motivation of VI is very similar to EM, which is to come up with an approximation of point estimates of the latent variables. Instead of point estimates, VI tries to find variational distributions that serve as good proxies for the exact solution.

Suppose we have $\mathbf{z}={ z^{(1)}, …, z^{(n)}}$ as observed data and $\mathbf{z}={ z^{(1)}, …, z^{(n)}}$ as latent variables. The inference problem is to find the posterior probability of the latent variables given observations $p(\mathbf{z} \vert \mathbf{x})$:

Often times, the denominator evidence is intractable. Therefore, we need approximations to find a relatively good solution in a reasonable amount of time. VI is exactly what we need!


Evidence Lower Bound (ELBO)

In my EM post, we can prove that the log evidence $ln~p(\mathbf{x})$ can actually be decomposed as follows (note that we will use integral this time):

\begin{align} ln~p(\mathbf{x}) & = \int_{\mathbf{z}} q(\mathbf{z}) d\mathbf{z}~~ln~p(\mathbf{x}) \\ & [\text{Recall that } \int_{\mathbf{z}} q(\mathbf{z}) d\mathbf{z} = 1] \\ & = \int_{\mathbf{z}} q(\mathbf{z}) ln~ \frac{p(\mathbf{x}, \mathbf{z})}{p(\mathbf{z} \vert \mathbf{x})} d\mathbf{z}\\ & = \int_{\mathbf{z}} q(\mathbf{z}) ln~ \frac{p(\mathbf{x}, \mathbf{z})~q(\mathbf{z})}{p(\mathbf{z} \vert \mathbf{x}) ~q(\mathbf{z})} d\mathbf{z}\\ & = \int_{\mathbf{z}} q(\mathbf{z}) ln~ \frac{p(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})} d\mathbf{z} + \int_{\mathbf{z}} q(\mathbf{z}) ln~ \frac{q(\mathbf{z})}{p(\mathbf{z} \vert \mathbf{x})} d\mathbf{z}\\ & = \mathcal{L}(\mathbf{x}) + KL(q\vert \vert p) \end{align}

where $\mathcal{L}(\mathbf{x})$ is defined as ELBO: KL divergence is bounded and nonnegative.

If we further decompose ELBO, we have:

\begin{align} \mathcal{L}(\mathbf{x}) & = \int_{\mathbf{z}} q(\mathbf{z}) ln~ \frac{p(\mathbf{x}, \mathbf{z})}{q(\mathbf{z})} d\mathbf{z} \\ & = \int_{\mathbf{z}} q(\mathbf{z}) ln~p(\mathbf{x} \vert \mathbf{z}) - q(\mathbf{z})ln~\frac{q(\mathbf{z})}{p(\mathbf{z})} d\mathbf{z}\\ & = E_q\big[ ln~p(\mathbf{z} \vert \mathbf{x}) \big] - KL(q(\mathbf{z})||p(\mathbf{z}))\\ & = \int_{\mathbf{z}} q(\mathbf{z}) ln~p(\mathbf{x}, \mathbf{z}) - q(\mathbf{z})ln~q(\mathbf{z}) d\mathbf{z}\\ & = E_q\big[ ln~p(\mathbf{x}, \mathbf{z}) \big] + \mathcal{H}(q) ~~\text{(Entropy of } q\text{)}\\ \end{align}

The last equation above shows that ELBO trades off between the two terms:

Finally, we note that, in EM, we are able to compute $p(\mathbf{z}\vert \mathbf{x})$ so that we can easily maximize ELBO. However, VI is the way to do when we cannot.


Mean Field Variational Family

By far, we haven’t say anything about what $q$’s should be. In this note, we only look at a classical type, called mean field variational family. Specifically, it assumes that latent variables are mutually independent. This means that we can easily factorize the variational distributions into groups:

By doing this, we are unable to capture the interdependence between the latent variables. A nice visualization from Blei et al. (2017):

image.png


Coordinate Ascent VI (CAVI)

By factorizing the variational distributions into invidual products, we can easily apply coordinate ascent optimization on each factor. A common procedure to conduct CAVI is:

Derivation of optimal var. dist.:

In fact, we can derive the optimal solutions without too much efforts:

\begin{align} ELBO & = E_q[log~p(x,z)] - E_q[log~q(z)] \\ & [\text{Here we use the fact that } q(z) \text{ can be factorized}]\\ & = E_q[log~p(x, z_j, z_{-j})] - \sum_{q_l}E_{q_l}[q_l(z_l)] \\ & [\text{Iterative expectation: } E[A] = E[E[A|B]]]\\ & = E_j\Big[E_{-j}\big[ log~p(x, z_j, z_{-j}) \vert z_{j} \big] \Big] - E_{q_j}[q_j] + const \\ \end{align}

Now, according to the definition of expectation, we have:

\begin{align} E_{-j}\big[ log~p(x, z_j, z_{-j}) \vert z_{j} \big] &= \int_{-j} log~p(x, z_j, z_{-j})~q(z_{-j}|z_j) dq_{-j} \\ & = \int_{-j} log~p(x, z_j, z_{-j})~q(z_{-j}) dq_{-j} \\ & = E_{-j}\big[ log~p(x, z_j, z_{-j}) \big] \end{align}

We assume independence between latent variables’ variational distributions $q(z)$

Therefore we have:

\begin{align} ELBO & = E_{j}\Big[E_{-j}\big[ log~p(x, \mathbf{z}) \big]\Big] - E_{j}[q_j] + const \\ \end{align}

We can see that the first two terms can be combined into a negative KL divergence between those within the $E_j\big[ \cdot \big]$. Therefore, we can write down the optimal solution as:

Alternative way

While the derivation through iterative expectation seems to be simpler, I personally still prefer taking partial derivatives to parameters of variational distributions, as in the following example, which seems to be more natural to me. After all, we will be using ELBO to check convergence anyway.


Applying VI on GMM

Let’s get back to our original problem with the univariate Gaussian mixtures with unit variance. The full parameterization is as follows:

\begin{align} \mu_j & \sim \mathcal{N}(0, \sigma^2)~\text{for } j = 1, ..., K \\ c_i & \sim \mathcal{U}(K)~\text{for } i = 1, ..., N \\ x_i & \sim \mathcal{N}(c_i^T \mu, 1)~\text{for } i = 1, ..., N \end{align}

Note that $c_i$ is a vector of one’s and zero’s such that $c_{ij} = 1; c_{il} = 0 \text{ for } j\neq l$ (a.k.a, one-hot vector).

By mean field VI, we can introduce variational distributions for the two latent variables $\mathbf{c}$ and $\mathbf{\mu}$:

Choose $q$

According to what we have above, we will choose the following variational distributions for $c$ and $\mu$

where:

\begin{align} \mu_j; m_j, s_j^2 & \sim \mathcal{N}(m_j, s_j^2) \\ c_i; \phi_i & \sim Multi(\phi_i) \end{align}

Therefore, $\phi_i$ is a vector of probabilities such that $p(c_i=j) = \phi_{ij}$

ELBO

The most important thing is to write down ELBO, the evidence lower bound, which is needed for (i) parameter updates; (ii) convergence check. However, I’ve seen that convergence check could be done by the relative change of parameter estimates here. If parameters do not change much, VI will stop by thinking that it has converged.

Recall that $ELBO = E_q[log~p(x,z)] - E_q[log~q(z)] $. Let me split this task into two.

Full joint probability

The hidden/latent variables in this problem are $c$ and $\mu$.

\begin{align} log~p(x, c, \mu) & = log~p(\mu)p(c)p(x~\vert~c, \mu) \\ & = \sum_j log~p(\mu_j) + \sum_i \big[ log~p(c_i) + log~p(x_i~\vert~c_i, \mu) \big] \\ \end{align}

$p(c_i) = \dfrac{1}{K}$ is a constant drop it. We then expand $p(\mu_j)$:

\begin{align} log~p(\mu_j) & = log~\Big\{ \dfrac{1}{\sqrt{2\pi \sigma^2}} exp\big[ -\dfrac{\mu_j^2}{2\sigma^2} \big] \Big\} \\ & [log~\dfrac{1}{\sqrt{2\pi \sigma^2}} \text{ is a constant} ]\\ & \propto -\dfrac{\mu_j^2}{2\sigma^2} \end{align}

For $log~p(x_i~\vert~c_i, \mu)$, it is a bit tricky. Recall that $c_i$ is a one-hot vector, where only one of the element is 1. We can make use of this property and rewrite:

Combine all the above, we can write the log full joint probability as:

Entropy of variational distributions

Thanks to the mean field assumption, we can factorize the joint of variational easily:

Let’s expand these two terms seperately.

\begin{align} log~p(\mu_j; m_j, s_j^2) & = log~\Big\{ \dfrac{1}{\sqrt{2\pi s_j^2}} exp \big[ -\dfrac{(\mu_j-m_j)^2}{s_j^2} \big] \Big\} \\ & = -\dfrac{1}{2}log~(2\pi s_j^2) -\dfrac{(\mu_j-m_j)^2}{s_j^2} \end{align}

Therefore, we have:

Full ELBO

Merge the results back, we have the ELBO written as:

\begin{align} ELBO \propto & \sum_j -E_q\Big[\dfrac{\mu_j}{2\sigma^2}\Big] + \sum_i\sum_j E_q\Big[c_{ij}\Big]E_q\Big[-\dfrac{(x_i-\mu_j)^2}{2}\Big] \\ &- \sum_i \sum_j E_q\Big[log~\phi_{ij}\Big] + \sum_j \dfrac{1}{2}log~(s_j^2) \end{align}

Parameter updates

$\phi_{ij}$

This is a contrained optimization because $\sum_j \phi_{ij} = 1~\forall i$. However, we do not need to add the Lagrange multiplier and the result can still be normalized (we are using a lot of $\propto$ here!)

\begin{align} \dfrac{\partial}{\partial \phi_{ij}}~ELBO & \propto \dfrac{\partial}{\partial \phi_{ij}}\Big\{E_q\Big[-\dfrac{(x_i-\mu_j)^2}{2}\Big] \phi_{ij} - E_q\Big[log~\phi_{ij}\Big] \Big\}\\ & = E_q\Big[-\dfrac{(x_i-\mu_j)^2}{2}\Big] - log~\phi_{ij} - 1 = 0 \\ & E[\mu_j] = m_j \text{; } E[\mu_j^2] = V[\mu] + E^2[\mu] = s_j^2 + \mu_j^2 \\ log~\phi_{ij} & \propto E_q\Big[-\dfrac{(x_i-\mu_j)^2}{2}\Big] \\ \phi_{ij}^* & \propto exp\{ -\tfrac{1}{2}(m_j^2+s_j^2) + x_i m_j \} \end{align}
$m_j$
\begin{align} \dfrac{\partial}{\partial m_{j}}~ELBO & \propto \dfrac{\partial}{\partial m_{j}}~\Big\{ -E\big[\dfrac{\mu_j^2}{2\sigma^2}\big] - \sum_i \phi_{ij} E[\dfrac{(x_i-\mu_j)^2}{2}] \Big\} \\ & \propto \dfrac{\partial}{\partial m_{j}}~\Big\{ -\dfrac{1}{2\sigma^2} m_j^2 - \sum_i \phi_{ij} \big[ -\dfrac{1}{2}m_j^2 + x_i m_j \big] \Big\} \\ & = -\dfrac{1}{\sigma^2}m_j - \sum_i\phi_{ij} m_j + \sum_i \phi_{ij} x_i = 0 \\ m_j^* &= \dfrac{\sum_i\phi_{ij}x_i}{\tfrac{1}{\sigma^2} + \sum_i\phi_{ij}} \end{align}
$s_j^2$

Note that we are considering $s_j^2$ as a whole.

\begin{align} \dfrac{\partial}{\partial s_j^2}~ELBO & \propto \dfrac{\partial}{\partial s_j^2}~ \big\{ -E \big[\dfrac{\mu_j^2}{2\sigma^2}\big] - \sum_i \phi_{ij} E[\dfrac{(x_i-\mu_j)^2}{2}] +\dfrac{1}{2}log~s_j^2 \big\} \\ & \propto \dfrac{\partial}{\partial s_j^2}~ \Big\{ -\dfrac{1}{2\sigma^2}s_j^2 - \sum_i\phi_{ij}(\dfrac{1}{2}s_j^2) +\dfrac{1}{2}log~s_j^2 \Big\} \\ & = -\dfrac{1}{2\sigma^2} - \sum_i\dfrac{\phi_{ij}}{2} + \dfrac{1}{2s_j^2} = 0 \\ \dfrac{1}{s_j^2} & = \dfrac{1}{\sigma^2} + \sum_i\phi_{ij} \\ (s_j^2)^{*} & = \dfrac{1}{\frac{1}{\sigma^2} + \sum_i\phi_{ij}} \end{align}

Now that we have the ELBO and paramter update formulas, we can setup our own VI algorithm for this simple Guassian Mixture!


Python Implementation

import numpy as np

class UGMM(object):
    '''Univariate GMM with CAVI'''
    def __init__(self, X, K=2, sigma=1):
        self.X = X
        self.K = K
        self.N = self.X.shape[0]
        self.sigma2 = sigma**2

    def _init(self):
        self.phi = np.random.dirichlet([np.random.random()*np.random.randint(1, 10)]*self.K, self.N)
        self.m = np.random.randint(int(self.X.min()), high=int(self.X.max()), size=self.K).astype(float)
        self.m += self.X.max()*np.random.random(self.K)
        self.s2 = np.ones(self.K) * np.random.random(self.K)
        print('Init mean')
        print(self.m)
        print('Init s2')
        print(self.s2)

    def get_elbo(self):
        t1 = np.log(self.s2) - self.m/self.sigma2
        t1 = t1.sum()
        t2 = -0.5*np.add.outer(self.X**2, self.s2+self.m**2)
        t2 += np.outer(self.X, self.m)
        t2 -= np.log(self.phi)
        t2 *= self.phi
        t2 = t2.sum()
        return t1 + t2

    def fit(self, max_iter=100, tol=1e-10):
        self._init()
        self.elbo_values = [self.get_elbo()]
        self.m_history = [self.m]
        self.s2_history = [self.s2]
        for iter_ in range(1, max_iter+1):
            self._cavi()
            self.m_history.append(self.m)
            self.s2_history.append(self.s2)
            self.elbo_values.append(self.get_elbo())
            if iter_ % 5 == 0:
                print(iter_, self.m_history[iter_])
            if np.abs(self.elbo_values[-2] - self.elbo_values[-1]) <= tol:
                print('ELBO converged with ll %.3f at iteration %d'%(self.elbo_values[-1],
                                                                     iter_))
                break

        if iter_ == max_iter:
            print('ELBO ended with ll %.3f'%(self.elbo_values[-1]))


    def _cavi(self):
        self._update_phi()
        self._update_mu()

    def _update_phi(self):
        t1 = np.outer(self.X, self.m)
        t2 = -(0.5*self.m**2 + 0.5*self.s2)
        exponent = t1 + t2[np.newaxis, :]
        self.phi = np.exp(exponent)
        self.phi = self.phi / self.phi.sum(1)[:, np.newaxis]

    def _update_mu(self):
        self.m = (self.phi*self.X[:, np.newaxis]).sum(0) * (1/self.sigma2 + self.phi.sum(0))**(-1)
        assert self.m.size == self.K
        #print(self.m)
        self.s2 = (1/self.sigma2 + self.phi.sum(0))**(-1)
        assert self.s2.size == self.K

Making data

num_components = 3
mu_arr = np.random.choice(np.arange(-10, 10, 2),
                          num_components) +\
         np.random.random(num_components)
mu_arr
array([ 8.79153551,  6.29803456, -5.7042636 ])
SAMPLE = 1000
X = np.random.normal(loc=mu_arr[0], scale=1, size=SAMPLE)
for i, mu in enumerate(mu_arr[1:]):
    X = np.append(X, np.random.normal(loc=mu, scale=1, size=SAMPLE))
fig, ax = plt.subplots(figsize=(15, 4))
sns.distplot(X[:SAMPLE], ax=ax, rug=True)
sns.distplot(X[SAMPLE:SAMPLE*2], ax=ax, rug=True)
sns.distplot(X[SAMPLE*2:], ax=ax, rug=True)
<matplotlib.axes._subplots.AxesSubplot at 0x10f5784e0>

png

ugmm = UGMM(X, 3)
ugmm.fit()
Init mean
[9.62056838 2.48053419 8.95455044]
Init s2
[0.22102799 0.50256273 0.72923656]
5 [ 8.78575069 -5.69598804  6.32040619]
10 [ 8.77126102 -5.69598804  6.30384436]
15 [ 8.77083542 -5.69598804  6.30344752]
20 [ 8.77082412 -5.69598804  6.30343699]
25 [ 8.77082382 -5.69598804  6.30343671]
30 [ 8.77082381 -5.69598804  6.3034367 ]
35 [ 8.77082381 -5.69598804  6.3034367 ]
ELBO converged with ll -1001.987 at iteration 35
ugmm.phi.argmax(1)
array([0, 0, 0, ..., 1, 1, 1])
sorted(mu_arr)
[-5.704263600460798, 6.298034563379406, 8.791535506275245]
sorted(ugmm.m)
[-5.695988039984863, 6.303436701203107, 8.770823807705389]
fig, ax = plt.subplots(figsize=(15, 4))
sns.distplot(X[:SAMPLE], ax=ax, hist=True, norm_hist=True)
sns.distplot(np.random.normal(ugmm.m[0], 1, SAMPLE), color='k', hist=False, kde=True)
sns.distplot(X[SAMPLE:SAMPLE*2], ax=ax, hist=True, norm_hist=True)
sns.distplot(np.random.normal(ugmm.m[1], 1, SAMPLE), color='k', hist=False, kde=True)
sns.distplot(X[SAMPLE*2:], ax=ax, hist=True, norm_hist=True)
sns.distplot(np.random.normal(ugmm.m[2], 1, SAMPLE), color='k', hist=False, kde=True)

png

Zhiya Zuo

Zhiya Zuo

Filet-O-Fish 🍔 is the BEST!

comments powered by Disqus
rss facebook twitter github youtube mail spotify lastfm instagram linkedin google google-plus pinterest medium vimeo stackoverflow reddit quora quora