Understanding a Recurrent Neural Network for Image Generation

logo-dark
Shugert Marketing & Analytcs

The purpose of this post is to implement and understand Google Deepmind’s paper, DRAW: A Recurrent Neural Network For Image Generation.

The purpose of this post is to implement and understand Google Deepmind’s paper DRAW: A Recurrent Neural Network For Image Generation. The code is based on the work of Eric Jang, who in his original code was able to achieve the implementation in only 158 lines of Python code.

Let’s Begin by Explaining What Draw Stands for…

Deep Recurrent Attentive Writer (DRAW) is a neural network architecture for image generation. DRAW networks combine a novel spatial attention mechanism that mimics the foveation of the human eye, with a sequential variational auto-encoding framework that allows for the iterative construction of complex images.

The system substantially improves on the state of the art for generative models on MNIST, and, when trained on the Street View House Numbers dataset, it generates images that cannot be distinguished from real data with the naked eye.
The core of the DRAW architecture is a pair of recurrent neural networks: an encoder network that compresses the real images presented during training, and a decoder that reconstitutes images after receiving codes. The combined system is trained end-to-end with stochastic gradient descent, where the loss function is a variational upper bound on the log-likelihood of the data.

DRAW Architecture

DRAW Network is similar to other variational auto-encoders, it contains an encoder network that determines a distribution over latent codes that capture salient information about the input data and a decoder network receives samples from the code distribution and uses them to condition its own distribution over images.

Key Differences Between DRAW and Auto-Encoders

Both the encoder and decoder are recurrent networks in DRAW. Decoder’s output is added successively to the distribution in order to generate the data, instead of generating this the distribution in single steps. A dynamically updated attention mechanism is used to restrict both the input region observed by the encoder and the output region modified by the decoder. In simple terms, the network decides at each time-step “where to read” and “where to write” as well as “what to write”.

Left: Conventional Variational Auto-Encoder

