1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80
| import tensorflow as tf from tensorflow.keras.layers import Dense, Flatten, Reshape from tensorflow.keras.models import Sequential import matplotlib.pyplot as plt import numpy as np
def build_generator(latent_dim, output_shape): model = Sequential([ Dense(128, input_dim=latent_dim, activation='relu'), Dense(256, activation='relu'), Dense(np.prod(output_shape), activation='sigmoid'), Reshape(output_shape) ]) return model
def build_discriminator(input_shape): model = Sequential([ Flatten(input_shape=input_shape), Dense(256, activation='relu'), Dense(128, activation='relu'), Dense(1, activation='sigmoid') ]) return model
latent_dim = 100 input_shape = (28, 28, 1)
generator = build_generator(latent_dim, input_shape) discriminator = build_discriminator(input_shape) discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) discriminator.trainable = False
gan_input = tf.keras.Input(shape=(latent_dim,)) gan_output = discriminator(generator(gan_input)) gan = tf.keras.Model(gan_input, gan_output) gan.compile(loss='binary_crossentropy', optimizer='adam')
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data() x_train = x_train / 255.0 x_train = np.expand_dims(x_train, axis=-1)
epochs = 10000 batch_size = 32
for epoch in range(epochs): noise = np.random.normal(0, 1, (batch_size, latent_dim)) generated_images = generator.predict(noise) real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)] d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1))) d_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1))) d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) noise = np.random.normal(0, 1, (batch_size, latent_dim)) g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1))) if epoch % 100 == 0: print(f"Epoch {epoch}, D Loss: {d_loss[0]}, G Loss: {g_loss}")
if epoch % 1000 == 0: generated_images = generated_images * 255.0 generated_images = generated_images.astype(np.uint8) for i in range(4): plt.subplot(2, 2, i+1) plt.imshow(generated_images[i].reshape(28, 28), cmap='gray') plt.show()
|