首页 > 其他分享 >自编码器_【手写数字】

自编码器_【手写数字】

时间:2022-11-10 15:04:00浏览次数:44  
标签:src 编码器 Dense 数字 shape train test encoded 手写


自编码器

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
import numpy as np


print ("start")

def train_model():
mnist=tf.keras.datasets.mnist

#获取数据,训练集,测试集 60k训练,10K测试
(x_train,y_train),(x_test,y_test)=mnist.load_data()


#数据集格式转换
x_train = x_train.astype('float32')/255.0 - 0.5
x_test = x_test.astype('float32')/255.0 - 0.5


x_train=x_train.reshape(x_train.shape[0],-1)
x_test=x_test.reshape(x_test.shape[0],-1)
print(x_train.shape,x_test.shape)


# 输入是大小为28x28,灰度图像
img_shape = (784)
# batchsize 为16
batch_size = 16
# 输出的潜在空间的维度
latent_dim = 128

input_img = tf.keras.Input(shape=(784,))
input_img_ = tf.keras.Input(shape=(128,))



encoded = Dense(128,activation="relu")(input_img)
encoded = Dense(64,activation="relu")(encoded)
encoded = Dense(10,activation="relu")(encoded)
encoder_output = Dense(latent_dim,)(encoded)

dencoded = Dense(10,activation="relu")(encoder_output)
dencoded = Dense(64,activation="relu")(dencoded)
dencoded = Dense(128,activation="relu")(dencoded)
dencoded = Dense(784,activation="tanh")(dencoded)

autoencoder = Model(input_img,dencoded)
encoder = Model(input_img,encoder_output)

encoded_imgs = encoder.predict(x_test)


adam_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001, beta_1=0.5)
autoencoder.compile(optimizer=adam_optimizer,loss="mse")
autoencoder.fit(x_train,x_train,epochs=5,batch_size=10,shuffle=True)

autoencoder.save("autoencoder.h5")
#encoder.save("encoder.h5")

encoded_imgs = encoder.predict(x_test)
print (encoded_imgs.shape)
plt.scatter(encoded_imgs[:,0],encoded_imgs[:,1],c=y_test)
plt.show()

train_model()
print ("end")

预测

#coding=utf-8

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model,load_model
import matplotlib.pyplot as plt
import numpy as np
import cv2

print ("start")
def cv2_display(src):
cv2.imshow('src',src)
cv2.waitKey(0)
cv2.destroyAllWindows()

def predict_model():
mnist=tf.keras.datasets.mnist
#获取数据,训练集,测试集 60k训练,10K测试
(x_train,y_train),(x_test,y_test)=mnist.load_data()
x_test = x_test[:10]
cv2.imwrite("test.png",x_test[0])
#数据集格式转换
x_train = x_train.astype('float32')/255.0 - 0.5
x_test = x_test.astype('float32')/255.0 - 0.5


x_train=x_train.reshape(x_train.shape[0],-1)
x_test=x_test.reshape(x_test.shape[0],-1)
print(x_train.shape,x_test.shape)

autoencoder = load_model("autoencoder.h5")
moto_img = autoencoder.predict(x_test)
print (moto_img.shape)
moto_src = tf.reshape(moto_img[0],(28,28))
moto_src = ((moto_src + 0.5)*255.0)
moto_src = np.asarray(moto_src)
cv2.imwrite("test_output.png",moto_src)


predict_model()
print ("end")

原始图片

自编码器_【手写数字】_tensorflow


预测图片(自编码器预测输出的图片)

自编码器_【手写数字】_h5_02

自己利用数据训练编码器解码器

编码器

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model,load_model
import matplotlib.pyplot as plt
import numpy as np
import cv2

print ("start")

def cv2_display(src):
cv2.imshow('src',src)
cv2.waitKey(0)
cv2.destroyAllWindows()

def train_model():
mnist=tf.keras.datasets.mnist

#获取数据,训练集,测试集 60k训练,10K测试
(x_train,y_train),(x_test,y_test)=mnist.load_data()


#数据集格式转换
x_train = x_train.astype('float32')/255.0 - 0.5
x_test = x_test.astype('float32')/255.0 - 0.5


x_train=x_train.reshape(x_train.shape[0],-1)
x_test=x_test.reshape(x_test.shape[0],-1)
print(x_train.shape,x_test.shape)


# 输入是大小为28x28,灰度图像
img_shape = (784)
# batchsize 为16
batch_size = 16
# 输出的潜在空间的维度
latent_dim = 128

input_img_1 = tf.keras.Input(shape=(784,))
input_img_2 = tf.keras.Input(shape=(128,))


encoded = Dense(128,activation="relu")(input_img_1)
encoded = Dense(64,activation="relu")(encoded)
encoded = Dense(10,activation="relu")(encoded)
encoder_output = Dense(latent_dim,)(encoded)

dencoded = Dense(10,activation="relu")(input_img_2)
dencoded = Dense(64,activation="relu")(dencoded)
dencoded = Dense(128,activation="relu")(dencoded)
dencoded = Dense(784,activation="tanh")(dencoded)

encoder = Model(input_img_1,encoder_output)
encoder.save("transform_128_encoder.h5")
Y_train = encoder.predict(x_train)
Y_test = encoder.predict(x_test)

np.save("Y_train.npy",Y_train)
np.save("Y_test.npy",Y_test)


train_model()
print ("end")

说明:可以将28*28的手写数字转换为128维,维度可以自定义。

解码器

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model,load_model
import matplotlib.pyplot as plt
import numpy as np
import cv2

print ("start")

def cv2_display(src):
cv2.imshow('src',src)
cv2.waitKey(0)
cv2.destroyAllWindows()

def train_model():
mnist=tf.keras.datasets.mnist

