首页 > 其他分享 >机器学习-哺乳动物识别

机器学习-哺乳动物识别

时间:2022-12-21 18:45:52浏览次数:40  
标签:plt 机器 keras 哺乳动物 names import model 识别 class

机器学习——哺乳动物识别

(一)选题背景:

哺乳动物是动物世界中形态结构最高等、生理机能最完善的动物。与其他动物相比,哺乳动物最突出的特征在于胎生以及其幼崽由母体分泌的乳汁喂养长大。所有哺乳动物都长有毛发,以保持体温的恒定,适应各种复杂的生存环境;哺乳动物具有比较发达的大脑,因而能产生比其他动物更为复杂的行为,并能不断地改变自己的行为,以适应外界环境的变化。

(二)选题原因

最开始接触是手写,到后面开始学习到猫狗,我的想法就是写一个比较偏向于动物类型的,就写了一个关于猴子的识别,图片量大概在10000左右,但是为觉得猴子的判定太单调了,就加大类型选择到了哺乳动物

(三)实现步骤

1.导入对应包

#导入需要用到的库
import numpy as np
import pandas as pd
import os
import tensorflow as tf
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import Activation
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.applications.resnet import preprocess_input

from keras_preprocessing.image import ImageDataGenerator
from keras.models import load_model
from keras.utils import image_utils
from keras import optimizers

from tensorflow.python.keras import layers, models
from keras.preprocessing.image import ImageDataGenerator as IDG
import plotly.express as px
from keras.layers import GlobalAveragePooling2D as GAP, Dense, Dropout

from keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.applications import ResNet50V2, InceptionV3, Xception, ResNet152V2
from keras.callbacks import EarlyStopping, ModelCheckpoint

2.获取文件夹

# 训练集
root_path = './animals10/raw-img/'

# 获取文件夹的名字
class_names = sorted(os.listdir(root_path))
n_classes = len(class_names)

print(f"Class Names : {class_names}")
print(f"Number of Classes  : {n_classes}")
# 每个动物的个数
class_dis = [len(os.listdir(root_path + name)) for name in class_names]
class_dis
# Visualization
fig = px.pie(names=class_names, values=class_dis, title="哺乳动物种类训练集分布")
fig.update_layout({'title':{'x':0.45}})
fig.show()

image
image
image

# 测试集
pred_root_path = './animals10/test/'

# 获取文件夹的名字
testclass_names = sorted(os.listdir(pred_root_path))
n_testclasses = len(testclass_names)

print(f"Class Names : {testclass_names}")
print(f"Number of Classes  : {n_testclasses}")

# 动物的分布
testclass_dis = [len(os.listdir(pred_root_path + name)) for name in class_names]
testclass_dis

# Visualization
figtest = px.pie(names=testclass_names, values=testclass_dis, title="哺乳动物种类训练集分布")
figtest.update_layout({'title':{'x':0.45}})
figtest.show()

image
image
image

3.预处理

# 预处理
train_gen = IDG(rescale=1./255, rotation_range=10, horizontal_flip=True, validation_split=0.1)
test_gen = IDG(rescale=1./255)

# 加载数据
train_ds = train_gen.flow_from_directory(root_path,
                                         target_size=(256,256), 
                                         shuffle=True, 
                                         batch_size=32, 
                                         subset='training', 
                                         class_mode='binary')
valid_ds = train_gen.flow_from_directory(root_path, 
                                         target_size=(256,256), 
                                         shuffle=True,
                                         batch_size=32, 
                                         subset='validation', 
                                         class_mode='binary')
test_ds = train_gen.flow_from_directory(pred_root_path, 
                                        target_size=(256,256), 
                                        shuffle=True, 
                                        batch_size=32, 
                                        class_mode='binary')

4.图片显示函数

def show_images(data, GRID=[6,5], size=(20,25), model=None, class_names=class_names):

    # 行列
    n_rows, n_cols = GRID
    n_images = n_rows * n_cols

    # 大小
    plt.figure(figsize=size)

    # 数据
    images, labels = next(iter(data))
    
    # Do not repeat id
    ids = []

    # 数据
    for i in range(1, n_images+1):
        
        #选择随机的图片
        id = np.random.randint(len(images))
        if id in ids:
            id = np.random.randint(len(images))
        ids.append(id)
        image, label = images[id], class_names[int(labels[id])]

        
        plt.subplot(n_rows, n_cols, i)
        plt.imshow(image)
        plt.axis('off')

        # 选择标题
        if model is None:
            title = f"Class : {label}"
        else:
            pred = class_names[np.argmax(model.predict(image[np.newaxis,...]))]
            title = f"Class : {label} \nPred : {pred}"
        plt.title(title)


    # Show Plot
    plt.show()
	
	
#运行测试
show_images(data=train_ds)

image

5.训练模型

# 模型编译
base_model = ResNet50V2(input_shape=(256,256,3), include_top=False)
base_model.trainable = False

