首页 > 其他分享 >实验12-使用keras预训练模型完成猫狗识别

实验12-使用keras预训练模型完成猫狗识别

时间:2024-04-27 14:44:19浏览次数:22  
标签:#%% 12 keras train test 识别 os dir history

版本python3.7 tensorflow版本为tensorflow-gpu版本2.6

运行结果:

这里我用Gpu进行加速,训练一回9秒,如果不启用gpu,训练一回会很慢。

 

 代码:

#-*- codeing = utf-8 -*-
#@Time : 2022/10/2 11:44
#@Author : 程浩
#@File : 猫狗识别.py
#@Software: PyCharm
import tensorflow as tf
import keras
from keras import layers
import numpy as np
import os
import shutil
from tensorflow.keras.applications import VGG16
from tensorflow.keras.optimizers import Adam



#%%

base_dir = './dc'
train_dir = base_dir + '/train'
train_dog_dir = train_dir + '/dog'
train_cat_dir = train_dir + '/cat'
test_dir = base_dir + '/test'
test_dog_dir = test_dir + '/dog'
test_cat_dir = test_dir + '/cat'
dc_dir = '.=/dc/train'

#%%

if not os.path.exists(base_dir):

    os.mkdir(base_dir)

    os.mkdir(train_dir)
    os.mkdir(train_dog_dir)
    os.mkdir(train_cat_dir)
    os.mkdir(test_dir)
    os.mkdir(test_dog_dir)
    os.mkdir(test_cat_dir)

    fnames = ['cat.{}.jpg'.format(i) for i in range(1000,2000)]
    for fname in fnames:
        src = os.path.join(dc_dir, fname)
        dst = os.path.join(train_cat_dir, fname)
        shutil.copyfile(src, dst)
    #
    fnames = ['cat.{}.jpg'.format(i) for i in range(1000, 1500)]
    for fname in fnames:
        src = os.path.join(dc_dir, fname)
        dst = os.path.join(test_cat_dir, fname)
        shutil.copyfile(src, dst)

    fnames = ['dog.{}.jpg'.format(i) for i in range(1000)]
    for fname in fnames:
        src = os.path.join(dc_dir, fname)
        dst = os.path.join(train_dog_dir, fname)
        shutil.copyfile(src, dst)

    fnames = ['dog.{}.jpg'.format(i) for i in range(1000, 1500)]
    for fname in fnames:
        src = os.path.join(dc_dir, fname)
        dst = os.path.join(test_dog_dir, fname)
        shutil.copyfile(src, dst)

#%%

from keras.preprocessing.image import ImageDataGenerator
train_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)

#%%

train_generator = train_datagen.flow_from_directory(
    train_dir,
    target_size=(200, 200),
    batch_size=20,
    class_mode='binary'
)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(200, 200),
    batch_size=20,
    class_mode='binary'
)

#%% md

# keras内置经典网络实现

#%%
covn_base = VGG16(weights=None, include_top=False)


#%%

covn_base.summary()

#%%

model = keras.Sequential()
model.add(covn_base)
model.add(layers.GlobalAveragePooling2D())
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))

#%%

model.summary()

#%%

covn_base.trainable = False #设置权重不可变,卷积基不可变

#%%

model.summary()

#%%
model.compile(optimizer=Adam(lr=0.001),
              loss='binary_crossentropy',
              metrics=['acc'])


#%%

history = model.fit_generator(
    train_generator,
    steps_per_epoch=100,
    epochs=15,
    validation_data=test_generator,
    validation_steps=50)

#%%

import matplotlib.pyplot as plt
# %matplotlib inline

#%%
print(history.history)

plt.plot(history.epoch, history.history['loss'], 'r', label='loss')
# plt.plot(history.epoch, history.history['val_loss'], 'b--', label='val_loss')
plt.legend()

#%%

plt.plot(history.epoch, history.history['acc'], 'r')
# plt.plot(history.epoch, history.history['val_acc'], 'b--')

 

标签:#%%,12,keras,train,test,识别,os,dir,history
From: https://www.cnblogs.com/daitu66/p/18162034

