首页 > 其他分享 >TensorFlow\Keras实战100例——GAN生成图像

TensorFlow\Keras实战100例——GAN生成图像

时间:2024-11-13 21:16:40浏览次数:3  
标签:layers Keras GAN add train output images TensorFlow model

 一.原理说明

GAN 包括两个主要部分:

  • 生成器(Generator)
  • 鉴别器(Discriminator)。

生成器负责创建新颖的图像,而鉴别器负责了解生成的图像有多好。

我们要为 GAN 图像生成构建的整个架构如下图所示。

二.数据说明

MINST数据集是机器学习领域一个经典的数据集,其中包括70000个样本,包括60000个训练样本和10000个测试样本

三.代码实战

 第一步:导入头文件
 

import os
import time
import struct
import numpy as np
import glob
import matplotlib.pyplot as plt
import PIL
import imageio
import tensorflow as tf
import tensorflow.keras.layers as layers
import time
from IPython import display
print(tf.__version__)

第二步:导入数据并进行预处理

def dense_to_one_hot(labels_dense, num_classes=10):
  """将类标签从标量转换为一个独热向量"""
  num_labels = labels_dense.shape[0]
  index_offset = np.arange(num_labels) * num_classes
  labels_one_hot = np.zeros((num_labels, num_classes))
  labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
  return labels_one_hot

def load_mnist(path, kind='train'):
    """根据指定路径加载数据集"""
    labels_path = os.path.join(path, '%s-labels-idx1-ubyte' % kind)
    images_path = os.path.join(path, '%s-images-idx3-ubyte' % kind)

    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II',lbpath.read(8))
        labels = np.fromfile(lbpath, dtype=np.uint8)
        labels=dense_to_one_hot(labels)

    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack(">IIII",imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)


    return images, labels

X_train, y_train = load_mnist('../data/MNIST/raw/', kind='train')
print('Rows: %d, columns: %d' % (X_train.shape[0], X_train.shape[1]))
print('Rows: %d, columns: %d' % ( y_train.shape[0],  y_train.shape[1]))

X_test, y_test = load_mnist('../data/MNIST/raw/', kind='t10k')
print('Rows: %d, columns: %d' % (X_test.shape[0], X_test.shape[1]))

# 构建数据集

train_images = X_train.reshape(X_train.shape[0], 28, 28, 1).astype('float32')
test_images = X_test.reshape(X_test.shape[0], 28, 28, 1).astype('float32')

# 化到0-1之间
train_images /= 255.0
test_images /= 255.0

# 二值化
train_images[train_images>=0.5] = 1.0
train_images[train_images<0.5] = 0.0
test_images[test_images>=0.5] = 1.0
test_images[test_images<0.5] = 0.0

# 超参数
TRAIN_BUF=60000
BATCH_SIZE = 100
TEST_BUF = 10000

train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TEST_BUF).batch(BATCH_SIZE)

第三步:搭建模型网络

# 构建生成器,这个相当于VAE中的解码器 
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
      
    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size
    
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)  
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)    
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)
  
    return model


def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same', 
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
      
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
       
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
     
    return model

第四步:训练

generator = make_generator_model()
discriminator = make_discriminator_model()

# 计算交叉熵损失
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# 判别器损失
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

# 生成器损失
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16

seed = tf.random.normal([num_examples_to_generate, noise_dim])


def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))


def train(dataset, epochs):  
  for epoch in range(epochs):
    start = time.time()
    
    for image_batch in dataset:
      train_step(image_batch)

    # Produce images for the GIF as we go
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)
        
    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
    
  # 在最后epoch生成图像
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)


def generate_and_save_images(model, epoch, test_input):
  # Notice `training` is set to False. 
  # This is so all layers run in inference mode (batchnorm).
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))
  
  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')
        
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()


train(train_dataset, EPOCHS)

可视化结果:

标签:layers,Keras,GAN,add,train,output,images,TensorFlow,model
From: https://blog.csdn.net/u013289254/article/details/143752724

相关文章

  • TensorFlow\Keras实战100例——变分自编码器生成图像
    一.原理说明变分自编码器是自编码器的改进版本,自编码器是一种无监督学习,但它无法产生新的内容,变分自编码器对其潜在空间进行拓展,使其满足正态分布,情况就大不一样了。自编码器是通过对输入X进行编码后得到一个低维的向量z,然后根据这个向量还原出输入X。通过对比X与X̃的误差,......
  • 2024年最新优化算法:海市蜃楼算法(Fata Morgana Algorithm ,FATA)介绍
    海市蜃楼算法(FataMorganaAlgorithm,FATA)是2024年提出一种新型的群体智能优化算法,它的设计灵感来源于自然现象中的海市蜃楼形成过程。FATA算法通过模仿光线在不均匀介质中的传播方式,提出了两种核心策略——海市蜃楼光过滤原则(MLF)和光传播策略(LPS)——来优化搜索过程,增强算法......
  • express使用morgan+file-stream-rotator实现自定义日志+轮转
    importexpress,{json}from'express';importfsfrom'fs';importpathfrom'path';importmorganfrom'morgan';importFileStreamRotatorfrom'file-stream-rotator';constapp=express();//自动采集一些东西//......
  • POLIR-Society-Organization-Management: “How”-关系网络+组织建设+目标: 计划:管人:
    POLIR-Society-Organization-Management:“How”沟通+关系网络Object的Role:Internalboss/上级:Outcome,平级:Team/Organization员工:RoleExternalCustomer:7P+RelationshipSupplierCompetetorIndividualGovernment组织建设分辨好坏对错是非目......
  • MOGANET-SA模块
    paper`importtorch.nnasnnimporttorchimporttorch.nn.functionalasFdefbuild_act_layer(act_type):"""Buildactivationlayer."""ifact_typeisNone:returnnn.Identity()assertact_typein['GELU','ReL......
  • MOGANET-CA模块
    paper`importtorchimporttorch.nnasnndefbuild_act_layer(act_type):#Buildactivationlayerifact_typeisNone:returnnn.Identity()assertact_typein['GELU','ReLU','SiLU']ifact_type=='SiLU':returnn......
  • tensorflow案例5--基于改进VGG16模型的马铃薯识别,准确率提升0.6%,计算量降低78.07%
    ......
  • 2025年入门深度学习或人工智能,该学PyTorch还是TensorFlow?
    随着2025应用人工智能和深度学习技术的举世泛气,还在迷茫于该选择哪个深度学习框架吗?PyTorch和TensorFlow是并立于深度学习世界两座巨塔,但是越来越多人发现,在2025年,PyTorch似乎比TensorFlow更为流行和被接受。下面我来分析一下这两个深度学习框架的发展历史,应用差异和现状,以......
  • 对比:生成对抗网络(GANs)和变分自编码器(VAEs)
    以下是生成对抗网络(GANs)和变分自编码器(VAEs)的详细介绍、区别、优缺点的对比表:项目生成对抗网络(GANs)变分自编码器(VAEs)定义GANs是一种生成模型,通过训练两个网络:生成器和判别器,生成器生成数据,判别器判断数据真假,从而相互提升。VAEs是一种概率生成模型,通过学习潜在空间的分布,将......
  • TensorFlow介绍
    TensorFlow是一个开源的机器学习框架,由Google开发并维护。它是一个基于数据流图的计算库,能够用于构建和训练各种机器学习模型。TensorFlow的核心功能是进行张量(Tensor)操作,它使用计算图来表示和执行数值计算。TensorFlow的基本概念包括:1.张量(Tensor):是多维数组的一种表示形式,......