You are on page 1of 11

Exercises 10

Variational inference & Expectation propagation

Exercise 10.1 (adapted from [2])


Consider the Kullback-Leibler divergence
Z
q(x)
KL(p(x) k q(x)) = − p(x) ln dx.
p(x)

Evaluate KL(p(x) k q(x)) where p(x) and q(x) are . . .

a) . . . two scalar Gaussians


   
p(x) = N x; µ, σ 2 and q(x) = N x; m, s2 .

b) . . . two multivariate Gaussians

p(x) = N (x; µ, Σ) and q(x) = N (x; m, S) .

Exercise 10.2
Consider a distribution p(x) which we want to approximate with a scalar Gaussian
 
qµ,σ (x) = N x; µ, σ 2

by minimizing the Kullback-Leibler divergence

µ̂, σ̂ = arg min KL(p(x) k qµ,σ (x)).


µ,σ

Show that this results in moment matching

µ̂ = Ep [x]
σ̂ 2 = Ep [(x − µ̂)2 ]

f (x)p(x) dx, i.e., that µ̂ and σ̂ 2 will be the mean and variance of p(x).
R
where Ep [f (x)] =

Version: September 27, 2021


Exercises 10. Variational inference & Expectation propagation 11

Exercise 10.3 a) Consider the variational linear regression example from the lecture,
 
p(y | w, β) = N y; xw, β −1 IN

where in this problem we also treat β as a random variable and use the prior1
 
p(w, α, β) = p(w|α)p(α)p(β) = N w; 0, α−1 Gam (α; a0 , b0 ) Gam (β; c0 , d0 )

where
1 a a−1 −bα
Gam (α; a, b) = b α e , α ∈ [0, ∞).
Γ(a)
Assume the factorized variational distribution q(w, α, β) = q(w)q(α)q(β).
Use variational inference to derive the equations for updating the variational distribution q(w, α, β) approximating
the posterior p(w, α, β|y).
b) Consider the case where w is a vector, i.e., the standard linear regression setting where we consider the likelihood
 
p(y | w) = N y; Xw, β −1 IN .

Consider a diagonal prior on w and as in (a) a Gamma prior on α and β, i.e.,


 
p(w, α, β) = p(w|α)p(α)p(β) = N w; 0, α−1 I Gam (α; a0 , b0 ) Gam (β; c0 , d0 ) .

Consider the same factorization of the variational distribution q(w, α, β) = q(w)q(α)q(β).


Hint: For a multivariate Gaussian distribution x ∼ N (µ, Σ) we have

E[xxT ] = µµT + Σ.

You may also want to use that


Tr(AB) = Tr(BA)
where A and B are two matrices and Tr is the trace operator.

1 In the conjugate prior for linear regression, we need to have a = c . We are thus treating a more general problem, where the true posterior may not
0 0
have a closed-form expression.
Exercises 10. Variational inference & Expectation propagation 12

Exercise 10.4
Consider the dynamical model
 
xn+1 = γxn + vn , vn ∼ N 0, βv2
1  
yn = xn + en , en ∼ N 0, σe2
2
with the initial state and prior

x0 ∼ N (x̄0 , Σ0 )
 
γ ∼ N 0, σγ2

We observe y1:N = {y1 , . . . , yN } and consider x0:N = {x0 , . . . , xN } and γ as our latent variables.

a) Derive an expression of the joint distribution

p(y1:N , x0:N , γ).

b) Consider the variational approximation


N
Y
q(γ, x0:N ) = q(γ) q(xn ).
n=0

Use variational inference to derive the equations for estimating the posterior q(γ) ≈ p(γ|y1:N ).
c) We can also solve this problem without factorizing the terms xn by considering

q(γ, x0:N ) = q(γ)q(x0:N ).

Use variational inference to estimate the posterior of q(γ) ≈ p(γ|y1:N ) using this variational approximation.
Solutions 10

Variational inference & Expectation propagation

Solution to Exercise 10.1 a) We have


Z Z
KL(p(x) k q(x)) = − p(x) ln q(x) dx + p(x) ln p(x) dx.

The first term can be written as


!
(x − m)2
Z Z  1
2 2
− p(x) ln q(x) dx = N x; µ, σ ln(2πs ) + dx
2 s2
!
x2 − 2xm + m2
Z  1
2 2
= N x; µ, σ ln(2πs ) + dx
2 s2
!
(x − µ)2 + 2x(µ − m) + m2 − µ2
Z  1
= N x; µ, σ 2 ln(2πs2 ) + dx
2 s2
!
1 σ 2 + 2µ(µ − m) + m2 − µ2
= ln(2πs2 ) +
2 s2
!
2 2
1 σ + (µ − m)
= ln(2πs2 ) + .
2 s2