#获取数据,训练集,测试集 60k训练,10K测试
(x_train,y_train),(x_test,y_test)=mnist.load_data()


#数据集格式转换
x_train = x_train.astype('float32')/255.0 - 0.5
x_test = x_test.astype('float32')/255.0 - 0.5


x_train=x_train.reshape(x_train.shape[0],-1)
x_test=x_test.reshape(x_test.shape[0],-1)
print(x_train.shape,x_test.shape)


# 输入是大小为28x28,灰度图像
img_shape = (784)
# batchsize 为16
batch_size = 16
# 输出的潜在空间的维度
latent_dim = 128

input_img_1 = tf.keras.Input(shape=(784,))
input_img_2 = tf.keras.Input(shape=(128,))



encoded = Dense(128,activation="relu")(input_img_1)
encoded = Dense(64,activation="relu")(encoded)
encoded = Dense(10,activation="relu")(encoded)
encoder_output = Dense(latent_dim,)(encoded)

dencoded = Dense(10,activation="relu")(input_img_2)
dencoded = Dense(64,activation="relu")(dencoded)
dencoded = Dense(128,activation="relu")(dencoded)
dencoded = Dense(784,activation="tanh")(dencoded)

dencoder = Model(input_img_2,dencoded)
Y = np.load("Y_train.npy")
adam_optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001)
dencoder.compile(optimizer=adam_optimizer,loss="mse")
dencoder.fit(Y,x_train,epochs=100,batch_size=60,shuffle=True)
dencoder.save("transform_784_encoder.h5")
train_model()
print ("end")

说明:将128维的向量解码为手写数字,需要训练,相当于反操作。

预测还原数据

import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Model,load_model
import matplotlib.pyplot as plt
import numpy as np
import cv2

print ("start")

def cv2_display(src):
cv2.imshow('src',src)
cv2.waitKey(0)
cv2.destroyAllWindows()

def predict_model():
Y = np.load("Y_test.npy")
print (Y.shape)
dencoder = load_model("transform_784_encoder.h5")
encoded_imgs = dencoder.predict(Y)
print (encoded_imgs.shape)
predict_src = tf.reshape(encoded_imgs[0],(28,28))
predict_src = ((predict_src + 0.5)*255.0)
predict_src = np.asarray(predict_src)
cv2.imwrite("1_output.png",predict_src)

predict_model()
print ("end")

自编码器_【手写数字】_编码器_03


说明:可以看出来数据稍微有所不同,缺少了细节,清晰度也有所下降。

结尾

也可以将它迁移到彩色图片上去,但是虽然能够还原轮廓,但是细节部分相差太大,需要使用其他网络,达到更好的效果。

下面的是利用该方案的彩色图片输出效果。

彩色输入图片

自编码器_【手写数字】_机器学习_04


彩色输出图片

自编码器_【手写数字】_h5_05


寻找到更好的方案后会更新下一个。


标签:src,编码器,Dense,数字,shape,train,test,encoded,手写
From: https://blog.51cto.com/u_15872074/5841639

相关文章

  • 利用卷积反卷积实现图片自编码器
    手写数字fromtensorflow.keras.layersimportConv2D,MaxPooling2D,Input,Conv2DTranspose,Flatten,Densefromtensorflow.keras.optimizersimportAdamfromtensorflow.k......
  • 传统企业如何实现数字化转型?
    先就题主关注的问题大致说明下,传统生产制造型企业在数字经济时代,要保持领先不被时代所淘汰,数字化转型是最佳选择,数字化转型的核心就是要构建“业务数字化、数字资产化、资......
  • 能不能手写Vue响应式?前端面试进阶
    Vue视图更新原理Vue的视图更新原理主要涉及的是响应式相关APIObject.defineProperty的使用,它的作用是为对象的某个属性对外提供get、set方法,从而实现外部对该属性的......
  • 手写一个JS函数,实现数组深度扁平化
    要求:把数组arr=[12,34,[122,324],[222,[333]];扁平化思路:创建一个新数组,循环原数组判断每一项是否是数组是的话先递归,在调用const或push方法,不是直接const或push。方法一......
  • 001[Js修炼]手写深拷贝
    /**//编写一个深度克隆函数,满足以下需求(此题考察面较广,注意细节)functiondeepClone(obj){}//deepClone函数测试效果constobjA={name:'jack',birthday:......
  • 数字电路学习
    2022.11.91.状态机的分类一段式:只有一个alwaysblock,把所有的逻辑(输入、输出、状态)都在一个alwaysblock的时序逻辑中实现。这种写法看起来很简洁,但是不利于维护......
  • 数字孪生智能工厂三维可视化系统带动企业数字化变革-深圳华锐视点
    数字经济是未来主要的经济形态和发展方向,促进数字经济和实体经济深度融合是二十大报告做出的重要战略思想,企业数字化转型是积极迎战数字经济的必修课,工厂3D可视化管理......
  • 手写本地缓存实战2—— 打造正规军,构建通用本地缓存框架
    大家好,又见面了。本文是笔者作为掘金技术社区签约作者的身份输出的缓存专栏系列内容,将会通过系列专题,讲清楚缓存的方方面面。如果感兴趣,欢迎关注以获取后续更新。村......
  • 数字孪生的落地应用都有哪些?
    2019年,“数字孪生”热度不断攀升,备受行业内外关注。各大峰会论坛将其作为热议主题,全球最具权威的IT研究与顾问咨询机构Gartner在2019年报告中将其列为十大战略科技发展趋势......
  • 手写map和filter
    mapfunctionmyMap(arr,callback){if(Array.isArray(arr)){if(arr.length===0)returnarr;constbrr=[];for(letitemofarr){brr.pu......