首页 > 其他分享 >实验16-使用GAN生成手写数字样本

实验16-使用GAN生成手写数字样本

时间:2024-06-05 21:35:33浏览次数:8  
标签:loss img 16 self GAN add import 手写 model

from __future__ import print_function, division

from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout
from keras.layers import BatchNormalization, Activation, ZeroPadding2D
from keras.layers.advanced_activations import LeakyReLU
from keras.layers.convolutional import UpSampling2D, Conv2D
from keras.models import Sequential, Model
from keras.optimizers import Adam

import matplotlib.pyplot as plt

import sys
import os
import numpy as np

class GAN():
    def __init__(self):
        # --------------------------------- #
        #   行28,列28,也就是mnist的shape
        # --------------------------------- #
        self.img_rows = 28
        self.img_cols = 28
        self.channels = 1
        # 28,28,1
        self.img_shape = (self.img_rows, self.img_cols, self.channels)
        self.latent_dim = 100
        # adam优化器
        optimizer = Adam(0.0002, 0.5)

        self.discriminator = self.build_discriminator()
        self.discriminator.compile(loss='binary_crossentropy',
            optimizer=optimizer,
            metrics=['accuracy'])

        self.generator = self.build_generator()
        gan_input = Input(shape=(self.latent_dim,))
        img = self.generator(gan_input)
        # 在训练generate的时候不训练discriminator
        self.discriminator.trainable = False
        # 对生成的假图片进行预测
        validity = self.discriminator(img)
        self.combined = Model(gan_input, validity)
        self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)


    def build_generator(self):
        # --------------------------------- #
        #   生成器,输入一串随机数字
        # --------------------------------- #
        model = Sequential()

        model.add(Dense(256, input_dim=self.latent_dim))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(1024))
        model.add(LeakyReLU(alpha=0.2))
        model.add(BatchNormalization(momentum=0.8))

        model.add(Dense(np.prod(self.img_shape), activation='tanh'))
        model.add(Reshape(self.img_shape))

        noise = Input(shape=(self.latent_dim,))
        img = model(noise)

        return Model(noise, img)

    def build_discriminator(self):
        # ----------------------------------- #
        #   评价器,对输入进来的图片进行评价
        # ----------------------------------- #
        model = Sequential()
        # 输入一张图片
        model.add(Flatten(input_shape=self.img_shape))
        model.add(Dense(512))
        model.add(LeakyReLU(alpha=0.2))
        model.add(Dense(256))
        model.add(LeakyReLU(alpha=0.2))
        # 判断真伪
        model.add(Dense(1, activation='sigmoid'))

        img = Input(shape=self.img_shape)
        validity = model(img)

        return Model(img, validity)

    def train(self, epochs, batch_size=128, sample_interval=50):
        # 获得数据
        (X_train, _), (_, _) = mnist.load_data()

        # 进行标准化
        X_train = X_train / 127.5 - 1.
        X_train = np.expand_dims(X_train, axis=3)

        # 创建标签
        valid = np.ones((batch_size, 1))
        fake = np.zeros((batch_size, 1))

        for epoch in range(epochs):

            # --------------------------- #
            #   随机选取batch_size个图片
            #   对discriminator进行训练
            # --------------------------- #
            idx = np.random.randint(0, X_train.shape[0], batch_size)
            imgs = X_train[idx]

            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))

            gen_imgs = self.generator.predict(noise)

            d_loss_real = self.discriminator.train_on_batch(imgs, valid)
            d_loss_fake = self.discriminator.train_on_batch(gen_imgs, fake)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

            # --------------------------- #
            #  训练generator
            # --------------------------- #
            noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
            g_loss = self.combined.train_on_batch(noise, valid)
            print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

            if epoch % sample_interval == 0:
                self.sample_images(epoch)

    def sample_images(self, epoch):

        r, c = 5, 5
        noise = np.random.normal(0, 1, (r * c, self.latent_dim))
        gen_imgs = self.generator.predict(noise)

        gen_imgs = 0.5 * gen_imgs + 0.5

        fig, axs = plt.subplots(r, c)
        cnt = 0
        for i in range(r):
            for j in range(c):
                axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
                axs[i,j].axis('off')
                cnt += 1
        fig.savefig("images/%d.png" % epoch)
        plt.close()


if __name__ == '__main__':
    if not os.path.exists("./images"):
        os.makedirs("./images")
    gan = GAN()
    gan.train(epochs=3000, batch_size=256, sample_interval=200)

运行结果

 

