首页 > 其他分享 >tensorflow实现花分类

tensorflow实现花分类

时间:2023-06-12 20:22:12浏览次数:50  
标签:layers plt keras 实现 分类 tf tensorflow model size

1. 花数据集

  数据集来自kaggle官网下载。分为五类花,每类花有1000张图片。下载方式可以参考我的https://www.cnblogs.com/wancy/p/17446715.html

 

2. 图片大小分布图

  训练模型之前,我们会需要先分析数据集,由于此类数据集每类花的图片数量一样,是均衡的。训练模型之前,我们需要传入合适图片大小输入,选择合适的图片大小也会对模型最终性能有一定影响。所以训练模型之前先观察图片大小分布。

import os
import cv2
import numpy as np
import matplotlib.pyplot as plt

# 读取图片文件夹中所有图片的大小信息
def get_image_sizes(image_folder):
   sizes = []
   for folder_name in os.listdir(image_folder):
       # print("文件夹名"+folder_name)
       folder_path = os.path.join(image_folder, folder_name)
       # print(folder_path)
       if os.path.isdir(folder_path):
           for img_name in os.listdir(folder_path):
               image = cv2.imread(folder_path+"/"+img_name)
               sizes.append((image.shape[0], image.shape[1]))  # (height, width)
   return sizes

# 绘制散点图
def plot_size_scatter(sizes):
   x = [size[1] for size in sizes] # width
   y = [size[0] for size in sizes] # height
   plt.scatter(x, y, s=5, alpha=0.5, c='steelblue')
   plt.title('Image Size Scatter')
   plt.xlabel('Width (pixel)')
   plt.ylabel('Height (pixel)')
   plt.show()

# 代码
if __name__ == '__main__':
   sizes = get_image_sizes('./5-flower-types-classification-dataset/flower_images')
   print(f'Total images: {len(sizes)}')
   plot_size_scatter(sizes)

  可以发现,图片大部分大小范围在(200x200,1000,1000)以内,在此之内选择比较合适。

3. tensorflow训练模型

import tensorflow as tf
import os

from matplotlib import pyplot as plt
#from tensorflow.keras.preprocessing import image
from tensorflow.keras.callbacks import TensorBoard
# 创建TensorBoard回调函数
# tensorboard = TensorBoard(log_dir='./log', historytogram_freq=1)
data_dir = './5-flower-types-classification-dataset/flower_images'  # 数据路径

# 设置图像尺寸和批次大小
img_size = (224, 224)
batch_size = 32

# 将数据文件夹名称存储在列表中
data_folders = os.listdir(data_dir)

# 使用 ImageDataGenerator 类读取和增强数据
train_datagen = image.ImageDataGenerator(rescale=1. / 255,rotation_range=20,zoom_range=0.2,horizontal_flip=True,validation_split=0.2)

# 读取并划分训练集和测试集数据
train_generator = train_datagen.flow_from_directory(data_dir,target_size=img_size,batch_size=batch_size,class_mode='categorical',subset='training')
valid_generator = train_datagen.flow_from_directory(data_dir,target_size=img_size,batch_size=batch_size,class_mode='categorical',subset='validation')


#类别标签字典
class_indices = train_generator.class_indices
print("class_indices",class_indices)####{'Lilly': 0, 'Lotus': 1, 'Orchid': 2, 'Sunflower': 3, 'Tulip': 4}
# 构建模型
# model = tf.keras.Sequential([
#     tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(img_size[0], img_size[1], 3)),
#     tf.keras.layers.MaxPooling2D((2, 2)),
#     tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
#     tf.keras.layers.MaxPooling2D((2, 2)),
#     tf.keras.layers.Conv2D(128, (3, 3), activation='relu'),
#     tf.keras.layers.MaxPooling2D((2, 2)),
#     tf.keras.layers.Flatten(),
#     tf.keras.layers.Dense(512, activation='relu'),
#     tf.keras.layers.Dropout(0.5),
#     tf.keras.layers.Dense(len(data_folders), activation='softmax')
# ])
#定义模型
model=tf.keras.Sequential()
#conv_layer = layers.Conv2D(64, (3, 3), bias_initializer="zeros", kernel_initializer=GlorotUniform(seed=42))
model.add(tf.keras.layers.Conv2D(32,(3,3),input_shape=(img_size[0], img_size[1], 3),activation="relu",padding="same"))
model.add(tf.keras.layers.Conv2D(32,(3,3),activation="relu",padding="same"))
#tf.keras.layers.Dropout(0.5)
model.add(tf.keras.layers.MaxPool2D())#默认2*2,步长也为2

model.add(tf.keras.layers.Conv2D(64,(4,4),activation='relu',padding="same"))
model.add(tf.keras.layers.Conv2D(128,(3,3),activation='relu',padding="same"))
model.add(tf.keras.layers.MaxPool2D())#默认2*2,步长也为2
#4维转2维
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512,activation='relu'))
model.add(tf.keras.layers.Dense(256,activation='relu'))
model.add(tf.keras.layers.Dense(5,activation='softmax'))
#dense_layer = layers.Dense(64, kernel_initializer=GlorotUniform(seed=42))

# 编译和训练模型
# optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
# model.compile(optimizer=optimizer, ...)
"""
在tensorflow.keras中,
如果没有手动指定优化器的学习率,那么model.compile默认使用的Adam优化器的学习率为0.001
"""

