首页 > 其他分享 >实验16-使用GAN生成手写数字样本

实验16-使用GAN生成手写数字样本

时间:2024-04-27 15:01:43浏览次数:24  
标签:loss img 16 self GAN add import 手写 model

版本python3.7 tensorflow版本为tensorflow-gpu版本2.6

运行结果:

 代码:

from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys
import os
import numpy as np

class GAN():
    def __init__(self):
        # --------------------------------- #
        #   行28,列28,也就是mnist的shape
        # --------------------------------- #
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        # 28,28,1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        # adam优化器
        optimizer = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        # 在训练generate的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):
        # --------------------------------- #
        #   生成器,输入一串随机数字
        # --------------------------------- #
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):
        # ----------------------------------- #
        #   评价器,对输入进来的图片进行评价
        # ----------------------------------- #
        model = Sequential()
        # 输入一张图片
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        # 判断真伪
        model.add(Dense(1, activation='sigmoid'))

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):
        # 获得数据
        (X_train, _), (_, _) = mnist.load_data(r"F:\大学\大三\选修\机器学习\机械学习\实验\实验十四\mnist.npz")

        # 进行标准化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # --------------------------- #
            #   随机选取batch_size个图片
            #   对discriminator进行训练
            # --------------------------- #
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # --------------------------- #
            #  训练generator
            # --------------------------- #
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):

        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=30000, batch_size=256, sample_interval=200)

 

标签:loss,img,16,self,GAN,add,import,手写,model
From: https://www.cnblogs.com/daitu66/p/18162065

相关文章

  • 实验14-1使用cnn完成MNIST手写体识别(tf)+实验14-2使用cnn完成MNIST手写体识别(keras)
    版本python3.7tensorflow版本为tensorflow-gpu版本2.6实验14-1使用cnn完成MNIST手写体识别(tf)运行结果: 代码:importtensorflowastf#Tensorflow提供了一个类来处理MNIST数据fromtensorflow.examples.tutorials.mnistimportinput_dataimporttime#载入数据集mn......
  • 手写bind函数
    今天无事手写一个bind函数//重写bind函数Function.prototype.bindDemo=function(){//arguments可以获取到传的参数//1.先把获取到的数据转换为数组的格式letargs=Array.prototype.slice.call(arguments);//2.获取数组中第一个元素,即this即将指向的数据le......
  • 【vue3入门】-【16】表单输入绑定
    表单输入绑定在前端处理表单时,我们常常需要将表单输入框的内容同步给JavaScript中相应的变量。手动连接值绑定和更改事件监听器可能会比较麻烦,v-model指令帮我们简化了这一步骤。<template><h3>表单输入绑定</h3><form><!--v-model:在页面中输入信息的同时,下......
  • Ubuntu 16.04 LTS 升级到 Ubuntu 18.04 LTS
    Ubuntu从16.04升级到18.04版本_ubuntu16upgrade了18的库-CSDN博客......
  • (数据科学学习手札160)使用miniforge代替miniconda
    本文已收录至我的Github仓库https://github.com/CNFeffery/DataScienceStudyNotes1简介大家好我是费老师,conda作为Python数据科学领域的常用软件,是对Python环境及相关依赖进行管理的经典工具,通常集成在anaconda或miniconda等产品中供用户日常使用。但长久以来,conda......
  • 4.16
    本人最近在写一个小的安卓项目,开发app过程中用到了安卓自带的sqlite。本文主要对sqlite图片操作进行介绍,其他存入文本之类的操作和普通数据库一样,众所周知,sqlite是一款轻型的数据库,以下先简单介绍一下sqlite,为后续做铺垫,有了解的大佬可以跳过此部分:SQLite是一种轻量级、嵌入式的......
  • 16.匿名函数 与 部分内置函数
    【一】匿名函数1)语法lambda函数参数:表达式2)用法#单参数匿名函数lbd_sqr=lambdax:x**2#多参数匿名函数sumary_lba=lambdaarg1,arg2:arg1+arg2#多参数解包add_lba=lambda*args:sum(args)3)高阶函数#过滤函数(filter)odd=lambdax:x%2==1......
  • P3293 [SCOI2016] 美味
    经典题,\(\rm01Trie\)和主席树的结合。考虑一个没有偏移量的时候如何计算,其实就是一个裸的可持久化\(\rmTrie\)。但是有了偏移量就不一样了,这会导致直接改变\(\rmTrie\)的结构,十分不好做。套路的逐位考虑,从高位枚举到低位。假设当前找到的数为\(\rmret\),考虑到\(i\)......
  • OpenAI未至,Open-Sora再度升级!已支持生成16秒720p视频
    Open-Sora在开源社区悄悄更新了!现在支持长达16秒的视频生成,分辨率最高可达720p,并且可以处理任何宽高比的文本到图像、文本到视频、图像到视频、视频到视频和无限长视频的生成需求。我们来试试效果。生成个横屏圣诞雪景,发b站再生成个竖屏,发抖音还能生成16秒的长视频,这下人......
  • cf 1601B Frog Traveler Codeforces Round 751 (Div. 1)
     Problem-1601B-Codeforces BFS然后每次上升可以的范围是一个区间,然后每次都遍历这个区间的所有点,那么超时。用set等方式,合并这些区间,之前没遍历过的范围才更新(加入BFS需要遍历的队列里)。但是区间的更新特别容易写错…… 我的代码和造数据1/**2记录两个vi......