During generation, a sample z is drawn from a prior P(z) and passed through the feedforward decoder network to compute the probability of the input P(x|z) given the sample.
During inference the input x is passed to the encoder network, producing an approximate posterior Q(z|x) over latent variables. During training, z is sampled from Q(z|x) and then used to compute the total description length KL ( Q (Z|x)∣∣ P(Z)−log(P(x|z)), which is minimized with stochastic gradient descent.

Right: DRAW Network

At each time-step a sample z_t from the prior P(z_t) is passed to the recurrent decoder network, which then modifies part of the canvas matrix. The final canvas matrix cT is used to compute P(x|z_1:T).
During inference the input is read at every time-step and the result is passed to the encoder RNN. The RNNs at the previous time-step specify where to read. The output of the encoder RNN is used to compute the approximate posterior over the latent variables at that time-step.

Loss Function

The final canvas matrix cT is used to parametrize a model D(X | cT) of the input data. If the input is binary, the natural choice for D is a Bernoulli distribution with means given by σ(cT). The reconstruction loss Lx is defined as the negative log probability of x under D:

The latent loss

for a sequence of latent distributions

is defined as the summed Kullback-Leibler divergence of some latent prior P(Z_t) from

Note that this loss depends upon the latent samples z_t drawn from

which depend in turn on the input x. If the latent distribution is a diagonal Gaussian with μt, σt where:

a simple choice for P(Z_t) is a standard Gaussian with mean zero and standard deviation one, in which case the equation becomes:

The total loss L for the network is the expectation of the sum of the reconstruction and latent losses:

Which we optimize using a single sample of z for each stochastic gradient descent step.
L^z can be interpreted as the number of nats required to transmit the latent sample sequence z_1:T to the decoder from the prior, and (if x is discrete) L^x is the number of nats required for the decoder to reconstruct x given z_1:T. The total loss is therefore equivalent to the expected compression of the data by the decoder and prior.

Improving Images

As Eric Jang mentions in his post, it’s easier to ask our neural network to merely “improve the image” rather than “finish the image in one shot”. Human artists work by iterating on their canvas and infer from their drawing what to fix and what to paint next. Improving an image or progressive refinement is simply breaking up our joint distribution P(C) over and over again, resulting in a chain of latent variables C1,C2,…CT−1 to a new observed variable distribution P(CT).

The trick is to sample from the iterative refinement distribution P(Ct|Ct−1)several times rather than straight-up sampling from P(C).
In the DRAW model, P(Ct|Ct−1) is the same distribution for all t, so we can compactly represent this as the following recurrence relation (if not, then we have a Markov Chain instead of a recurrent network)

The DRAW Model Applied

Imagine you are trying to encode an image of the number 8. Every handwritten number is drawn differently, while some portions may be thicker others can be longer. Without attention, the encoder would be forced to try and capture all these small variations at the same time.
But, what about if the encoder could choose a small crop of the image on every frame and examine each portion of the number one at a time? That would make the work more easy, right? The same logic applies to generate the number. The attention unit will determine where to draw the next portion of the number 8, or any other, while the latent vector passed will determine if the decoder generates a thicker area or a thinner area.

Basically, if we think of the latent code in a VAE (variational auto-encoder) as a vector that represents the entire image, the latent codes in DRAW can be thought of as vectors that represent a pen stroke. Eventually, a sequence of these vectors creates a recreation of the original image.

Ok, but How Does It Really Work?

In a recurrent VAE model, the encoder takes in the entire input image at every single timestep. In DRAW we need to focus on the attention gate between the two of them, so the encoder only receives the portion of our image that the network deems is important at that timestep. That first attention gate is called the read attention. The read attention consists of two parts: choosing the important portion and cropping the image.

Choosing the Important Portion of an Image

In order to determine which part of the image to focus on, we need some sort of observation to make a decision. In DRAW, we use the previous timestep’s decoder hidden state. Using a simple fully-connected layer, we can map the hidden state to three parameters that represent our square crop: center x, center y, and the scale.

Cropping the Image

Now, instead of encoding the entire image, we crop it so only a small part of the image is encoded. This code is then passed through the system and decoded back into a small patch. We now arrive at the second part of our attention gate, the write attention, which has the same setup as the read section, except that the write attention gate uses the current decoder instead of the previous timestep’s decoder.

Wait… Is That Really Done in Practice?

While describing the attention mechanism as a crop makes sense intuitively, in practice, a different method is used. The model structure described above is still accurate, but a matrix of Gaussian filters instead of a crop is used. In DRAW, we take an array of Gaussian filters, each with their centers spaced apart evenly.

Show Me the Money… or the Code Instead

We will use Eric Jang’s code as a base but we will clean it up a bit and comment it in order to make it more easy to understand.

# first we import our libraries
import tensorflow as tf
from tensorflow.examples.tutorials import mnist
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import scipy.misc
import os

Eric provided us with some great functions that will help us build our read and write attention gates as well as a function to filter the initial state that we will use below. But first, we will need to add new functions that will allow us to create a dense layer as well as merge the images and save them into our local machine for our updated code.

# fully-conected layer
def dense(x, inputFeatures, outputFeatures, scope=None, with_w=False):
with tf.variable_scope(scope or “Linear”):
matrix = tf.get_variable(“Matrix”, [inputFeatures, outputFeatures], tf.float32, tf.random_normal_initializer(stddev=0.02))
bias = tf.get_variable(“bias”, [outputFeatures], initializer=tf.constant_initializer(0.0))
if with_w:
return tf.matmul(x, matrix) + bias, matrix, bias
else:
return tf.matmul(x, matrix) + bias

# merge images
def merge(images, size):
h, w = images.shape[1], images.shape[2]
img = np.zeros((h * size[0], w * size[1]))
for idx, image in enumerate(images):
i = idx % size[1]
j = idx / size[1]
img[j*h:j*h+h, i*w:i*w+w] = image
return img

# save image on local machine
def ims(name, img):
# print img[:10][:10]
scipy.misc.toimage(img, cmin=0, cmax=1).save(name)

Let’s now put the code altogether for the sake of completion.

# DRAW implementation
class draw_model():
def __init__(self):

# First we download the MNIST dataset into our local machine.
self.mnist = input_data.read_data_sets(“data/”, one_hot=True)
print “————————————“
print “MNIST Dataset Succesufully Imported”
print “————————————“
self.n_samples = self.mnist.train.num_examples

# We set up the model parameters
# ——————————
# image width,height
self.img_size = 28
# read glimpse grid width/height
self.attention_n = 5
# number of hidden units / output size in LSTM
self.n_hidden = 256
# QSampler output size
self.n_z = 10
# MNIST generation sequence length
self.sequence_length = 10
# training minibatch size
self.batch_size = 64
# workaround for variable_scope(reuse=True)
self.share_parameters = False

# Build our model
self.images = tf.placeholder(tf.float32, [None, 784]) # input (batch_size * img_size)
self.e = tf.random_normal((self.batch_size, self.n_z), mean=0, stddev=1) # Qsampler noise
self.lstm_enc = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # encoder Op
self.lstm_dec = tf.nn.rnn_cell.LSTMCell(self.n_hidden, state_is_tuple=True) # decoder Op

# Define our state variables
self.cs = [0] * self.sequence_length # sequence of canvases
self.mu, self.logsigma, self.sigma = [0] * self.sequence_length, [0] * self.sequence_length, [0] * self.sequence_length

# Initial states
h_dec_prev = tf.zeros((self.batch_size, self.n_hidden))
enc_state = self.lstm_enc.zero_state(self.batch_size, tf.float32)
dec_state = self.lstm_dec.zero_state(self.batch_size, tf.float32)

# Construct the unrolled computational graph
x = self.images
for t in range(self.sequence_length):
# error image + original image
c_prev = tf.zeros((self.batch_size, self.img_size**2)) if t == 0 else self.cs[t-1]
x_hat = x – tf.sigmoid(c_prev)
# read the image
r = self.read_basic(x,x_hat,h_dec_prev)
#sanity check
print r.get_shape()
# encode to guass distribution
self.mu[t], self.logsigma[t], self.sigma[t], enc_state = self.encode(enc_state, tf.concat(1, [r, h_dec_prev]))
# sample from the distribution to get z
z = self.sampleQ(self.mu[t],self.sigma[t])
#sanity check
print z.get_shape()
# retrieve the hidden layer of RNN
h_dec, dec_state = self.decode_layer(dec_state, z)
#sanity check
print h_dec.get_shape()
# map from hidden layer
self.cs[t] = c_prev + self.write_basic(h_dec)
h_dec_prev = h_dec
self.share_parameters = True # from now on, share variables

# Loss function
self.generated_images = tf.nn.sigmoid(self.cs[-1])
self.generation_loss = tf.reduce_mean(-tf.reduce_sum(self.images * tf.log(1e-10 + self.generated_images) + (1-self.images) * tf.log(1e-10 + 1 – self.generated_images),1))

kl_terms = [0]*self.sequence_length
for t in xrange(self.sequence_length):
mu2 = tf.square(self.mu[t])
sigma2 = tf.square(self.sigma[t])
logsigma = self.logsigma[t]
kl_terms[t] = 0.5 * tf.reduce_sum(mu2 + sigma2 – 2*logsigma, 1) – self.sequence_length*0.5 # each kl term is (1xminibatch)
self.latent_loss = tf.reduce_mean(tf.add_n(kl_terms))
self.cost = self.generation_loss + self.latent_loss

# Optimization
optimizer = tf.train.AdamOptimizer(1e-3, beta1=0.5)
grads = optimizer.compute_gradients(self.cost)
for i,(g,v) in enumerate(grads):
if g is not None:
grads[i] = (tf.clip_by_norm(g,5),v)
self.train_op = optimizer.apply_gradients(grads)

self.sess = tf.Session()
self.sess.run(tf.initialize_all_variables())

# Our training function
def train(self):
for i in xrange(20000):
xtrain, _ = self.mnist.train.next_batch(self.batch_size)
cs, gen_loss, lat_loss, _ = self.sess.run([self.cs, self.generation_loss, self.latent_loss, self.train_op], feed_dict={self.images: xtrain})
print “iter %d genloss %f latloss %f” % (i, gen_loss, lat_loss)
if i % 500 == 0:

cs = 1.0/(1.0+np.exp(-np.array(cs))) # x_recons=sigmoid(canvas)

for cs_iter in xrange(10):
results = cs[cs_iter]
results_square = np.reshape(results, [-1, 28, 28])
print results_square.shape
ims(“results/”+str(i)+”-step-“+str(cs_iter)+”.jpg”,merge(results_square,[8,8]))

# Eric Jang’s main functions
# ————————–
# locate where to put attention filters on hidden layers
def attn_window(self, scope, h_dec):
with tf.variable_scope(scope, reuse=self.share_parameters):
parameters = dense(h_dec, self.n_hidden, 5)
# center of 2d gaussian on a scale of -1 to 1
gx_, gy_, log_sigma2, log_delta, log_gamma = tf.split(1,5,parameters)

# move gx/gy to be a scale of -imgsize to +imgsize
gx = (self.img_size+1)/2 * (gx_ + 1)
gy = (self.img_size+1)/2 * (gy_ + 1)

sigma2 = tf.exp(log_sigma2)
# distance between patches
delta = (self.img_size – 1) / ((self.attention_n-1) * tf.exp(log_delta))
# returns [Fx, Fy, gamma]
return self.filterbank(gx,gy,sigma2,delta) + (tf.exp(log_gamma),)

# Construct patches of gaussian filters
def filterbank(self, gx, gy, sigma2, delta):
# 1 x N, look like [[0,1,2,3,4]]
grid_i = tf.reshape(tf.cast(tf.range(self.attention_n), tf.float32),[1, -1])
# individual patches centers
mu_x = gx + (grid_i – self.attention_n/2 – 0.5) * delta
mu_y = gy + (grid_i – self.attention_n/2 – 0.5) * delta
mu_x = tf.reshape(mu_x, [-1, self.attention_n, 1])
mu_y = tf.reshape(mu_y, [-1, self.attention_n, 1])
# 1 x 1 x imgsize, looks like [[[0,1,2,3,4,…,27]]]
im = tf.reshape(tf.cast(tf.range(self.img_size), tf.float32), [1, 1, -1])
# list of gaussian curves for x and y
sigma2 = tf.reshape(sigma2, [-1, 1, 1])
Fx = tf.exp(-tf.square((im – mu_x) / (2*sigma2)))
Fy = tf.exp(-tf.square((im – mu_x) / (2*sigma2)))
# normalize area-under-curve
Fx = Fx / tf.maximum(tf.reduce_sum(Fx,2,keep_dims=True),1e-8)
Fy = Fy / tf.maximum(tf.reduce_sum(Fy,2,keep_dims=True),1e-8)
return Fx, Fy

# read operation without attention
def read_basic(self, x, x_hat, h_dec_prev):
return tf.concat(1,[x,x_hat])

# read operation with attention
def read_attention(self, x, x_hat, h_dec_prev):
Fx, Fy, gamma = self.attn_window(“read”, h_dec_prev)
# apply parameters for patch of gaussian filters
def filter_img(img, Fx, Fy, gamma):
Fxt = tf.transpose(Fx, perm=[0,2,1])
img = tf.reshape(img, [-1, self.img_size, self.img_size])
# apply the gaussian patches
glimpse = tf.batch_matmul(Fy, tf.batch_matmul(img, Fxt))
glimpse = tf.reshape(glimpse, [-1, self.attention_n**2])
# scale using the gamma parameter
return glimpse * tf.reshape(gamma, [-1, 1])
x = filter_img(x, Fx, Fy, gamma)
x_hat = filter_img(x_hat, Fx, Fy, gamma)
return tf.concat(1, [x, x_hat])

# encoder function for attention patch
def encode(self, prev_state, image):
# update the RNN with our image
with tf.variable_scope(“encoder”,reuse=self.share_parameters):
hidden_layer, next_state = self.lstm_enc(image, prev_state)

# map the RNN hidden state to latent variables
with tf.variable_scope(“mu”, reuse=self.share_parameters):
mu = dense(hidden_layer, self.n_hidden, self.n_z)
with tf.variable_scope(“sigma”, reuse=self.share_parameters):
logsigma = dense(hidden_layer, self.n_hidden, self.n_z)
sigma = tf.exp(logsigma)
return mu, logsigma, sigma, next_state

def sampleQ(self, mu, sigma):
return mu + sigma*self.e

# decoder function
def decode_layer(self, prev_state, latent):
# update decoder RNN using our latent variable
with tf.variable_scope(“decoder”, reuse=self.share_parameters):
hidden_layer, next_state = self.lstm_dec(latent, prev_state)

return hidden_layer, next_state

# write operation without attention
def write_basic(self, hidden_layer):
# map RNN hidden state to image
with tf.variable_scope(“write”, reuse=self.share_parameters):
decoded_image_portion = dense(hidden_layer, self.n_hidden, self.img_size**2)
return decoded_image_portion

# write operation with attention
def write_attention(self, hidden_layer):
with tf.variable_scope(“writeW”, reuse=self.share_parameters):
w = dense(hidden_layer, self.n_hidden, self.attention_n**2)
w = tf.reshape(w, [self.batch_size, self.attention_n, self.attention_n])
Fx, Fy, gamma = self.attn_window(“write”, hidden_layer)
Fyt = tf.transpose(Fy, perm=[0,2,1])
wr = tf.batch_matmul(Fyt, tf.batch_matmul(w, Fx))
wr = tf.reshape(wr, [self.batch_size, self.img_size**2])
return wr * tf.reshape(1.0/gamma, [-1, 1])

model = draw_model()
model.train()

You can see the full notebook on my github page.

About the author: Samuel Noriega is a Master of Data Science graduate from the University of Barcelona. He is the head of Data Science at Shugert Analytics and city lead at Saturdays.ai. Recently co-founded Roomies.es.