#sparse_categorical_crossentropy要求target为非onehot编码,函数内部进行onehot编码实现。categorical_crossentropy要求target为onehot编码。
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# history = model.fit(train_generator,steps_per_epoch=len(train_generator),epochs=3,validation_data=valid_generator,validation_steps=len(valid_generator),callbacks=[tensorboard])
history = model.fit(train_generator,steps_per_epoch=len(train_generator),epochs=10,shuffle=True,validation_data=valid_generator,validation_steps=len(valid_generator))
#shuffle=True
#每个epoch后,将使用validation_steps个batch的数据进行评估
print(history.history)

# 评估模型并输出结果
test_loss, test_acc = model.evaluate(valid_generator, verbose=2)
print('Test accuracy:', test_acc)
model.save('./model/my_model.h5')
#####################################################################
print(model.summary())
"""
model.summary()是TensorFlow中用于打印模型结构信息的函数。它会输出模型的各层名称、类型、输入和输出张量的形状等信息,以帮助我们了解模型的结构和参数数量。
"""

#画图 性能评估
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss =history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(15, 5))
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
#plt.grid()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
#plt.grid()
plt.show()

   运行结果:

4. 测试效果

  网上找了几张图片测试。不是很准。

小结:模型训练用得是kaggle平台,训练了10个epoch。准确率达到了70%左右。增加epoch应该能提高准确率。另外,当我将模型从kaggle下载出来拷贝到本地测试分类时,报错了,后来发现,训练的版本为tf2.12.0,本地为2.6.0,相差较大,后改为2.11.0就不报错了。

 

 

 

参考资料:

https://cloud.tencent.com/developer/article/2094683?areaSource=102001.3&traceId=LonVetBuRWMw4aKCESBvu

 

  若存在不足或错误之处,欢迎评论与指正!

 

标签:layers,plt,keras,实现,分类,tf,tensorflow,model,size
From: https://www.cnblogs.com/wancy/p/17476009.html

相关文章

  • H5实现左右滑动手势
    使用已有的轮子简单实现H5左右滑动手势安装vue2-touch-eventsnpminstallvue2-touch-events在main.js中引入//main.jsimportVue2TouchEventsfrom'vue2-touch-events'Vue.use(Vue2TouchEvents)通过自定义指令使用<!--template--><!--需要监听左右滑动手势的......
  • 【Ubuntu22.04】安装MySQL数据库,修改root用户密码,实现远程访问,
    预备条件本次实验使用静态IP的地址192.168.1.81作为mysql-001服务器地址,并配置为本地域名mysql-001:打开Powershell(Window自带)使用SSH方式连接服务器,用户名test,密码:123456:安装Mysql:更新软件源aptupdate安装MySQL8.0,因为Ubuntu22.04不支持MySQL5.7sudoaptinst......
  • js实现复制粘贴
    在一些页面里,有时候会需要用户点击按钮或者控件需要把一些文字内容写入用户设备的剪切板里。在js中如何通过代码实现?接下来是两种实现方法!使用document.execCommandAPI注意document.execCommandAPI是同步执行,如果数据量大可能会阻塞页面加载,这种办法能兼容老版本浏览器和大......
  • Jenkins Pipeline 密钥实现远程部署
    前提:已配置jenkins秘钥凭证 一、配置流程1.1片段生成1、按如下图选择2、新增密钥信息1.2脚本配置以上配置完成后,接下来就可以在Jenkinsfile中配置了,:stages{stage('xx启动'){steps{echo"xx启动"dir("${SRC_PATH}")......
  • Redis高可用的三种实现方式
    Redis高可用的三种实现方式一、高可用的概念​高可用(HighAvailability,即HA),指的是通过尽量缩短日常维护操作和突发的系统崩溃所导致的停机时间,以提高系统和应用的可用性。一个业务系统如果全年无一时刻不在提供服务,它的可用性可达100%。那么什么样的系统可以称之为高可用呢,业......
  • 基于Tensorflow的Faster-Rcnn的断点续训
    一、前言最近在学习目标检测,到github上找了一个开源的Faster-RCNN项目(Tensorflow),项目地址是:https://github.com/dBeker/Faster-RCNN-TensorFlow-Python3根据网上的各种教程,模型训练还算顺利,不过这个项目缺少断点续训的功能。也就是中途误操作导致训练中止,就只能从头开始......
  • 如何实现不同服务器之间 大体量的数据自动同步?
    随着企业结构分散化的不断扩大,企业的数据中心、服务器节点、异地分支机构之间,会存在多种文件交换场景。传统的FTP、rsync、网盘等传输方式在数据体量较小、时效性要求不高的情况下,基本也可以满足需求。但随着数量爆发式增长,需要及时分析使用数据的情况下,就不太够用了,弊端也随之体......
  • .net core 微服务 集成Ocelot 和Nacos 之后使用grpc 如何实现服务与服务之间的调用
    在.NETCore微服务中使用gRPC调用其他服务,你需要完成以下步骤:1.定义服务契约:你需要定义你的服务、方法以及消息类型,以便客户端和服务端协商通信。2.生成代码:你需要使用gRPC工具生成客户端和服务端的代码,这样你就可以在应用程序中使用它们。3.实现服务:你需要实现......
  • java实现dwg转pdf
    提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档文章目录前言一、无奈选择第二种二、jar引入1.jar地址2.使用jar,完成dwg转为pdf总结前言由于公司需要最近研究一个cad文件需要在浏览器中展示,经过研究发现大致有两种方式:1将.dwg转换为vds文件,就可以在web端展示......
  • redis的消息发布订阅实现
    文章目录前言一、创建好springboot项目,引入核心依赖二、使用步骤1.自定义一个消息接受类2.声名一个消息配置类3.编写一个测试类总结前言一般项目中都会使用redis作为缓存使用,加速用户体验,实现分布式锁等等,redis可以说为项目中的优化,关键技术实现立下了汗马功劳.今天带来它......