Computers can be creative, at least with AI — Here’s how (GAN Tutorial)
Generative Adversarial Networks: The artists of the AI world.
3 words. 11 syllables. Generative Adversarial Networks.
You may have heard of the term before. It’s been spreading like wildfire among machine learning enthusiasts, especially in the graphics department. Companies like NVIDIA are all over this relatively new technology that was introduced in 2014 from researchers at the University of Montreal.
NVIDIA has used GANs to create a game demo using only real life footage and a program called StyleGAN which can generate images of people that don’t even exist.
GANs have so much potential and we’re just touching the surface of it.
If you want to learn more about how GANs are structured and how they generate images, you can check out another article I wrote here.
In this article, I’ll go more technical and show you how to code a GAN to generate handwritten images using the MNIST data set.
Generating Handwritten Digits Using GANs
This is some preliminary code to help our project run smoothly.
First, we must load the MNIST data set using a custom function. In the function, we load the data, normalize the values, and change the shape of the data piece “x_train”
Then we can declare our optimizer. We will be using the “Adam” optimizer from Keras with these specific parameters.
Then we can put together the generator model. It is compromised of 4 dense layers and 3 Leaky ReLU layers. The input dimension is a variable that we declared earlier that has a value of 100. The generator will generate images using this dimension value. The final layer uses the activation “tanh” (which is mainly used in the classification of 2 classes) and the entire model is compiled with the optimizer and the loss function of “binary_crossentropy”. We use this loss function because it measures the performance of a classification model whose output is a probability value between 0 and 1.
We then build the discriminator model which will determine if the images fed to it is authentic (from the MNIST data set) or not authentic (from the generator). The discriminator is what improves the images generated by the generator over time. The discriminator is comprised of 4 dense layers, 3 Leaky ReLU layers, and 3 dropout layers (to prevent overfitting). The last dense layer has an activation function of “sigmoid” which is used for models that have to predict the probability as an output. We compile the discriminator the same way as the generator.
We bring the generator and the discriminator together using this custom function to build the full GAN model.
To track how our GAN is doing over time, we can plot the generated images using this function and matplotlib.
Finally, we can train the GAN model using this train function. We load the training and test data, split the batches into sizes of 128, then build the GAN network using the function we created previously. Then, in a loop, we can do a certain set of tasks per epoch. First we generate fake MNIST images, then label the generated and fake data, train the discriminator, train the generator, and finally plot the images.
After 400 epochs (which took a long time), we can see how the generated images got better in quality over time.
Results after 1 epoch:
Results after 40 epochs:
Results after 400 epochs:
Congratulations! You’ve just created a GAN that can generate hand written digits that look pretty convincing.
Connect with me!