Professional Documents
Culture Documents
Variational AutoEncoders (VAE) With PyTorch - Alexander Van de Kleut
Variational AutoEncoders (VAE) With PyTorch - Alexander Van de Kleut
manual_seed(0)
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.distributions
import torchvision
import numpy as np
import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200
Encoder torch.nn.Module
__init__ forward
class Encoder(nn.Module):
def __init__(self, latent_dims):
super(Encoder, self).__init__()
self.linear1 = nn.Linear(784, 512)
self.linear2 = nn.Linear(512, latent_dims)
Decoder
class Decoder(nn.Module):
def __init__(self, latent_dims):
super(Decoder, self).__init__()
self.linear1 = nn.Linear(latent_dims, 512)
self.linear2 = nn.Linear(512, 784)
Autoencoder
class Autoencoder(nn.Module):
def __init__(self, latent_dims):
super(Autoencoder, self).__init__()
self.encoder = Encoder(latent_dims)
self.decoder = Decoder(latent_dims)
latent_dims = 2
autoencoder = Autoencoder(latent_dims).to(device) # GPU
data = torch.utils.data.DataLoader(
torchvision.datasets.MNIST('./data',
transform=torchvision.transforms.ToTensor(),
download=True),
batch_size=128,
shuffle=True)
plot_latent(autoencoder, data)
def plot_reconstructed(autoencoder, r0=(-5, 10), r1=(-10, 5), n=12):
w = 28
img = np.zeros((n*w, n*w))
for i, y in enumerate(np.linspace(*r1, n)):
for j, x in enumerate(np.linspace(*r0, n)):
z = torch.Tensor([[x, y]]).to(device)
x_hat = autoencoder.decoder(z)
x_hat = x_hat.reshape(28, 28).to('cpu').detach().numpy()
img[(n-1-i)*w:(n-1-i+1)*w, j*w:(j+1)*w] = x_hat
plt.imshow(img, extent=[*r0, *r1])
plot_reconstructed(autoencoder)
plot_reconstructed
Decoder
Encoder
class VariationalEncoder(nn.Module):
def __init__(self, latent_dims):
super(VariationalEncoder, self).__init__()
self.linear1 = nn.Linear(784, 512)
self.linear2 = nn.Linear(512, latent_dims)
self.linear3 = nn.Linear(512, latent_dims)
self.N = torch.distributions.Normal(0, 1)
self.N.loc = self.N.loc.cuda() # hack to get sampling on the GPU
self.N.scale = self.N.scale.cuda()
self.kl = 0
Encoder
VariationalEncoder
class VariationalAutoencoder(nn.Module):
def __init__(self, latent_dims):
super(VariationalAutoencoder, self).__init__()
self.encoder = VariationalEncoder(latent_dims)
self.decoder = Decoder(latent_dims)
autoencoder.encoder.kl
def train(autoencoder, data, epochs=20):
opt = torch.optim.Adam(autoencoder.parameters())
for epoch in range(epochs):
for x, y in data:
x = x.to(device) # GPU
opt.zero_grad()
x_hat = autoencoder(x)
loss = ((x - x_hat)**2).sum() + autoencoder.encoder.kl
loss.backward()
opt.step()
return autoencoder
plot_latent(vae, data)
plot_reconstructed(vae, r0=(-3, 3), r1=(-3, 3))
def interpolate(autoencoder, x_1, x_2, n=12):
z_1 = autoencoder.encoder(x_1)
z_2 = autoencoder.encoder(x_2)
z = torch.stack([z_1 + (z_2 - z_1)*t for t in np.linspace(0, 1, n)])
interpolate_list = autoencoder.decoder(z)
interpolate_list = interpolate_list.to('cpu').detach().numpy()
w = 28
img = np.zeros((w, n*w))
for i, x_hat in enumerate(interpolate_list):
img[:, i*w:(i+1)*w] = x_hat.reshape(28, 28)
plt.imshow(img)
plt.xticks([])
plt.yticks([])
interpolate_list = autoencoder.decoder(z)
interpolate_list = interpolate_list.to('cpu').detach().numpy()*255
images_list[0].save(
f'{filename}.gif',
save_all=True,
append_images=images_list[1:],
loop=1)
Name
T
Timilehin Ayanlade − ⚑
7 months ago edited
Great post Alexandar. I believe there is an oversight in the architecture for VAE. the sigma
symbol in particular. Here is an edited image you could easily replace it with
https://drive.google.com/ ...
1 0 Reply • Share ›
fairlix − ⚑
6 months ago edited
Hey there,
0 0 Reply • Share ›
Boris Burkov − ⚑
8 months ago
I think, I noticed a little mistake: the picture, illustrating VAE has 2 vectors of expectation
instead of a vector of expectation and a vector of variance. Cheers!
0 0 Reply • Share ›
D Daniel Kleine − ⚑
10 months ago edited
Can you please x the equation after "(...) which is given by" in the text?
0 0 Reply • Share ›
0 0 Reply • Share ›
Alan − ⚑
2 years ago
0 0 Reply • Share ›