0%


title:机器学习-GAN

Generative Adversarial Network,就是大家耳熟能详的 GAN,由 Ian Goodfellow 首先提出,在这两年更是深度学习中最热门的东西,仿佛什么东西都能由 GAN 做出来。我最近刚入门 GAN,看了些资料,做一些笔记。

1.Generation
什么是生成(generation)?就是模型通过学习一些数据,然后生成类似的数据。让机器看一些动物图片,然后自己来产生动物的图片,这就是生成。

以前就有很多可以用来生成的技术了,比如 auto-encoder(自编码器)

你训练一个 encoder,把 input 转换成 code,然后训练一个 decoder,把 code 转换成一个 image,然后计算得到的 image 和 input 之间的 MSE(mean square error),训练完这个 model 之后,取出后半部分 NN Decoder,输入一个随机的 code,就能 generate 一个 image。

但是 auto-encoder 生成 image 的效果,当然看着很别扭啦,一眼就能看出真假。所以后来还提出了比如VAE这样的生成模型,我对此也不是很了解,在这就不细说。

上述的这些生成模型,其实有一个非常严重的弊端。比如 VAE,它生成的 image 是希望和 input 越相似越好,但是 model 是如何来衡量这个相似呢?model 会计算一个 loss,采用的大多是 MSE,即每一个像素上的均方差。loss 小真的表示相似嘛?

生成对抗网络(Generative Adversarial Network,GAN)是一种深度学习模型,用于生成与训练数据相似的新数据。GAN 包括两个主要组件:生成器(Generator)和判别器(Discriminator)。生成器试图生成类似于真实数据的样本,而判别器试图区分生成的样本和真实样本。这两个组件通过对抗性训练相互竞争,最终生成具有高质量的新数据。

以下是一个使用 Python 的 TensorFlow 库实现简单 GAN 的示例代码:

首先,确保已安装 tensorflow 库,可以通过以下命令安装:

1
pip install tensorflow

接下来,使用下面的代码示例:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.models import Sequential
import matplotlib.pyplot as plt
import numpy as np

# 生成器模型
def build_generator(latent_dim, output_shape):
model = Sequential([
Dense(128, input_dim=latent_dim, activation='relu'),
Dense(256, activation='relu'),
Dense(np.prod(output_shape), activation='sigmoid'),
Reshape(output_shape)
])
return model

# 判别器模型
def build_discriminator(input_shape):
model = Sequential([
Flatten(input_shape=input_shape),
Dense(256, activation='relu'),
Dense(128, activation='relu'),
Dense(1, activation='sigmoid')
])
return model

# 生成器和判别器的参数
latent_dim = 100
input_shape = (28, 28, 1)

# 构建生成器和判别器
generator = build_generator(latent_dim, input_shape)
discriminator = build_discriminator(input_shape)
discriminator.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
discriminator.trainable = False

# 构建 GAN 模型
gan_input = tf.keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = tf.keras.Model(gan_input, gan_output)
gan.compile(loss='binary_crossentropy', optimizer='adam')

# 加载 MNIST 数据集
(x_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=-1)

# 训练 GAN
epochs = 10000
batch_size = 32

for epoch in range(epochs):
# 生成随机噪声作为输入
noise = np.random.normal(0, 1, (batch_size, latent_dim))
# 生成假图片
generated_images = generator.predict(noise)
# 随机选择真实图片
real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]

# 训练判别器
d_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

# 训练生成器
noise = np.random.normal(0, 1, (batch_size, latent_dim))
g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))

# 打印损失
if epoch % 100 == 0:
print(f"Epoch {epoch}, D Loss: {d_loss[0]}, G Loss: {g_loss}")

# 绘制生成的图片
if epoch % 1000 == 0:
generated_images = generated_images * 255.0
generated_images = generated_images.astype(np.uint8)
for i in range(4):
plt.subplot(2, 2, i+1)
plt.imshow(generated_images[i].reshape(28, 28), cmap='gray')
plt.show()

这个示例展示了如何使用 TensorFlow 实现简单的 GAN,生成手写数字图片。代码中首先定义了生成器和判别器模型,然后构建 GAN 模型。通过循环训练生成器和判别器,使它们相互竞争和优化,最终生成具有相似特征的新图片。

gan