首页 > 其他分享 >Tensorflow2如何读取自制数据集并训练模型?-- Tensorflow自学笔记13

Tensorflow2如何读取自制数据集并训练模型?-- Tensorflow自学笔记13

时间:2024-09-08 11:49:58浏览次数:19  
标签:Tensorflow2 13 -- np savepath train test path mnist

一. 如何自制数据集?

1. 目录结构

以下是自制数据集-手写数字集, 保存在目录 mnist_image_label 下

2. 数据存储格式 

2.1. 目录mnist_train_jpeg_60000 下存放的是 60000张用于测试的手写数字

       如 : 0_5.jpg, 表示编号为0,标签为5的图片

              6_1.jpg, 表示编号为6,标签为1的图片

2.2. 目录mnist_test_jpeg_10000 下存放的是10000张用于测试的手写数字

        图片存储格式与1.1相同

2.3. txt文件 mnist_train_jpg_60000.txt,里面存放的是

     

        比如,第一行  28755_0.jpg   0     前面表示图片名称,后面的0表示该图片对应的标签,这里表示该图片是手写数字0.

2.4. txt文件 mnist_test_jpg_10000.txt   , 存放的是测试数据集的标签

二. 如何读取自制数据集并输入神经网络

以下是test.py 如何读取自制数据集代码

1. 导入需要的库

import tensorflow as tf

 from PIL import Image

import numpy as np

import os

2.设置数据集所在文件目录 

   (test.py, 需和mnist_image_label 目录在同一级目录下)

train_path = './mnist_image_label/mnist_train_jpg_60000/'

train_txt = './mnist_image_label/mnist_train_jpg_60000.txt'

x_train_savepath = './mnist_image_label/mnist_x_train.npy'

y_train_savepath = './mnist_image_label/mnist_y_train.npy'

test_path = './mnist_image_label/mnist_test_jpg_10000/'

test_txt = 'v/mnist_image_label/mnist_test_jpg_10000.txt'

x_test_savepath = './mnist_image_label/mnist_x_test.npy' #训练集输入特征存储文件npy,

y_test_savepath = './mnist_image_label/mnist_y_test.npy' #训练集标签存储文件

3.定义读取数据的函数

def generateds(path, txt):

    f = open(txt, 'r') # 以只读形式打开txt文件

    contents = f.readlines() # 读取文件中所有行

    f.close() # 关闭txt文件

    x, y_ = [], [] # 建立空列表

    for content in contents: # 逐行取出

        value = content.split() # 以空格分开,图片路径为value[0] , 标签为value[1] , 存入列表

        img_path = path + value[0] # 拼出图片路径和文件名

        print('image path....: '+img_path)

        img = Image.open(img_path) # 读入图片

        img = np.array(img.convert('L')) # 图片变为8位宽灰度值的np.array格式

        img = img / 255. # 数据归一化 (实现预处理)

        x.append(img) # 归一化后的数据,贴到列表x

        y_.append(value[1]) # 标签贴到列表y_

        print('loading : ' + content) # 打印状态提示



    x = np.array(x) # 变为np.array格式

    y_ = np.array(y_) # 变为np.array格式

    y_ = y_.astype(np.int64) # 变为64位整型

    return x, y_ # 返回输入特征x,返回标签y_


4.调用定义的函数

if os.path.exists(x_train_savepath) and os.path.exists(y_train_savepath) and os.path.exists(

x_test_savepath) and os.path.exists(y_test_savepath):

    print('-------------Load Datasets-----------------')

    x_train_save = np.load(x_train_savepath)

    y_train = np.load(y_train_savepath)

    x_test_save = np.load(x_test_savepath)

    y_test = np.load(y_test_savepath)

    x_train = np.reshape(x_train_save, (len(x_train_save), 28, 28))

    x_test = np.reshape(x_test_save, (len(x_test_save), 28, 28))

else:

    print('-------------Generate Datasets-----------------')

    x_train, y_train = generateds(train_path, train_txt)

    x_test, y_test = generateds(test_path, test_txt)



    print('-------------Save Datasets-----------------')

    x_train_save = np.reshape(x_train, (len(x_train), -1))

    x_test_save = np.reshape(x_test, (len(x_test), -1))

    np.save(x_train_savepath, x_train_save)

    np.save(y_train_savepath, y_train)

    np.save(x_test_savepath, x_test_save)

    np.save(y_test_savepath, y_test)

5. 搭建神经网络训练数据

model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=['sparse_categorical_accuracy'])

model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1)
model.summary()

标签:Tensorflow2,13,--,np,savepath,train,test,path,mnist
From: https://blog.csdn.net/pisceshsu/article/details/141928421

相关文章

  • AI大语言模型LLM学习-基于Vue3的AI问答页面
    系列文章1.AI大语言模型LLM学习-入门篇2.AI大语言模型LLM学习-Token及流式响应3.AI大语言模型LLM学习-WebAPI搭建前言在上一篇博文中,我们使用Flask这一Web框架结合LLM模型实现了后端流式WebAPI接口,本篇将基于Vue3实现AI问答页面,本人习惯使用HBuilder进行前端页面......
  • [USACO3.2] 香甜的黄油 Sweet Butter(Dijkstra)
     FarmerJohn发现了做出全威斯康辛州最甜的黄油的方法:糖。把糖放在一片牧场上,他知道NNN只奶牛会过来舔它,这样就能做出能卖好价钱的超甜黄油。当然,他将付出额外的费用在奶牛上。FarmerJohn很狡猾。像以前的Pavlov,他知道他可以训练这些奶牛,让它们在听到铃声时去一个特......
  • CF 2008 H
    题目描述给定一个长度为\(N\)的序列\(A\),以及\(Q\)次询问,每次询问给定一个\(x\)。你可以执行以下操作任意次:选择一个\(1\lei\leN\)使得\(A_i\gex\)。令\(A_i\leftarrowA_i-x\)。求\(A\)的最小中位数。这里中位数是\(A\)排序后的第\(\lfloor\frac......
  • 西门子电机编码器参数设置
    SimotionPLC解释1FK70221FK70331FK7(AM20)1FK7(AM24)1FK7(AS20)1FK7(AS24)encoderMode模式PROFIDRIVEPROFIDRIVEPROFIDRIVEPROFIDRIVEPROFIDRIVEPROFIDRIVEABSResolutionIncrements单圈线数PROFIDRIVEPROFIDRIVEPROFIDRIVEPRO......
  • Bash中$10 和 ${10}的区别
    #!/bin/bashfunWithParam(){echo"第一个参数为$1!"echo"第二个参数为$2!"echo"第十个参数为$10!"echo"第十个参数为${10}!"echo"第十一个参数为${11}!"echo"参数总数有$#个!"echo"作为一个字符......
  • [网鼎杯 2020 朱雀组]phpweb
    仔细地话可以看到这题每个一段时间就会刷新一次页面,而且后面还会有一个时间,就很可疑,抓个包试试果然多了几个参数func=date&p=Y-m-d+h%3Ai%3As+a经过搜索发现这是一个函数(用来显示时间,也就证实了前面地图片为什么会出现时间地原因)于是试着就修改函数和参数来执行命令但是最......
  • 【有源码】基于python+爬虫的短视频数据分析与可视化分析flask短视频推荐系统的设计与
    注意:该项目只展示部分功能,如需了解,文末咨询即可。本文目录1.开发环境2系统设计2.1设计背景2.2设计内容3系统展示3.1功能展示视频3.2用户页面3.3管理员页面4更多推荐5部分功能代码1.开发环境开发语言:Python采用技术:flask、爬虫数据库:MySQL开发环境:P......