首页 > 其他分享 >tensorflow keras从入门到精通——DcGAN生成手写数字

tensorflow keras从入门到精通——DcGAN生成手写数字

时间:2022-11-01 18:04:44浏览次数:61  
标签:keras BATCH add train DcGAN images tensorflow model SIZE


导库

from keras.models import Sequential
from keras.layers import Dense, BatchNormalization
from keras.layers import Reshape
from keras.layers.core import Activation
from keras.layers.convolutional import UpSampling2D
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Flatten
from keras.datasets import mnist
import numpy as np
from PIL import Image
import argparse
import math
from keras.optimizers import SGD

生成网络

输入数据为100维数据,不断上采样,最后生成1,1,28,28的图片,生成网络利用noise
生成的img 再利用判别器进行判别与真实标签做损失,可以得到g_loss

def generator_model():
'''
生成器代码

:return:
'''
model = Sequential()
# input and out
model.add(Dense(input_dim=100, output_dim=1024))
# 激活函数
model.add(Activation('tanh'))
#
model.add(Dense(128*7*7))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(64, (5, 5), padding='same'))
model.add(Activation('tanh'))
model.add(UpSampling2D(size=(2, 2)))
model.add(Conv2D(1, (5, 5), padding='same'))
model.add(Activation('tanh'))
return model

判别网络

def discriminator_model():
'''
判别器
:return:
'''
model = Sequential()
# 输入维度input_shape=(28, 28, 1)
model.add(
Conv2D(64, (5, 5),
padding='same',
input_shape=(28, 28, 1))
)
# tanh激活
model.add(Activation('tanh'))
# 最大池化
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, (5, 5)))

model.add(Activation('tanh'))

model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())

model.add(Dense(1024))

model.add(Activation('tanh'))
model.add(Dense(1))
# sigmoid 激活
model.add(Activation('sigmoid'))
return model
def generator_containing_discriminator(g, d):
model = Sequential()
model.add(g)
# d 设置为不可训练
d.trainable = False
model.add(d)
return model
def combine_images(generated_images):
# 保存图片
num = generated_images.shape[0]

# 数量
width = int(math.sqrt(num))

height = int(math.ceil(float(num)/width))


shape = generated_images.shape[1:3]

# 生成全黑
image = np.zeros((height*shape[0], width*shape[1]),
dtype=generated_images.dtype)

for index, img in enumerate(generated_images):
i = int(index/width)
j = index % width
image[i*shape[0]:(i+1)*shape[0], j*shape[1]:(j+1)*shape[1]] = img[:, :, 0]
return image

训练

def train(BATCH_SIZE):

# 加载数据
(X_train, y_train), (X_test, y_test) = mnist.load_data()

# 图像归一化 60000,28,28
X_train = (X_train.astype(np.float32) - 127.5)/127.5


X_train = X_train[:, :, :, None]
X_test = X_test[:, :, :, None]

# X_train = X_train.reshape((X_train.shape, 1) + X_train.shape[1:])



d = discriminator_model()
g = generator_model()
d_on_g = generator_containing_discriminator(g, d)


# 牛顿动量
d_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)
g_optim = SGD(lr=0.0005, momentum=0.9, nesterov=True)


# 二进制交叉熵
g.compile(loss='binary_crossentropy', optimizer="SGD")

d_on_g.compile(loss='binary_crossentropy', optimizer=g_optim)
d.trainable = True
d.compile(loss='binary_crossentropy', optimizer=d_optim)


for epoch in range(100):
print("Epoch is", epoch)
print("Number of batches", int(X_train.shape[0]/BATCH_SIZE))

for index in range(int(X_train.shape[0]/BATCH_SIZE)):
# 生成noise
noise = np.random.uniform(-1, 1, size=(BATCH_SIZE, 100))
# 生成批次
image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]

# 不输出日志,预测
generated_images = g.predict(noise, verbose=0)

if index % 20 == 0:
image = combine_images(generated_images)
image = image*127.5+127.5
Image.fromarray(image.astype(np.uint8)).save("./images/"+
str(epoch)+"_"+str(index)+".png")
X = np.concatenate((image_batch, generated_images))
### label
y = [1] * BATCH_SIZE + [0] * BATCH_SIZE

# 判别器损失
d_loss = d.train_on_batch(X, y)
print("batch %d d_loss : %f" % (index, d_loss))
# 生成noise
noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
d.trainable = False
# g_loss
g_loss = d_on_g.train_on_batch(noise, [1] * BATCH_SIZE)
d.trainable = True

print("batch %d g_loss : %f" % (index, g_loss))
if index % 10 == 9:
print("save successed!")
g.save_weights('generator', True)
d.save_weights('discriminator', True)

test

def generate(BATCH_SIZE, nice=False):