标签:loss,img,16,self,GAN,add,import,手写,model
From: https://www.cnblogs.com/liucaizhi/p/18233844

相关文章

  • 实验14-使用cnn完成MNIST手写体识别
    实验14-1使用cnn完成MNIST手写体识别(tf).pyimporttensorflowastf#Tensorflow提供了一个类来处理MNIST数据fromtensorflow.examples.tutorials.mnistimportinput_dataimporttime#载入数据集mnist=input_data.read_data_sets('MNIST',one_hot=True)#设置批次......
  • P4785 [BalticOI 2016 Day2] 交换 题解
    看到\(i\)和\(\lfloor\frac{i}{2}\rfloor\),考虑一颗二叉树。题目的操作相当于按顺序交换当前节点和左右儿子的权值。假设当前考虑的节点为\(id\),左儿子为\(ls\),右儿子为\(rs\),当前这些点的值分别为\(A,B,C\)。因为\(id\)的位置最靠前,最终又要字典序最小,所以要尽可能......
  • 400、基于51单片机的电压表(1路,ADC0832,LCD1602)(程序+Proteus仿真+原理图+流程图+元器件
    毕设帮助、开题指导、技术解答(有偿)见文未目录方案选择单片机的选择显示器选择方案一、设计功能二、Proteus仿真图单片机模块设计三、原理图四、程序源码资料包括:需要完整的资料可以点击下面的名片加下我,找我要资源压缩包的百度网盘下载地址及提取码。方案选择......
  • P1654 OSU! 题解
    P1654OSU!题解题目链接好题!但不得不说早期洛谷的题解质量是真的差,感觉没有一篇题解是讲的特别清楚的,我看了好久才搞懂。下面是我认为的一种更规范的解题过程。首先,我们设随机变量\(X_i\)表示从\(i\)向左的极长1串的长度,并且对于任意的\(i\),我们要想办法求出\(E(X_i......
  • (数据科学学习手札161)高性能数据分析利器DuckDB在Python中的使用
    本文完整代码及附件已上传至我的Github仓库https://github.com/CNFeffery/DataScienceStudyNotes1简介大家好我是费老师,就在几天前,经过六年多的持续开发迭代,著名的开源高性能分析型数据库DuckDB发布了其1.0.0正式版本。DuckDB具有极强的单机数据分析性能表现,功能丰......
  • P7860 [COCI2015-2016#2] ARTUR
    原题链接教训1.计算几何,能用乘法就不用除法2.计算几何,开longlong3.计算几何,注意直线的特殊性code#include<bits/stdc++.h>#definelllonglongusingnamespacestd;structnode{llx1,y1,x2,y2;}sk[5005];intcheck(nodea,nodeb){if(a.x2<b.x1||a.x1>b.......
  • 1689D Lena and Matrix (曼哈顿距离转切比雪夫距离/随机化/线段树)
    记一道有趣的题:P题意这道题很有意思。给定地图上若干个黑色的点,求这样一个点的坐标,满足其到图中任何一个黑色点的最大曼哈顿距离最小。\(max(|a-x_i|+|b-y_i|),i=1,2..k\)方法一曼哈顿距离和且比雪夫距离可以互相转化,曼哈顿转切比雪夫如下:\((x,y)\to(x+y,x-y)\)转化后......
  • GPEN——使用GANs恢复对人脸图像进行修复
    1.简介盲目的面部修复(BlindFaceRestoration,BFR)是一个活跃的研究领域,它涉及到在没有任何先验信息的情况下改善低质量(LowQuality,LQ)图像的质量。这确实是一个具有挑战性的问题,因为模型需要能够处理多种未知的退化,例如模糊、噪声、压缩伪影等,这些退化可能在训练数据中......
  • 5.16
    Python面向对象基础训练班级:信2205-1学号:20224074 姓名:王晨宇一实验目的l 使学生掌握Python下类与对象的基本应用;l 使学生掌握Python下继承与多态的基本应用;l 使学生掌握Python接口的基本应用;l 使学生掌握Python异常处理的基本应用;二实验环境及实验准备l ......
  • 286、基于51单片机的温度报警(8路,DS18B20,热电偶,LCD1602)
    完整资料或定制滴滴我(有偿)见文末。目录一、设计功能二、Proteus仿真三、原理图四、程序源码五、资料包括一、设计功能多路温度采集系统1、刺激4路DS18B20温度和4路热电偶温度2、自动循环显示每路温度值3、设置温度上下限,温度过限报警二、Proteus仿真......