By replacing m and s with µ and σ, for the second term we get that
Z
1 
p(x) ln p(x) dx = − ln(2πσ 2 ) + 1
2
which gives  
!
2 2 2
1 s σ + (µ − m)
KL(p(x) k q(x)) = ln + − 1 .
2 σ2 s2

Version: September 27, 2021


Solutions 10. Variational inference & Expectation propagation 35

Solution to Exercise 10.2

µ̂, σ̂ = arg min KL(p(x) k qµ,σ (x))


µ,σ
Z Z
= arg min − p(x) ln qµ,σ (x)dx + p(x) ln p(x) dx
µ,σ
Z  
1 2
= arg min p(x) ln σ + 2 (x − µ) dx
µ,σ 2σ
!
σp2 + (µ − µp )2
= arg min ln σ + .
µ,σ 2σ 2
For µ this is minimized by µ̂ = µp . What remains for σ is
!
σp2
σ̂ = arg min ln σ + 2 .
σ 2σ
| {z }
f (σ)

Setting the derivative equal to zero gives

d 1 σp2
0= f (σ) = − 3 ⇒ σ 2 = σp2 .
dσ σ σ
This is a minimum point since
d2 1 3σp2 2
f (σ) =− + = 2 > 0.
dσ 2 σ=σp σp2 σp4 σp

Solution to Exercise 10.3 a) Variational approximation


q(w, α, β) = q(w)q(α)q(β)

The two equations we will iterate are


ln q̂(w) = Eq̂(α),q̂(β) [ln p(y, w, α)] + const.
ln q̂(α) = Eq̂(w),q̂(β) [ln p(y, w, α)] + const.
ln q̂(β) = Eq̂(w),q̂(α) [ln p(y, w, β)] + const.

The joint of distribution of y, w and α is


p(y, w, α, β) = p(y|w, β)p(w|α)p(α)p(β) ⇒
ln p(y, w, α, β) = ln p(y|w, β) + ln p(w|α) + ln p(α) + ln p(β)
where
N β
ln p(y|w, β) = ln β − (wx − y)T (wx − y) + const.
2 2
1 α 2
ln p(w|α) = ln α − w + const.
2 2
ln p(α) = (a0 − 1) ln α − b0 α + const.
ln p(β) = (c0 − 1) ln β − d0 β + const.
where we have included everything in the constant terms that neither depends on w, α nor β.
We start with q̂(α). These will be the same as in the lecture
ln q̂(α) = Eq̂(w),q̂(β) [ln p(y, w, α, β)] + const.
= Eq̂(w),q̂(β) [ln p(y|w, β) + ln p(w|α) + ln p(α) + ln p(β)] + const.
= ln p(α) + Eq̂(w) [ln p(w|α)] + const.
1 α
= (a0 − 1) ln α − b0 α + ln α − Eq̂(w) [w2 ] + const..
2 2
Solutions 10. Variational inference & Expectation propagation 36

We recognize this as a Gamma distribution


ln q̂(α) = ln Gam (α; aN , bN ) = (aN − 1) ln α − bN α
with
1
aN = a0 + ,
2
1
bN = b0 + Eq̂(w) [w2 ].
2
Now we proceed with q̂(w).
ln q̂(w) = Eq̂(α),q̂(β) [ln p(y, w, α, β)] + const.
= Eq̂(α),q̂(β) [ln p(y|w, β) + ln p(w|α) + ln p(α) + ln p(β)] + const.
= Eq̂(β) [ln p(y|w, β)] + Eq̂(α) [ln p(w|α)] + const.
1 1
= − Eq̂(β) [β](wx − y)T (wx − y) − E[α]w2 + const.
2 2
1
= − (Eq̂(α) [α] + Eq̂(β) [β]xT x)w2 + Eq̂(β) [β]xT yw + const.
2
(w − mN )2
=− 2 + const.
2σN
where
2
σN = (Eq̂(α) [α] + Eq̂(β) [β]xT x)−1 ,
2 T
mN = Eq̂(β) [β]σN x y.
2

