Variational Autoencoder with implementation in TensorFlow and Keras

Do not miss this exclusive book on Binary Tree Problems. Get it now for free.

In this article at OpenGenus, we will explore the variational autoencoder, a type of autoencoder along with its implementation using TensorFlow and Keras.

Table of content:

  1. What is an Autoencoder
  2. What is a Variational Autoencoder
  3. Its implementation with tensorflow and keras.
  4. Major Drawback of a variational autoencoder

Alright, Let's get started.

What is an Autoencoder

A unique kind of neural network called an autoencoder is taught to replicate its input in its output. For instance, an autoencoder will encode a handwritten digit image into a lower dimensional latent representation, then decode the latent representation back into the original image. An autoencoder gains the ability to minimize the reconstruction error while compressing the data.

What is a Variational Autoencoder

A variant of the autoencoder, or variational autoencoder (VAE), is a probabilistic model that compresses high-dimensional input data into a more manageable representation. A VAE transforms the input data into probability distribution parameters, such as the mean and variance of a Gaussian distribution, as opposed to a typical autoencoder, which maps the input onto a latent vector. This method creates a continuous, structured latent space that is helpful for creating images.

Its implementation in Tensorflow and Keras

For this example, we'll use the MNIST dataset for fashion.
Each MNIST image starts out as a vector of 784 integers, each of which indicates the intensity of a pixel and ranges from 0-255. In our model, we model each pixel using a Bernoulli distribution and statically binarize the data.

from IPython import display

import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import PIL
import tensorflow as tf
import tensorflow_probability as tfp
import time

Load the Fashion MNIST dataset

(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()

binarizing function

def preprocess_images(images):
  images = images[..., tf.newaxis] / 255.
  return np.where(images > .5, 1.0, 0.0).astype('float32')

train_images = preprocess_images(train_images)
test_images = preprocess_images(test_images)
train_size = 60000
batch_size = 32
test_size = 10000

Batch and shuffle the dataset with tf.data

train_dataset = (tf.data.Dataset.from_tensor_slices(train_images)
                 .shuffle(train_size).batch(batch_size))
test_dataset = (tf.data.Dataset.from_tensor_slices(test_images)
                .shuffle(test_size).batch(batch_size))

The Variational Autoencoder comprises of both an encoder and a decoder.

Encoder network

By using an observation x as input and a collection of parameters to characterize the conditional distribution of the latent representation z, the approximate posterior distribution q(x|z) is defined. The network produces the mean and log-variance parameters of a factorized Gaussian in this example if the distribution is simply modelled as a diagonal Gaussian. For numerical stability, output log-variance rather than the actual variance.

Decoder network

With a latent sample z as its input and the parameters for a conditional distribution of the observation as its output, this establishes the conditional distribution of the observation, or p(x|z). Model the prior p(z) of the latent distribution as a unit Gaussian.

Reparameterization trick

The sampling method used by the encoder creates a bottleneck because backpropagation cannot flow through a random node.

To address this, a reparameterization trick can be used. In our example, you approximate z using the decoder parameters and another parameter (epsilon).

This is implemented by the addition of the mean and matrix multiplication of the standard dev. and epsilon. Epsilon can be thought of as a random noise used to maintain stochasticity of z.

Network Architecture

Use two convolutional layers for the encoder network, followed by a fully-connected layer. Use a fully connected layer, three convolution transpose layers, and three deconvolutional layers in the decoder network to replicate this architecture. It should be noted that batch normalization should not be used when training VAEs as the usage of mini-batches may introduce additional stochasticity that can lead to instability on top of the stochasticity from sampling.

class CVAE(tf.keras.Model):
  """Convolutional variational autoencoder."""

  def __init__(self, latent_dim):
    super(CVAE, self).__init__()
    self.latent_dim = latent_dim
    self.encoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
            tf.keras.layers.Conv2D(
                filters=32, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Conv2D(
                filters=64, kernel_size=3, strides=(2, 2), activation='relu'),
            tf.keras.layers.Flatten(),
            # No activation
            tf.keras.layers.Dense(latent_dim + latent_dim),
        ]
    )

    self.decoder = tf.keras.Sequential(
        [
            tf.keras.layers.InputLayer(input_shape=(latent_dim,)),
            tf.keras.layers.Dense(units=7*7*32, activation=tf.nn.relu),
            tf.keras.layers.Reshape(target_shape=(7, 7, 32)),
            tf.keras.layers.Conv2DTranspose(
                filters=64, kernel_size=3, strides=2, padding='same',
                activation='relu'),
            tf.keras.layers.Conv2DTranspose(
                filters=32, kernel_size=3, strides=2, padding='same',
                activation='relu'),
            # No activation
            tf.keras.layers.Conv2DTranspose(
                filters=1, kernel_size=3, strides=1, padding='same'),
        ]
    )

  @tf.function
  def sample(self, eps=None):
    if eps is None:
      eps = tf.random.normal(shape=(100, self.latent_dim))
    return self.decode(eps, apply_sigmoid=True)
    
  # encodes the input image
  def encode(self, x):
    mean, logvar = tf.split(self.encoder(x), num_or_size_splits=2, axis=1)
    return mean, logvar
    
  # performs the reparameterization trick  
  def reparam(self, mean, logvar):
    eps = tf.random.normal(shape=mean.shape)
    return eps * tf.exp(logvar * .5) + mean
    
  # decodes the sample back to an image output
  def decode(self, z, apply_sigmoid=False):
    logits = self.decoder(z)
    if apply_sigmoid:
      probs = tf.sigmoid(logits)
      return probs
    return logits

