一.原理说明
GAN 包括两个主要部分:
- 生成器(Generator)
- 鉴别器(Discriminator)。
生成器负责创建新颖的图像,而鉴别器负责了解生成的图像有多好。
我们要为 GAN 图像生成构建的整个架构如下图所示。
二.数据说明
MINST数据集是机器学习领域一个经典的数据集,其中包括70000个样本,包括60000个训练样本和10000个测试样本
三.代码实战
第一步:导入头文件
import os
import time
import struct
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
import tensorflow as tf
import tensorflow.keras.layers as layers
import time
from IPython import display
print(tf.__version__)
第二步:导入数据并进行预处理
def dense_to_one_hot(labels_dense, num_classes=10):
"""将类标签从标量转换为一个独热向量"""
num_labels = labels_dense.shape[0]
index_offset = np.arange(num_labels) * num_classes
labels_one_hot = np.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
def load_mnist(path, kind='train'):
"""根据指定路径加载数据集"""
labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind)
images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind)
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II',lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
labels=dense_to_one_hot(labels)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack(">IIII",imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
return images, labels
X_train, y_train = load_mnist('../data/MNIST/raw/', kind='train')
print('Rows: %d, columns: %d' % (X_train.shape[0], X_train.shape[1]))
print('Rows: %d, columns: %d' % ( y_train.shape[0], y_train.shape[1]))
X_test, y_test = load_mnist('../data/MNIST/raw/', kind='t10k')
print('Rows: %d, columns: %d' % (X_test.shape[0], X_test.shape[1]))
# 构建数据集
train_images = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
test_images = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')
# 化到0-1之间
train_images /= 255.0
test_images /= 255.0
# 二值化
train_images[train_images>=0.5] = 1.0
train_images[train_images<0.5] = 0.0
test_images[test_images>=0.5] = 1.0
test_images[test_images<0.5] = 0.0
# 超参数
TRAIN_BUF=60000
BATCH_SIZE = 100
TEST_BUF = 10000
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)
第三步:搭建模型网络
# 构建生成器,这个相当于VAE中的解码器
def make_generator_model():
model = tf.keras.Sequential()
model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size
model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU())
model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)
return model
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[28, 28, 1]))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(layers.LeakyReLU())
model.add(layers.Dropout(0.3))
model.add(layers.Flatten())
model.add(layers.Dense(1))
return model
第四步:训练
generator = make_generator_model()
discriminator = make_discriminator_model()
# 计算交叉熵损失
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
# 判别器损失
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
# 生成器损失
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16
seed = tf.random.normal([num_examples_to_generate, noise_dim])
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
def train(dataset, epochs):
for epoch in range(epochs):
start = time.time()
for image_batch in dataset:
train_step(image_batch)
# Produce images for the GIF as we go
display.clear_output(wait=True)
generate_and_save_images(generator,
epoch + 1,
seed)
print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
# 在最后epoch生成图像
display.clear_output(wait=True)
generate_and_save_images(generator,
epochs,
seed)
def generate_and_save_images(model, epoch, test_input):
# Notice `training` is set to False.
# This is so all layers run in inference mode (batchnorm).
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4,4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
train(train_dataset, EPOCHS)
可视化结果:
标签:layers,Keras,GAN,add,train,output,images,TensorFlow,model From: https://blog.csdn.net/u013289254/article/details/143752724