So we have q̂(w) = N w; mN , σN .
We proceed with q(β).
ln q̂(β) = Eq̂(w),q̂(α) [ln p(y, w, α, β)] + const.
= Eq̂(w),q̂(α) [ln p(y|w, β) + ln p(w|α) + ln p(α) + ln p(β)] + const.
= ln p(β) + Eq̂(w) [ln p(y|w, β)] + const.
N β
= (c0 − 1) ln β − d0 β + ln β − Eq̂(w) [(wx − y)T (wx − y)] + const.
2 2
where
Eq̂(w) [(wx − y)T (wx − y)] = Eq̂(w) [w2 ]xT x − 2Eq̂(w) [w]xT y + yT y
= (m2N + σN
2
)xT x − 2mN xT y + yT y
= ky − mN xk2 + σN
2 T
x x.
This gives
N β 
ln q̂(β) = (c0 − 1) ln β − d0 β + ln β − ky − mN xk2 + σN
2 T
x x + const..
2 2
We recognize this also as a Gamma distribution
ln q̂(β) = ln Gam (β; cN , dN ) = (cN − 1) ln β − dN β
with
N
cN = c0 + ,
2
1  
dN = d0 + ky − mN xk2 + σN
2 T
x x .
2
Since  
2
q̂(α) = Gam (α; aN , bN ) , q̂(β) = Gam (β; cN , dN ) , q̂(w) = N w; mN , σN
we can compute
aN cN
Eq̂(α) [α] =, Eq̂(β) [β] = , Eq̂(w) [w2 ] = m2N + σN
2
.
bN dN
Now we can state the equations we need to iterate.
Solution: Iterate the following three steps until convergence
Solutions 10. Variational inference & Expectation propagation 37

• Compute
1
aN = a0 + ,
2
1 2 2
bN = b0 + (mN + σN ).
2
• Compute
N
cN = c0 + ,
2
1  
dN = d0 + ky − mN xk2 + σN
2 T
x x .
2
• Compute
 −1
2 aN cN T
σN = + x x ,
bN dN
aN 2 T
mN = σ x y.
bN N

b) Variational approximation
q(w, α, β) = q(w)q(α)q(β).

The two equations we will iterate are

ln q̂(w) = Eq̂(α),q̂(β) [ln p(y, w, α)] + const.


ln q̂(α) = Eq̂(w),q̂(β) [ln p(y, w, α)] + const.
ln q̂(β) = Eq̂(w),q̂(α) [ln p(y, w, β)] + const..

The joint of distribution of y, w and α is

p(y, w, α, β) = p(y|w, β)p(w|α)p(α)p(β) ⇒


ln p(y, w, α, β) = ln p(y|w, β) + ln p(w|α) + ln p(α) + ln p(β)

where
N β
ln p(y|w, β) = ln β − (Xw − y)T (Xw − y) + const.
2 2
M α
ln p(w|α) = ln α − wT w + const.
2 2
ln p(α) = (a0 − 1) ln α − b0 α + const.
ln p(β) = (c0 − 1) ln β − d0 β + const.

where M is the dimension of w and where we have included everything in the constant terms that neither depends on
w, α nor β.
We start with q̂(α). These will be the same as in the lecture (but now multivariate)

ln q̂(α) = Eq̂(w),q̂(β) [ln p(y, w, α, β)] + const.


= Eq̂(w),q̂(β) [ln p(y|w, β) + ln p(w|α) + ln p(α) + ln p(β)] + const.
= ln p(α) + Eq̂(w) [ln p(w|α)] + const.
M α
= (a0 − 1) ln α − b0 α + ln α − Eq̂(w) [wT w] + const.
2 2
We recognize this as a Gamma distribution

ln q̂(α) = ln Gam (α; aN , bN ) = (aN − 1) ln α − bN α

with
Solutions 10. Variational inference & Expectation propagation 38

M
aN = a0 + ,
2
1
bN = b0 + Eq̂(w) [wT w].
2

Now we proceed with q̂(w).

ln q̂(w) = Eq̂(α),q̂(β) [ln p(y, w, α, β)] + const.


= Eq̂(α),q̂(β) [ln p(y|w, β) + ln p(w|α) + ln p(α) + ln p(β)] + const.
= Eq̂(β) [ln p(y|w, β)] + Eq̂(α) [ln p(w|α)] + const.
1 1
= − Eq̂(β) [β](Xw − y)T (Xw − y) − E[α]wT w + const.
2 2
1 T
= − w (Eq̂(α) [α]IM + Eq̂(β) [β]X X)w + Eq̂(β) [β]wT XT y + const.
T
2
1
= − (w − mN )T S−1 N (w − mN ) + const..
2
This results in:

SN = (Eq̂(α) [α]IM + Eq̂(β) [β]XT X)−1 ,


mN = Eq̂(β) [β]SN XT y.

So we have q̂(w) = N (w; mN , SN ).


We proceed with q(β)

ln q̂(β) = Eq̂(w),q̂(α) [ln p(y, w, α, β)] + const.


