Skip to content

GANs en Keras#

Teoría GANs

Ejemplo de GANs con Keras#

Preparación de datos#

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tensorflow.keras.datasets import mnist

# Cargamos los datos - Este dataset contiene imágenes de dígitos escritos a mano.
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# Filtramos solo los ceros
only_zeros = X_train[y_train == 0]

only_zeros.shape # (5923, 28, 28)

Crear modelo#

import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten
from tensorflow.keras.models import Sequential

# Creamos el discriminador
discriminator = Sequential()
discriminator.add(Flatten(input_shape=[28, 28]))
discriminator.add(Dense(150, activation='relu'))
discriminator.add(Dense(100, activation='relu'))
# Capa final - Una neurona que devuelve 0 o 1 (falso o verdadero)
discriminator.add(Dense(1, activation='sigmoid'))

discriminator.compile(loss='binary_crossentropy', optimizer='adam')


# Creamos el generador
# Parecido a autoencoders
# 784 (=28*28) --> 150 --> 30 --> 150 --> 784
coding_size = 100
# 100 --> 150 --> 784

generator = Sequential()
generator.add(Dense(100, input_shape=[coding_size], activation='relu'))
generator.add(Dense(150, activation='relu'))
generator.add(Dense(784, activation='tanh'))
generator.add(Reshape((28, 28)))
# Nota: No se compila el generador porque su entrenamiento depende de la respuesta
#   del discriminador dentro del modelo GAN combinado, no de un entrenamiento directo 
#   e independiente.

# Creamos el modelo combinado
gan = Sequential([generator, discriminator])
# Al compilar y entrenar el modelo combinado (gan), solo queremos que se actualicen los pesos del generador y no los del discriminador.
discriminator.trainable = False
gan.compile(loss='binary_crossentropy', optimizer='adam')

Entrenar el modelo#

batch_size = 32
my_data = only_zeros
dataset = tf.data.Dataset.from_tensor_slices(my_data)
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(1)
epochs = 1

generator, discriminator = gan.layers

for epoch in range(epochs):

    print(f"Currently on epoch {epoch}")

    for X_batch in dataset:
       i = i+1
       if i%100==0:
           print(i)
       print(f"Currently in batch number {i}")

       # Discriminator training
       noise = np.random.normal(size=(batch_size, coding_size))
       gen_images = generator(noise)
       # Concatenar imágenes reales y generadas
       X_fake_vs_real = tf.concat([gen_images, tf.dtypes.cast(X_batch, tf.float32)], axis=0)
       y1 = tf.constant(np.concatenate([[0.0]*batch_size] + [[1.0]*batch_size]))

       discriminator.trainable = True
       discriminator.train_on_batch(X_fake_vs_real, y1)

       # Generator training
       noise = np.random.normal(size=(batch_size, coding_size))
       y2 = tf.constant(np.concatenate([[1.0]*batch_size]))
       discriminator.trainable = False
       gan.train_on_batch(noise, y2)

Generar imágenes#

# Generamos un vector de ruido
noise = tf.random.normal(shape=[10, coding_size])
noise.shape # (10, 100)
plt.imshow(noise)
plt.show()

# pasamos el ruido por el generador
images = generator(noise)
plt.imshow(images[0], cmap='gray')
plt.show() # Debe aparecer una imagen parecida a cero