g = generator_model()
g.compile(loss='binary_crossentropy', optimizer="SGD")
g.load_weights('generator')
if nice:
### 获取生成效果较好的图片
d = discriminator_model()
d.compile(loss='binary_crossentropy', optimizer="SGD")
d.load_weights('discriminator')
noise = np.random.uniform(-1, 1, (BATCH_SIZE*20, 100))
generated_images = g.predict(noise, verbose=1)
d_pret = d.predict(generated_images, verbose=1)
index = np.arange(0, BATCH_SIZE*20)
index.resize((BATCH_SIZE*20, 1))
pre_with_index = list(np.append(d_pret, index, axis=1))
pre_with_index.sort(key=lambda x: x[0], reverse=True)
nice_images = np.zeros((BATCH_SIZE,) + generated_images.shape[1:3], dtype=np.float32)
nice_images = nice_images[:, :, :, None]

for i in range(BATCH_SIZE):
idx = int(pre_with_index[i][1])
nice_images[i, :, :, 0] = generated_images[idx, :, :, 0]
image = combine_images(nice_images)

else:
### 直接生成
noise = np.random.uniform(-1, 1, (BATCH_SIZE, 100))
generated_images = g.predict(noise, verbose=1)
image = combine_images(generated_images)

image = image*127.5+127.5
Image.fromarray(image.astype(np.uint8)).save(
"generated_image.png")

参数配置

def get_args():
'''
参数配置
:return:
'''
parser = argparse.ArgumentParser()
parser.add_argument("--mode", type=str)
parser.add_argument("--batch_size", type=int, default=96)
parser.add_argument("--nice", dest="nice", action="store_true")
parser.set_defaults(nice=True)

args = parser.parse_args()
return args
if __name__ == "__main__":
args = get_args()
args.mode ="generate"
if args.mode == "train":
train(BATCH_SIZE=args.batch_size)
elif args.mode == "generate":
generate(BATCH_SIZE=args.batch_size, nice=args.nice)


标签:keras,BATCH,add,train,DcGAN,images,tensorflow,model,SIZE
From: https://blog.51cto.com/u_13859040/5814618

相关文章

  • tensorflow 实现Logistic Regression
    importlibimportnumpyasnpimporttensorflowastffromsklearnimportdatasetsimportpandasaspdfromsklearn.cross_validationimporttrain_test_splitfrommatp......
  • 生成对抗网络实战入门到精通——DCGAN算法以及训练增强手法
    DCGAN​​DCGAN​​​​原理​​​​改进​​​​网络架构​​​​ImprovedDCGAN​​​​训练增强手法​​​​特征匹配​​​​批次判别​​​​常见问题​​​​原理​​......
  • 生成对抗网络从入门到精通——DCGAN
    这里只给技术实现,以minst数据集作为测试不懂得看论文噢~importargparseimportosimportnumpyasnpimportmathimporttorchvision.transformsastransformsfromtorchvi......
  • 20221031&20221101 Keras
    周末长安杯加上组网实验信安数基上机计网翻转课堂核酸S12半决赛,小摆几天......
  • Keras可视化神经网络架构的4种方法
    我们在使用卷积神经网络或递归神经网络或其他变体时,通常都希望对模型的架构可以进行可视化的查看,因为这样我们可以在定义和训练多个模型时,比较不同的层以及它们放置的顺序......
  • tensorflow2中以复制方式扩展tensor —— tf.tile()
    tensorflow2.0环境下,以复制方式扩展tensor,可以使用​​tf.tile()​​函数。该函数定义如下(图自官网):https://www.tensorflow.org/api_docs/python/tf/keras/backend/tiletf......
  • 2、tensorflow
    1、已经安装了scipy,但是无法调用'ImagetransformationsrequireSciPy.'、'InstallSciPy.'name'scipy'isnotdefined解决方法:pip3-Vpython3-mscipypytho......
  • 安装TensorFlow CPU版本
    TensorFlow1.1TensorFlow介绍  TensorFlow就是谷歌公司推出的一款高效的人工智能开源框架,自从2015年11月发布以来,已经成为全世界最广泛使用的深度学习库。很多以前难......
  • TensorFlow.NET机器学习入门【0】前言与目录
    TensorFlow.NET机器学习入门【0】前言与目录    曾经学习过一段时间ML.NET的知识,ML.NET是微软提供的一套机器学习框架,相对于其他的一些机器学习框架,ML.NET侧重于......
  • 【解决错误】AttributeError: module 'tensorflow.compat.v2.__internal__' has no at
    原因一般为tensorflow和keras版本不匹配。解决方法以下是tensorflow版本对应关系我最开始使用的为tensorflow=2.4.0,keras=2.4.3,但是问题仍然没有解决,我就安装了te......