# 模型的名字
name = "animals"

model = Sequential([
    base_model,
    GAP(),
    Dense(256, activation='relu', kernel_initializer='he_normal'),
    Dropout(0.2),
    Dense(10, activation='softmax')
], name=name)

# 编译模型
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Callbacks 
cbs = [EarlyStopping(patience=3, restore_best_weights=True), ModelCheckpoint(name + ".h5", save_best_only=True)]

# 训练
h2=model.fit(train_ds, validation_data=valid_ds, callbacks=cbs, epochs=10)

image

6.绘制损失和精度曲线

#绘制精度和损失曲线
accuracy = h2.history['accuracy']
loss = h2.history['loss']
val_loss = h2.history['val_loss']
val_accuracy = h2.history['val_accuracy']
plt.figure(figsize=(17, 7))
plt.subplot(2, 2, 1)
plt.plot(range(7), accuracy,'bo', label='Training Accuracy')
plt.plot(range(7), val_accuracy, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy : Training vs. Validation ')
plt.subplot(2, 2, 2)
plt.plot(range(7), loss,'bo' ,label='Training Loss')
plt.plot(range(7), val_loss, label='Validation Loss')
plt.title('Loss : Training vs. Validation ')
plt.legend(loc='upper right')
plt.show()

image

  1. 模型加载,评估,测试
# 模型加载
model_path = './animals.h5'
model = load_model(model_path)
model.summary()

image

# 模型评估
model.evaluate(test_ds)

image

# 结果预测
show_images(data=test_ds, model=model, GRID=[6,5], size=(25,25))

image

(四)全部代码

#导入需要用到的库
import numpy as np
import pandas as pd
import os
import tensorflow as tf
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.model_selection import train_test_split
from keras.models import Sequential
from keras.layers import Activation
from keras.layers import Dense, Dropout, Flatten, Conv2D, MaxPooling2D
from keras.applications.resnet import preprocess_input

from keras_preprocessing.image import ImageDataGenerator
from keras.models import load_model
from keras.utils import image_utils
from keras import optimizers

from tensorflow.python.keras import layers, models
from keras.preprocessing.image import ImageDataGenerator as IDG
import plotly.express as px
from keras.layers import GlobalAveragePooling2D as GAP, Dense, Dropout

from keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.applications import ResNet50V2, InceptionV3, Xception, ResNet152V2
from keras.callbacks import EarlyStopping, ModelCheckpoint

# 训练集
root_path = './animals10/raw-img/'

# 获取文件夹的名字
class_names = sorted(os.listdir(root_path))
n_classes = len(class_names)

print(f"Class Names : {class_names}")
print(f"Number of Classes  : {n_classes}")
# 每个动物的个数
class_dis = [len(os.listdir(root_path + name)) for name in class_names]
class_dis
# Visualization
fig = px.pie(names=class_names, values=class_dis, title="哺乳动物种类训练集分布")
fig.update_layout({'title':{'x':0.45}})
fig.show()

# 测试集
pred_root_path = './animals10/test/'

# 获取文件夹的名字
testclass_names = sorted(os.listdir(pred_root_path))
n_testclasses = len(testclass_names)

print(f"Class Names : {testclass_names}")
print(f"Number of Classes  : {n_testclasses}")

# 动物的分布
testclass_dis = [len(os.listdir(pred_root_path + name)) for name in class_names]
testclass_dis

# Visualization
figtest = px.pie(names=testclass_names, values=testclass_dis, title="哺乳动物种类训练集分布")
figtest.update_layout({'title':{'x':0.45}})
figtest.show()

# 预处理
train_gen = IDG(rescale=1./255, rotation_range=10, horizontal_flip=True, validation_split=0.1)
test_gen = IDG(rescale=1./255)

# Load the data
train_ds = train_gen.flow_from_directory(root_path,
                                         target_size=(256,256), 
                                         shuffle=True, 
                                         batch_size=32, 
                                         subset='training', 
                                         class_mode='binary')
valid_ds = train_gen.flow_from_directory(root_path, 
                                         target_size=(256,256), 
                                         shuffle=True,
                                         batch_size=32, 
                                         subset='validation', 
                                         class_mode='binary')
test_ds = train_gen.flow_from_directory(pred_root_path, 
                                        target_size=(256,256), 
                                        shuffle=True, 
                                        batch_size=32, 
                                        class_mode='binary')

def show_images(data, GRID=[6,5], size=(20,25), model=None, class_names=class_names):

    # 行列
    n_rows, n_cols = GRID
    n_images = n_rows * n_cols

    # 大小
    plt.figure(figsize=size)

    # 数据
    images, labels = next(iter(data))
    
    # Do not repeat id
    ids = []

    # 数据
    for i in range(1, n_images+1):
        
        # 选择随机的图片
        id = np.random.randint(len(images))
        if id in ids:
            id = np.random.randint(len(images))
        ids.append(id)
        image, label = images[id], class_names[int(labels[id])]

        
        plt.subplot(n_rows, n_cols, i)
        plt.imshow(image)
        plt.axis('off')

        # 选择标题
        if model is None:
            title = f"Class : {label}"
        else:
            pred = class_names[np.argmax(model.predict(image[np.newaxis,...]))]
            title = f"Class : {label} \nPred : {pred}"
        plt.title(title)


    # Show Plot
    plt.show()

#运行测试
show_images(data=train_ds)

# 模型编译
base_model = ResNet50V2(input_shape=(256,256,3), include_top=False)
base_model.trainable = False

# 模型的名字
name = "animals"

model = Sequential([
    base_model,
    GAP(),
    Dense(256, activation='relu', kernel_initializer='he_normal'),
    Dropout(0.2),
    Dense(10, activation='softmax')
], name=name)

# 编译模型
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# Callbacks 
cbs = [EarlyStopping(patience=3, restore_best_weights=True), ModelCheckpoint(name + ".h5", save_best_only=True)]

# 训练
h2=model.fit(train_ds, validation_data=valid_ds, callbacks=cbs, epochs=10)

#绘制精度和损失曲线
accuracy = h2.history['accuracy']
loss = h2.history['loss']
val_loss = h2.history['val_loss']
val_accuracy = h2.history['val_accuracy']
plt.figure(figsize=(17, 7))
plt.subplot(2, 2, 1)
plt.plot(range(7), accuracy,'bo', label='Training Accuracy')
plt.plot(range(7), val_accuracy, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Accuracy : Training vs. Validation ')
plt.subplot(2, 2, 2)
plt.plot(range(7), loss,'bo' ,label='Training Loss')
plt.plot(range(7), val_loss, label='Validation Loss')
plt.title('Loss : Training vs. Validation ')
plt.legend(loc='upper right')
plt.show()

# 模型加载
model_path = './animals.h5'
model = load_model(model_path)
model.summary()

# 模型评估
model.evaluate(test_ds)

# 结果预测
show_images(data=test_ds, model=model, GRID=[6,5], size=(25,25))

标签:plt,机器,keras,哺乳动物,names,import,model,识别,class
From: https://www.cnblogs.com/yh-bg/p/animals.html

相关文章

  • 带你玩转OpenHarmony AI:基于Seetaface2的人脸识别
     简介随着时代的进步,全民刷脸已经成为一种新型的生活方式,这也是全球科技进步的又一阶梯,人脸识别技术已经成为一种大趋势,无论在智慧出行、智能家居、智慧办公等场景均......
  • Vite插件快速识别-性能篇
    Vite快速识别之性能篇1、分包策略:浏览器重复请求相同名称的静态资源时,会直接使用缓存的资源。利用这个机制将不会经常更新的代码单独打包,减少HTTP请求降低服务器压力。......
  • Vite插件快速识别-开发篇
    Vite插件快速识别-日常开发篇1、打包构建后移除console.log和注释:vite官方自带//vite.config.tsimport{defineConfig}from'vite'exportdefaultdefineConfig({......
  • Electron-Vite快速识别
    构建Vite-electron项目npmielectron-vite-Dnpmcreate@quick-start/electronproject-name--templatevue-tsElectron的运行流程Electron进程一、主进程:有且......
  • H5 雪碧图 移动的机器猫
    精灵图(英语:Sprite),又被称为雪碧图或拼合图。在计算机图形学中,当一张二维图像集成进场景中,成为整个显示图像的一部分时,这张图就称为精灵图。本文中用的就是这张,来自爱给......
  • 中学数学知识点实体识别
    中学数学知识点实体识别⚠️所有有关智慧教育的项目已完结停更,不再维护,感谢您的支持构建中学数学知识图谱的第一步是完成数学命名实体识别。1.Overview  本文定义是:中学数......
  • 机器学习——植物叶片病害识别
    机器学习——植物叶片病害识别一、选题背景随着现代科技的发展,人们对于人工智能领域的研究越发的深入。机器学习作为人工智能和识别领域研究的重要课题非常值得我们......
  • Vue-router4.0接口快速识别
    Vue-router4.0接口快速识别<router-link> :将会被渲染a标签属性名属性类型属性作用tostring/object相当于跳转调用router.push(string/object)replacebo......
  • PaddlePaddle 实现手写数字识别
    PaddlePaddle实现手写数字识别在这次实验中我们将使用PaddlePaddle来实现三种不同的分类器,用于识别手写数字。三种分类器所基于的模型分别为Softmax回归、多层感知器、......
  • 机器学习——人脸性别识别
    一、选题背景    人脸识别技术是模式识别和计算机视觉领域最富挑战性的研究课题之一,也是近年来的研究热点,人脸性别识别作为人脸识别技术的重要组成部分也受到了广......