导库
from keras.models import Sequential
from keras.layers import Dense, BatchNormalization
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math
from keras.optimizers import SGD
生成网络
输入数据为100维数据,不断上采样,最后生成1,1,28,28的图片,生成网络利用noise
生成的img 再利用判别器进行判别与真实标签做损失,可以得到g_loss
def generator_model():
'''
生成器代码
:return:
'''
model = Sequential()
# input and out
model.add(Dense(input_dim=100, output_dim=1024))
# 激活函数
model.add(Activation('tanh'))
#
model.add(Dense(128*7*7))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(64, (5, 5), padding='same'))
model.add(Activation('tanh'))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(1, (5, 5), padding='same'))
model.add(Activation('tanh'))
return model
判别网络
def discriminator_model():
'''
判别器
:return:
'''
model = Sequential()
# 输入维度input_shape=(28, 28, 1)
model.add(
Conv2D(64, (5, 5),
padding='same',
input_shape=(28, 28, 1))
)
# tanh激活
model.add(Activation('tanh'))
# 最大池化
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(128, (5, 5)))
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(1024))
model.add(Activation('tanh'))
model.add(Dense(1))
# sigmoid 激活
model.add(Activation('sigmoid'))
return model
def generator_containing_discriminator(g, d):
model = Sequential()
model.add(g)
# d 设置为不可训练
d.trainable = False
model.add(d)
return model
def combine_images(generated_images):
# 保存图片
num = generated_images.shape[0]
# 数量
width = int(math.sqrt(num))
height = int(math.ceil(float(num)/width))
shape = generated_images.shape[1:3]
# 生成全黑
image = np.zeros((height*shape[0], width*shape[1]),
dtype=generated_images.dtype)
for index, img in enumerate(generated_images):
i = int(index/width)
j = index % width
image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = img[:, :, 0]
return image
训练
def train(BATCH_SIZE):
# 加载数据
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# 图像归一化 60000,28,28
X_train = (X_train.astype(np.float32) - 127.5)/127.5
X_train = X_train[:, :, :, None]
X_test = X_test[:, :, :, None]
# X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])
d = discriminator_model()
g = generator_model()
d_on_g = generator_containing_discriminator(g, d)
# 牛顿动量
d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
# 二进制交叉熵
g.compile(loss='binary_crossentropy', optimizer="SGD")
d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
d.trainable = True
d.compile(loss='binary_crossentropy', optimizer=d_optim)
for epoch in range(100):
print("Epoch is", epoch)
print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))
for index in range(int(X_train.shape[0]/BATCH_SIZE)):
# 生成noise
noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
# 生成批次
image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
# 不输出日志,预测
generated_images = g.predict(noise, verbose=0)
if index % 20 == 0:
image = combine_images(generated_images)
image = image*127.5+127.5
Image.fromarray(image.astype(np.uint8)).save("./images/"+
str(epoch)+"_"+str(index)+".png")
X = np.concatenate((image_batch, generated_images))
### label
y = [1] * BATCH_SIZE + [0] * BATCH_SIZE
# 判别器损失
d_loss = d.train_on_batch(X, y)
print("batch %d d_loss : %f" % (index, d_loss))
# 生成noise
noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
d.trainable = False
# g_loss
g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE)
d.trainable = True
print("batch %d g_loss : %f" % (index, g_loss))
if index % 10 == 9:
print("save successed!")
g.save_weights('generator', True)
d.save_weights('discriminator', True)
test
def generate(BATCH_SIZE, nice=False):
g = generator_model()
g.compile(loss='binary_crossentropy', optimizer="SGD")
g.load_weights('generator')
if nice:
### 获取生成效果较好的图片
d = discriminator_model()
d.compile(loss='binary_crossentropy', optimizer="SGD")
d.load_weights('discriminator')
noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
generated_images = g.predict(noise, verbose=1)
d_pret = d.predict(generated_images, verbose=1)
index = np.arange(0, BATCH_SIZE*20)
index.resize((BATCH_SIZE*20, 1))
pre_with_index = list(np.append(d_pret, index, axis=1))
pre_with_index.sort(key=lambda x: x[0], reverse=True)
nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
nice_images = nice_images[:, :, :, None]
for i in range(BATCH_SIZE):
idx = int(pre_with_index[i][1])
nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
image = combine_images(nice_images)
else:
### 直接生成
noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
generated_images = g.predict(noise, verbose=1)
image = combine_images(generated_images)
image = image*127.5+127.5
Image.fromarray(image.astype(np.uint8)).save(
"generated_image.png")
参数配置
def get_args():
'''
参数配置
:return:
'''
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str)
parser.add_argument("--batch_size", type=int, default=96)
parser.add_argument("--nice", dest="nice", action="store_true")
parser.set_defaults(nice=True)
args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
args.mode ="generate"
if args.mode == "train":
train(BATCH_SIZE=args.batch_size)
elif args.mode == "generate":
generate(BATCH_SIZE=args.batch_size, nice=args.nice)