= Eq̂(w),q̂(α) [ln p(y|w, β) + ln p(w|α) + ln p(α) + ln p(β)] + const.
= ln p(β) + Eq̂(w) [ln p(y|w, β)] + const.
N β
= (c0 − 1) ln β − d0 β + ln β − Eq̂(w) [(Xw − y)T (Xw − y)] + const.
2 2
where
Eq̂(w) [(Xw − y)T (Xw − y)] = Eq̂(w) [wT XT Xw − 2wT XT y + yT y]
= Eq̂(w) [Tr(wT XT Xw) − 2wT XT y] + yT y
= Eq̂(w) [Tr(XT XwwT ) − 2wT XT y]
= Tr(XT XEq̂(w) [wwT ]) − 2Eq̂(w) [wT ]XT y + yT y
= Tr(XT X(mN mT T T T
N + SN ) − 2mN X y + y y

= mN XT XmT T T T T
N − 2mN X y + Tr(X XSN ) + y y

= ky − XmN k2 + Tr(XT XSN ).

Here we have used


Eq̂(w) [wwT ] = mN mT
N + SN .

This gives
N β 
ln q̂(β) = (c0 − 1) ln β − d0 β + ln β − ky − XmN k2 + Tr(XT XSN ) + const..
2 2

We recognize this also as a Gamma distribution

ln q̂(β) = ln Gam (β; cN , dN ) = (cN − 1) ln β − dN β

with
N
cN = c0 + ,
2
1  
dN = d0 + ky − XmN k2 + Tr(XT XSN ) .
2
Solutions 10. Variational inference & Expectation propagation 39

Since
q̂(α) = Gam (α; aN , bN ) , q̂(β) = Gam (β; cN , dN ) , q̂(w) = N (w; mN , SN )
we can compute
aN cN
Eq̂(α) [α] = , Eq̂(β) [β] = , Eq̂(w) [wT w] = mT
N mN + Tr(SN ).
bN dN
Now we can state the equations we need to iterate.
Solution: Iterate the following three steps until convergence
• Compute
M
aN = a0 + ,
2
1
bN = b0 + (mT mN + Tr(SN )).
2 N
• Compute
N
cN = c0 + ,
2
1  
dN = d0 + ky − XmN k2 + Tr(XT XSN ) .
2
• Compute
 −1
aN cN T
SN = IN + X X
bN dN
aN
mN = SN XT y.
bN

Solution to Exercise 10.4 a)


N
Y
p(y1:N , x0:N , γ) = p(γ)p(x0 ) p(xn |xn−1 , γ)p(yn |xn )
n=1

where
 
p(x0 ) = N x0 ; µ0 , β0−1 ,
 
−1
p(γ) = N γ; 0, βγ0 ,
 
p(xn |xn−1 , γ) = N xn ; γxn−1 , βv−1 ,
 
p(yn |xn ) = N yn ; xn , βe−1 .

b) First specify the mean field equations

ln q(γ) = E{qn }N
n=1
[ln p(y1:N , x0:N , γ)] + const.,
ln q(xm ) = Eγ,{qn }n6=m [ln p(y1:N , x0:N , γ)] + const..
Solutions 10. Variational inference & Expectation propagation 40

We start with q(γ).

ln q(γ) = E{qn }N
n=0
[ln p(y1:N , x0:N , γ)] + const.
 
N
βv X
= ln p(γ) − E N ln p(xn |xn−1 , γ) + const.
2 {qn }n=0 n=1

N
βv X h i
= ln p(γ) − Eqn ,qn−1 (xn − xn−1 γ)2 + const.
2 n=1
N
βv X h i
= ln p(γ) − Eqn ,qn−1 −2xn xn−1 γ + x2n−1 γ 2 + const.
2 n=1
N
βv X  
= ln p(γ) − −2x̄n x̄n−1 γ + x2n−1 γ 2 + const.
2 n=1
N
!2
βv X 2 2 x̄n x̄n−1
= ln p(γ) − x γ − + const.
2 n=1 n−1 x2n−1
N
!
X x̄n x̄n−1 −1
= ln p(γ) + ln N γ; , (βv x2n−1 ) + const.
n=1 x2n−1

where
x̄n = Eqn [xn ], x2n = Eqn [x2n ].

This gives  
q(γ) = N γ; γ̄, βγ−1

where
N
X
βγ = βγ0 + βv x2n−1 ,
n=1
   
N N
1  X x̄n x̄n−1 β v 
X
γ̄ = βv x2n−1  = x̄n x̄n−1  .
βγ n=1 x2n−1 βγ n=1

Proceed similarly to derive the update equations for ln q(xm ).


c) The update equations for q(γ) will be the same except that we need to replace x̄n x̄n−1 with xn xn−1 =
E{qn qn−1 } [xn xn−1 ].
The update equations for ln q(xm ) will result in a Kalman smoother (not in this course, but you might familiar with it
if you know some time-series modeling or control theory). By considering an augmented system x̃n = [xn , xn−1 ]
to run the Kalman smoother on, we can get the expressions required above.
Bibliography

[1] David Barber. Bayesian reasoning and machine learning. Cambridge University Press, 2012.
[2] Christopher M Bishop. Pattern recognition and machine learning. springer, 2006.
[3] Carl E. Rrasmussen. Gaussian Processes for machine learning. MIT press, 2006.

Version: September 27, 2021

You might also like