相关文章

  • 12-项目成本管理(4/10 十大管理)
    11.1管理基础11.1.1重要性和意义项目成本管理重点关注完成项目活动所需资源的成本,但同时也考虑项目决策对项目产品、服务或成果的使用成本、维护成本和支持成本的影响。项目成本管理应考虑干系人对成本的要求,不同的干系人会在不同的时间,用不同的方法测算项目成本。对于小项......
  • 动画的12项基本法则
    目录挤压与伸展(Squashandstretch)预期动作(Anticipation)演出方式(Staging)接续动作与关键动作(Straightaheadactionandposetopose)跟随动作与重叠动作(Followthroughandoverlappingaction)渐快与渐慢(Slowinandslowout)弧形(Arcs)附属动作(Secondaryaction)时间控制(T......
  • 初中中考阅读理解难题一网打尽!句子结构深度解析+答案揭秘,助你轻松冲刺高分!-012
    PDF格式公众号回复关键字:ZKYDT012原文1Richardfoundthebirdintheforest,didn’the?解析1Richard,found发现了,thebird这只鸟,intheforest在森林里,didn’the?不是吗理查德在森林里发现了这只鸟,不是吗?2Hesawastrangebirdinabush.他在灌木丛......
  • ROS1学习记录(12.0)
    学习视频:19.launch启动文件的使用方法_哔哩哔哩_bilibili使用launch启动可以减少不断打开终端重复输入调用不同节点的工作; 使用:其中name的意义是取代掉程序文件中的初始化的节点名字,相当于另取一个别名进行代表,代表这个节点,这样在调用相同节点时名字不同就不会起冲突;pkg,na......
  • 力扣-LCR 126. 斐波那契数
    1.题目题目地址(LCR126.斐波那契数-力扣(LeetCode))https://leetcode.cn/problems/fei-bo-na-qi-shu-lie-lcof/题目描述斐波那契数(通常用F(n)表示)形成的序列称为斐波那契数列。该数列由0和1开始,后面的每一项数字都是前面两项数字的和。也就是:F(0)=0,F(1)=1F(n......
  • 使用浏览器F12控制台模拟js点击隐藏功能
    F12里能看到有下载功能,但是前端页面上隐藏了(看到display参数是none)F12里先找个功能测试下,比如下面的旋转功能切换到“控制台”,输入下面,“pageRotateCw”是iddocument.getElementById('pageRotateCw').click();把pageRotateCw替换到download即可下载删除这个参数也可以......
  • 12_自动创建设备节点
    自动创建设备节点自动创建设备节点简介​ 在嵌入式Linux中使用mdev来实现设备节点文件的自动创建和删除。​ udev是一种工具,它能够根据系统中的硬件设备的状态动态更新设备文件,包括设备文件的创建,删除等。设备文件通常放在/dev目录下。使用udev后,在/dev目录下就......
  • Socket.D v2.4.12 发布(新增 python 实现)
    Socket.D协议?Socket.D是一个网络应用协议。在微服务、移动应用、物联网等场景,可替代http、websocket等。协议详情参考《官网介绍》。支持:tcp,udp,ws,kcp传输。目前:java,kotlin,javascript,node.js,python语言环境可用。go,rust,c/c++,.net正在开发中。forJava更新......
  • 【Azure Event Hub】解决Event Hub SDK出现无法识别 com.azure.core.client.traits.To
    问题描述使用ServiceBusSDK编写消费端应用时,遇见了错误信息:SDK版本  <dependency>   <groupId>com.azure</groupId>   <artifactId>azure-messaging-eventhubs</artifactId>   <version>5.4.0</version>  </dependency>......
  • 视频监控技术及AI识别技术在城市智慧灯杆综合管理平台中的应用
    随着城市化进程的加快,智慧城市建设已经成为当今城市发展的重要趋势。作为智慧城市建设的重要组成部分,智慧灯杆综合管理平台在城市管理和公共服务方面发挥着越来越重要的作用。智慧灯杆综合管理平台是一种基于物联网技术的智能化管理平台,通过集成各类传感器、控制器、通信设备等多......