Loss and Optimizer

The evidence lower bound (ELBO) on the marginal log-likelihood is maximized by VAEs throughout training.
Optimize the single sample Monte Carlo estimation of this expectation in real-world applications:

log p(x|z) + log p(z) - log q(z|x)

where z is sampled for q(z|x).

optimizer = tf.keras.optimizers.Adam(1e-4)


def log_normal_pdf(sample, mean, logvar, raxis=1):
  log2pi = tf.math.log(2. * np.pi)
  return tf.reduce_sum(
      -.5 * ((sample - mean) ** 2. * tf.exp(-logvar) + logvar + log2pi),
      axis=raxis)


def compute_loss(model, x):
  mean, logvar = model.encode(x)
  z = model.reparam(mean, logvar)
  x_logit = model.decode(z)
  cross_ent = tf.nn.sigmoid_cross_entropy_with_logits(logits=x_logit, labels=x)
  logpx_z = -tf.reduce_sum(cross_ent, axis=[1, 2, 3])
  logpz = log_normal_pdf(z, 0., 0.)
  logqz_x = log_normal_pdf(z, mean, logvar)
  return -tf.reduce_mean(logpx_z + logpz - logqz_x)


@tf.function
def train_step(model, x, optimizer):
  """Executes one training step and returns the loss.

  This function computes the loss and gradients, and uses the latter to
  update the model's parameters.
  """
  with tf.GradientTape() as tape:
    loss = compute_loss(model, x)
  gradients = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(gradients, model.trainable_variables))

Training and Generating Images

For the training process we're going to be using a training loop which uses image generating function to generate images after each epoch.

epochs = 10
# set the dimensionality of the latent space to a plane for visualization later
latent_dim = 2
num_examples_to_generate = 16

# keeping the random vector constant for generation (prediction) so
# it will be easier to see the improvement.
random_vector_for_generation = tf.random.normal(
    shape=[num_examples_to_generate, latent_dim])
model = CVAE(latent_dim)

Image generating function

def generate_and_save_images(model, epoch, test_sample):
  mean, logvar = model.encode(test_sample)
  z = model.reparameterize(mean, logvar)
  predictions = model.sample(z)
  fig = plt.figure(figsize=(4, 4))

  for i in range(predictions.shape[0]):
    plt.subplot(4, 4, i + 1)
    plt.imshow(predictions[i, :, :, 0], cmap='gray')
    plt.axis('off')

  # tight_layout minimizes the overlap between 2 sub-plots
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()
for test_batch in test_dataset.take(1):
  test_sample = test_batch[0:num_examples_to_generate, :, :, :]

Training loop

generate_and_save_images(model, 0, test_sample)

for epoch in range(1, epochs + 1):
  start_time = time.time()
  for train_x in train_dataset:
    train_step(model, train_x, optimizer)
  end_time = time.time()

  loss = tf.keras.metrics.Mean()
  for test_x in test_dataset:
    loss(compute_loss(model, test_x))
  elbo = -loss.result()
  display.clear_output(wait=False)
  print('Epoch: {}, Test set ELBO: {}, time elapse for current epoch: {}'
        .format(epoch, elbo, end_time - start_time))
  generate_and_save_images(model, epoch, test_sample)

Major Drawback of a variational autoencoder

Variational autoencoders' major flaw is their propensity to produce hazy, irrational outputs(blurry images). This has to do with how VAEs compute loss functions and recover data distributions.

With this article at OpenGenus, you must have the complete idea of Variational autoencoders.

Sign up for FREE 3 months of Amazon Music. YOU MUST NOT MISS.