You are on page 1of 17

import torch; torch.

manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200

device = 'cuda' if torch.cuda.is_available() else 'cpu'

Encoder torch.nn.Module
__init__ forward

class Encoder(nn.Module):
def __init__(self, latent_dims):
super(Encoder, self).__init__()
self.linear1 = nn.Linear(784, 512)
self.linear2 = nn.Linear(512, latent_dims)

def forward(self, x):


x = torch.flatten(x, start_dim=1)
x = F.relu(self.linear1(x))
return self.linear2(x)

Decoder

class Decoder(nn.Module):
def __init__(self, latent_dims):
super(Decoder, self).__init__()
self.linear1 = nn.Linear(latent_dims, 512)
self.linear2 = nn.Linear(512, 784)

def forward(self, z):


z = F.relu(self.linear1(z))
z = torch.sigmoid(self.linear2(z))
return z.reshape((-1, 1, 28, 28))

Autoencoder
class Autoencoder(nn.Module):
def __init__(self, latent_dims):
super(Autoencoder, self).__init__()
self.encoder = Encoder(latent_dims)
self.decoder = Decoder(latent_dims)

def forward(self, x):


z = self.encoder(x)
return self.decoder(z)

def train(autoencoder, data, epochs=20):


opt = torch.optim.Adam(autoencoder.parameters())
for epoch in range(epochs):
for x, y in data:
x = x.to(device) # GPU
opt.zero_grad()
x_hat = autoencoder(x)
loss = ((x - x_hat)**2).sum()
loss.backward()
opt.step()
return autoencoder

latent_dims = 2
autoencoder = Autoencoder(latent_dims).to(device) # GPU

data = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data',
transform=torchvision.transforms.ToTensor(),
download=True),
batch_size=128,
shuffle=True)

autoencoder = train(autoencoder, data)


def plot_latent(autoencoder, data, num_batches=100):
for i, (x, y) in enumerate(data):
z = autoencoder.encoder(x.to(device))
z = z.to('cpu').detach().numpy()
plt.scatter(z[:, 0], z[:, 1], c=y, cmap='tab10')
if i > num_batches:
plt.colorbar()
break

plot_latent(autoencoder, data)
def plot_reconstructed(autoencoder, r0=(-5, 10), r1=(-10, 5), n=12):
w = 28
img = np.zeros((n*w, n*w))
for i, y in enumerate(np.linspace(*r1, n)):
for j, x in enumerate(np.linspace(*r0, n)):
z = torch.Tensor([[x, y]]).to(device)
x_hat = autoencoder.decoder(z)
x_hat = x_hat.reshape(28, 28).to('cpu').detach().numpy()
img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
plt.imshow(img, extent=[*r0, *r1])

plot_reconstructed(autoencoder)
plot_reconstructed
Decoder
Encoder
class VariationalEncoder(nn.Module):
def __init__(self, latent_dims):
super(VariationalEncoder, self).__init__()
self.linear1 = nn.Linear(784, 512)
self.linear2 = nn.Linear(512, latent_dims)
self.linear3 = nn.Linear(512, latent_dims)

self.N = torch.distributions.Normal(0, 1)
self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
self.N.scale = self.N.scale.cuda()
self.kl = 0

def forward(self, x):


x = torch.flatten(x, start_dim=1)
x = F.relu(self.linear1(x))
mu = self.linear2(x)
sigma = torch.exp(self.linear3(x))
z = mu + sigma*self.N.sample(mu.shape)
self.kl = (sigma**2 + mu**2 - torch.log(sigma) - 1/2).sum()
return z

Encoder
VariationalEncoder

class VariationalAutoencoder(nn.Module):
def __init__(self, latent_dims):
super(VariationalAutoencoder, self).__init__()
self.encoder = VariationalEncoder(latent_dims)
self.decoder = Decoder(latent_dims)

def forward(self, x):


z = self.encoder(x)
return self.decoder(z)

autoencoder.encoder.kl
def train(autoencoder, data, epochs=20):
opt = torch.optim.Adam(autoencoder.parameters())
for epoch in range(epochs):
for x, y in data:
x = x.to(device) # GPU
opt.zero_grad()
x_hat = autoencoder(x)
loss = ((x - x_hat)**2).sum() + autoencoder.encoder.kl
loss.backward()
opt.step()
return autoencoder

vae = VariationalAutoencoder(latent_dims).to(device) # GPU


vae = train(vae, data)

plot_latent(vae, data)
plot_reconstructed(vae, r0=(-3, 3), r1=(-3, 3))
def interpolate(autoencoder, x_1, x_2, n=12):
z_1 = autoencoder.encoder(x_1)
z_2 = autoencoder.encoder(x_2)
z = torch.stack([z_1 + (z_2 - z_1)*t for t in np.linspace(0, 1, n)])
interpolate_list = autoencoder.decoder(z)
interpolate_list = interpolate_list.to('cpu').detach().numpy()

w = 28
img = np.zeros((w, n*w))
for i, x_hat in enumerate(interpolate_list):
img[:, i*w:(i+1)*w] = x_hat.reshape(28, 28)
plt.imshow(img)
plt.xticks([])
plt.yticks([])

x, y = data.__iter__().next() # hack to grab a batch


x_1 = x[y == 1][1].to(device) # find a 1
x_2 = x[y == 0][1].to(device) # find a 0

interpolate(vae, x_1, x_2, n=20)

interpolate(autoencoder, x_1, x_2, n=20)


from PIL import Image

def interpolate_gif(autoencoder, filename, x_1, x_2, n=100):


z_1 = autoencoder.encoder(x_1)
z_2 = autoencoder.encoder(x_2)

z = torch.stack([z_1 + (z_2 - z_1)*t for t in np.linspace(0, 1, n)])

interpolate_list = autoencoder.decoder(z)
interpolate_list = interpolate_list.to('cpu').detach().numpy()*255

images_list = [Image.fromarray(img.reshape(28, 28)).resize((256, 256)) for img in


interpolate_list]
images_list = images_list + images_list[::-1] # loop back beginning

images_list[0].save(
f'{filename}.gif',
save_all=True,
append_images=images_list[1:],
loop=1)

interpolate_gif(vae, "vae", x_1, x_2)


5 Comments 1 Login

G Join the discussion…

LOG IN WITH OR SIGN UP WITH DISQUS ?

Name

 4 Share Best Newest Oldest

T
Timilehin Ayanlade − ⚑
7 months ago edited

Great post Alexandar. I believe there is an oversight in the architecture for VAE. the sigma
symbol in particular. Here is an edited image you could easily replace it with
https://drive.google.com/ ...

1 0 Reply • Share ›

fairlix − ⚑
6 months ago edited

Hey there,

really helpful. This is the rst time I grasp VAEs.

I spotted a small oversight in the VAE architecture image:


μ(x) is there two times whil I think one should be labelled μ(x) and the other one σ(x).

edit: Oh Timilehin Ayanlade commented the same...

0 0 Reply • Share ›

Boris Burkov − ⚑
8 months ago

Dear Alexander, thank you for a great post.

I think, I noticed a little mistake: the picture, illustrating VAE has 2 vectors of expectation
instead of a vector of expectation and a vector of variance. Cheers!

0 0 Reply • Share ›

D Daniel Kleine − ⚑
10 months ago edited

Great article, thanks!

Can you please x the equation after "(...) which is given by" in the text?

0 0 Reply • Share ›
0 0 Reply • Share ›

Alan − ⚑
2 years ago

This is very accessible and I really enjoyed the visualizations, thanks!

0 0 Reply • Share ›

Subscribe Privacy Do Not Sell My Data

You might also like