Professional Documents
Culture Documents
Lecture 24
Variational Inference
Philipp Hennig
13 July 2021
Faculty of Science
Department of Computer Science
Chair for the Methods of Machine Learning
log p(x | θ) = L(q, θ) + DKL (q∥p(z | x, θ))
X X
p(x, z | θ) p(z | x, θ)
L(q, θ) = q(z) log DKL (q∥p(z | x, θ)) = − q(z) log
z
q(z) z
q(z)
▶ For EM, we minimized KL-divergence to find q = p(z | x, θ) (E), then maximized L(q, θ) in θ.
▶ What if we treated the parameters θ as a probabilistic variable for full Bayesian inference?
z^z ∪ θ
▶ Then we could just maximize L(q(z)) wrt. q (not z!) to implicitly minimize DKL (q∥p(z | x)),
because log p(x) is constant. This is an optimization in the space of distributions q, not
(necessarily) in parameters of such distributions, and thus a very powerful notion.
▶ In general, this will be intractable, because the optimal choice for q is exactly p(z | x). But maybe
we can help out a bit with approximations. Amazingly, we often don’t need to impose strong
approximations. Sometimes we can get away with just imposing restrictions on the factorization
of q, not its analytic form.
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 1
log p(x) = L(q) + DKL (q∥p(z | x))
Z Z
p(x, z) p(z | x)
L(q) = q(z) log dz DKL (q∥p(z | x)) = − q(z) log dz
q(z) q(z)
▶ For EM, we minimized KL-divergence to find q = p(z | x, θ) (E), then maximized L(q, θ) in θ.
▶ What if we treated the parameters θ as a probabilistic variable for full Bayesian inference?
z^z ∪ θ
▶ Then we could just maximize L(q(z)) wrt. q (not z!) to implicitly minimize DKL (q∥p(z | x)),
because log p(x) is constant. This is an optimization in the space of distributions q, not
(necessarily) in parameters of such distributions, and thus a very powerful notion.
▶ In general, this will be intractable, because the optimal choice for q is exactly p(z | x). But maybe
we can help out a bit with approximations. Amazingly, we often don’t need to impose strong
approximations. Sometimes we can get away with just imposing restrictions on the factorization
of q, not its analytic form.
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 1
Mean Field Theory
Factorizing variational approximations
▶ the beauty is that we get to choose q, so one can nearly always find a tractable approximation.
Q
▶ If we impose the mean field approximation q(z) = i q(zi ), get
▶ for Exponential Family p things are particularly simple: we only need the expectation under q of
the sufficient statistics.
Variational Inference is an extremely flexible and powerful approximation method. Its downside is that
constructing the bound and update equations can be tedious. For a quick test, variational inference is
often not a good idea. But for a deployed product, it can be the most powerful tool in the box.
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 3
Example: The Gaussian Mixture Model
Returning to EM Exposition after Bishop, PRML 2006, Chapter 10.2
Y
N Y
K
p(x, z | µ, Σ, π) = πkznk · N (xn ; µk , Σk )znk π
n=1 k=1
Y
N
= p(zn: | π) · p(xn | zn: , µ, Σ)
n=1
zn
µ Σ
xn
N
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 4
Example: The Gaussian Mixture Model
Returning to EM Exposition after Bishop, PRML 2006, Chapter 10.2
n=1 k=1
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 4
Constructing the Variational Approximation
Example: The Gaussian Mixture Model [Exposition from Bishop, PRML 2006, Chapter 10.2]
▶ We know that the full posterior p(z, π, µ, Σ | x) is intractable (check the graph!)
▶ But let’s consider an approximation q(z, π, µ, Σ) with the factorization
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 5
Constructing the Variational Approximation
using q(z, π, µ, Σ) = q(z) · q(π, µ, Σ) [Exposition from Bishop, PRML 2006, Chapter 10.2]
X
K
= log p(π) + log p(µk , Σk ) + Eq(z) (log p(z | π))
k=1
X
N X
K
+ E(znk ) log N (xn ; µk , Σk ) + const.
n=1 k=1
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 6
Constructing the Variational Approximation
using q(z, π, µ, Σ) = q(z) · q(π, µ, Σ) [Exposition from Bishop, PRML 2006, Chapter 10.2]
X
K
log q∗ (π, µ, Σ) = log p(π) + log p(µk , Σk ) + Eq(z) (log p(z | π))
k=1
X
N X
K
+ E(znk ) log N (xn ; µk , Σk ) + const.
n=1 k=1
Y
K
▶ The bound exposes an induced factorization into q(π, µ, Σ) = q(π) · q(µk , Σk )
k=1
where log q(π) = log p(π) + Eq(z) (log p(z | π)) + const.
X XX
= (α − 1) log πk + rnk log πk + const.
k k n
X
q(π) = D(π, αk := α + Nk ) with Nk = rnk
n
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 7
Constructing the Variational Approximation
using q(z, π, µ, Σ) = q(z) · q(π, µ, Σ) [Exposition from Bishop, PRML 2006, Chapter 10.2]
X
K
log q∗ (π, µ, Σ) = log p(π) + log p(µk , Σk ) + Eq(z) (log p(z | π))
k=1
X
N X
K
+ E(znk ) log N (xn ; µk , Σk ) + const.
n=1 k=1
Y
K
▶ The bound exposes an induced factorization into q(π, µ, Σ) = q(π) · q(µk , Σk )
k=1
where (leaving out some tedious algebra) q∗ (µk , Σk ) = N (µk ; mk , Σk /βk )W(Σ−1
k ; Wk , νk )
1
with βk := β + Nk mk := (βm + Nk x̄k ) νk := ν + Nk
βk
βNk
W−1
k
:= W−1 + Nk Sk + (x̄k − m)(x̄k − m)⊺
β + Nk
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 7
Properties of the Dirichlet
Some tabulated identities, required for the concrete algorithm
Γ(α̂) Y αd −1 1 Y αd −1 X
p(x | α) = D(x; α) = Q x = x α̂ := αd
d Γ(αd ) B(α)
d d d
▶ Ep (xd ) = αd
α̂
αd −1
▶ mode(xd ) = α̂−D
▶ Ep (log xd ) = 𝟋(αd ) − 𝟋(α̂)
R P
▶ H(p) = − p(x) log p(x) dx = − d (αd − 1)(𝟋(αd ) − 𝟋(α̂)) + log B(α)
Where 𝟋(z) = d
dz log Γ(z) (the “digamma-function”).
scipy.special.digamma(z) https://dlmf.nist.gov/5
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 8
Constructing the Variational Approximation
closing the loop [Exposition from Bishop, PRML 2006, Chapter 10.2]
EN (µk ;mk ,Σk /βk )W(Σ−1 ;Wk ,νk ) ((xn − µk )⊺ Σ−1 (xn − µk )) = Dβk−1 + νk (xn − mk )⊺ Wk (xn − mk )
k
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 9
Constructing the Variational Approximation
connection to EM [Exposition from Bishop, PRML 2006, Chapter 10.2]
▶ Here, variational Inference is the Bayesian version of EM: Instead of maximizing the likelihood for
θ = (µ, Σ, π), we maximize a variational bound.
▶ One advantage of this is that the posterior can actually “decide” to ignore components:
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 10
Example
The Old Faithful Dataset, using α = 10−3 [from Bishop, PRML 2006, Fig. 10.6 / Ann-Kathrin Schalkamp]
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 11
Latent Dirichlet Allocation
Topic Models [Blei, D. M., Ng, A. Y. & Jordan, M. I. (2003) JMLR 3, 993–1022]
αd βk
πd cdi wdi θk
i = [1, . . . , Id ] k = [1, . . . , K]
d = [1, . . . , D]
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 12
A Variational Bound ∑
Find the best simple way to describe a complex thing reminder: ndkv = #{i : wdi = v, cijk = 1}. ndk· = v ndkv , etc.
! Id
! ! !
Y
D Y
D Y
QK Y Id
D Y
QK Y
K
cdik cdik
p(C, Π, Θ, W) = D(π d ; αd ) · k=1 πdk · k=1 θkwdi · D(θ k ; β k )
d=1 d=1 i=1 d=1 i=1 k=1
▶ To find the best such approximation — the one that minimizes DKL (q∥p(Π, Θ, C | W)), we
maximize the ELBO (minimize variational free energy)
Z
p(C, Π, Θ, W)
L(q) = q(C, Θ, Π) log dC dΘ dΠ
q(C, Θ, Π)
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 13
Constructing the Bound ∑
Mean-Field Theory: Putting Lectures 18–20 to use reminder: ndkv = #{i : wdi = v, cdik = 1}. ndk· = v ndkv , etc.
! Id
! ! !
Y
D Y
D Y
QK Y Id
D Y
QK Y
K
cdik cdik
p(C, Π, Θ, W) = D(π d ; αd ) · k=1 πdk · k=1 θkwdi · D(θ k ; β k )
d=1 d=1 i=1 d=1 i=1 k=1
Recall from above: To maximize the ELBO of a factorized approximation, compute the mean field
log q∗ (zi ) = Ezj ,j̸=i (log p(x, z)) + const.
X XX K
log q∗ (C) = Eq(Π,Θ) cdik log(πdk θkwdi ) + const. = cdik Eq(Π,Θ) (log πdk θdwdi ) +const.
d,i,k d,i k=1
| {z }
=:log γdik
Q Q cdik P
Thus, q(C) = d,i q(cdi ) with q(cdi ) = k γ̃dik , where γ̃dik = γdik / k γdik
(Note: Thus, Eq (cdik ) = γ̃dik )
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 14
Constructing the Bound ∑
Mean-Field Theory: Putting Lectures 18–20 to use reminder: ndkv = #{i : wdi = v, cdik = 1}. ndk· = v ndkv , etc.
! Id
! ! !
Y
D Y
D Y
QK Y Id
D Y
QK Y
K
cdik cdik
p(C, Π, Θ, W) = D(π d ; αd ) · k=1 πdk · k=1 θkwdi · D(θ k ; β k )
d=1 d=1 i=1 d=1 i=1 k=1
Recall from above: To maximize the ELBO of a factorized approximation, compute the mean field
log q∗ (zi ) = Ezj ,j̸=i (log p(x, z)) + const.
X X
log q∗ (Π, Θ) = E∏d,i q(cdi: )) (αdk − 1 + ndk· ) log πdk + (βkv − 1 + n·kv ) log θkv + const.
d,k k,v
X
D X
K X
K X
V
= (αdk − 1 + Eq(C) (ndk· )) log πdk + (βkv − 1 + Eq(C) (n·kv )) log θkv + const.
d=1 k=1 k=1 v=1
!
Y
D Y
K X
D X
Id
∗
q (Π, Θ) = D (π d ; α̃d: := [αd: + γ̃d·: ]) · D θ k ; β̃kv := [βkv + γ̃di: I(wdi = v)]v=1,...,V .
d=1 k=1 d i=1
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 14
The Variational Approximation
closing the loop
" #
X
Id
q(π d ) = D π d ; α̃dk := αdk + γ̃dik ∀d = 1, . . . , D
i=1 k=1,...,K
" #
X
D X
Id
q(θ k ) = D θ k ; β̃kv := βkv + γ̃dik I(wdi = v) ∀k = 1, . . . , K
d i=1 v=1,...,V
Y cdik
q(cdi ) = γ̃dik , ∀d i = 1, . . . , Id
k
P P
where γ̃dik = γdik / k γdik and (note that
α̃dk = const.) k
γdik = exp Eq(πdk ) (log πdk ) + Eq(θdi ) (log θkwdi )
!!
X
= exp 𝟋(α̃jk ) + 𝟋(β̃kwdi ) − 𝟋 β̃kv
v
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 15
The ELBO
useful for monitoring / bug-fixing
P ! P !
Y
D
Γ( k αdk ) Y αdk −1+ndk·
K Y
K
Γ( v βkv ) Y βkv −1+n·kv
V
p(C, Π, Θ, W) = Q πdk · Q θkv
d=1 k Γ(αdk ) k=1 k=1 v Γ(βkv )v=1
We need
L(q, W) = Eq (log p(W, C, Θ, Π)) + H(q)
Z Z
= q(C, Θ, Π) log p(W, C, Θ, Π) dC dΘ dΠ − q(C, Θ, Π) log q(C, Θ, Π) dC dΘ dΠ
Z X X X
= q(C, Θ, Π) log p(W, C, Θ, Π) dC dΘ dΠ + H(D(θk β̃k )) + H(D(πd α̃d )) + H(γ̃ di )
k d di
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 16
Building the Algorithm
updating and evaluating the bound
1 procedure LDA(W, α, β)
2 γ̃dik ^ DIRICHLET_RAND(α) initialize
3 L ^ −∞
4 while L not converged do
5 for d = 1, . . . , D;Pk = 1, . . . , K do
6 α̃dk ^ αdk + i γ̃dik update document-topics distributions
7 end for
8 for k = 1, . . . , K;Pv = 1, . . . , V do
9 β̃kv ^ βkv + d,i γ̃dik I(wdi = v) update topic-word distributions
10 end for
11 for d = 1, . . . , D; k = 1, . . . , K; i = 1, . . .P
, Id do
12 γ̃dik ^ exp(𝟋(α̃dk ) + 𝟋(β̃kwdi ) − 𝟋( v β̃kv )) update word-topic assignments
13 γ̃dik ^ γ̃dik /γ̃di·
14 end for
15 L ^ BOUND(γ̃, w, α̃, β̃) update bound
16 end while
17 end procedure
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 17
Z
p(C, Π, Θ, W)
L(q) = q(C, Θ, Π) log dC dΘ dΠ
q(C, Θ, Π)
“Derive your variational bound in the time it takes for your Monte Carlo sampler to converge.”
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 18
The Toolbox
Framework:
Z
p(y | x)p(x)
p(x1 , x2 ) dx2 = p(x1 ) p(x1 , x2 ) = p(x1 | x2 )p(x2 ) p(x | y) =
p(y)
Modelling: Computation:
▶ graphical models ▶ Monte Carlo
▶ Gaussian distributions ▶ Linear algebra / Gaussian inference
▶ (deep) learnt representations ▶ maximum likelihood / MAP
▶ Kernels ▶ Laplace approximations
▶ Markov Chains ▶ EM (iterative maximum likelihood)
▶ Exponential Families / Conjugate Priors ▶ variational inference / mean field
▶ Factor Graphs & Message Passing
Probabilistic ML — P. Hennig, SS 2021 — Lecture 24: Variational Inference — © Philipp Hennig, 2021 CC BY-NC-SA 3.0 19