首页 > 其他分享 >深度学习笔记2:数据增强

深度学习笔记2:数据增强

时间:2023-11-22 10:38:14浏览次数:40  
标签:layers 增强 keras 笔记 dataset 深度 test model size

  上一节由于训练数据集样本量较小,模型过早拟合最终我们在测试数据集的分类精度只达到了70%,本章节我们通过使用数据增强降低过拟合的方法。使用数据增强之后,模型的分类精度将提高到 80%~85%。数据增强是指从现有的训练样本中生成更多的训练数据,做法是利用一些能够生成可信图像的随机变换来增强(augment)样本。数据增强的目标是,模型在训练时不会两次查看完全相同的图片。这有助于模型观察到数据的更多内容,从而具有更强的泛化能力。

数据准备

定义数据增强代码

from tensorflow import keras
from tensorflow.keras import layers

data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),        
    layers.RandomRotation(0.1),        
    layers.RandomZoom(0.2),    
    ])

加载数据集

import  pathlib
new_base_dir = pathlib.Path('C:/Users/wuchh/.keras/datasets/dogs-vs-cats-small')

batch_size = 32
img_height = 180
img_width = 180

train_dataset = keras.preprocessing.image_dataset_from_directory(
    new_base_dir / 'train' ,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size
)

validation_dataset = keras.preprocessing.image_dataset_from_directory(
    new_base_dir / 'train' ,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size
)
显示几张增强后的训练图像
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10)) for images, _ in train_dataset.take(1): for i in range(9): augmented_images = data_augmentation(images) ax = plt.subplot(3, 3, i + 1) plt.imshow(augmented_images[0].numpy().astype("uint8")) plt.axis("off") plt.show()

 数据增强神经网络模型

inputs = keras.Input(shape=(180, 180, 3))
x = data_augmentation(inputs) #数字增强
x = layers.Rescaling(1./255)(x)
x = layers.Conv2D(filters=32, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=64, kernel_size
=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=128, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.MaxPooling2D(pool_size=2)(x)
x = layers.Conv2D(filters=256, kernel_size=3, activation="relu")(x)
x = layers.Flatten()(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs=inputs, outputs=outputs)
model.compile(loss="binary_crossentropy",optimizer="rmsprop",metrics=["accuracy"])

训练卷积神经网络

callbacks = [keras.callbacks.ModelCheckpoint(filepath="convnet_from_scratch_with_augmentation.model",        
                                             save_best_only=True,monitor="val_loss")]
history = model.fit(    train_dataset,    epochs=100,    validation_data=validation_dataset,    callbacks=callbacks)

在测试集上评估模型 

test_dataset = keras.preprocessing.image_dataset_from_directory(
new_base_dir / 'test' ,
seed=123,
image_size=(img_height, img_width),
batch_size=batch_size)

test_model = keras.models.load_model("convnet_from_scratch_with_augmentation.model")
test_loss, test_acc = test_model.evaluate(test_dataset)
print(f"Test accuracy: {test_acc:.3f}")
>>> test_model = keras.models.load_model("convnet_from_scratch_with_augmentation.model")
>>> test_loss, test_acc = test_model.evaluate(test_dataset)
 3/32 [=>............................] - ETA: 2s - loss: 0.4401 - accuracy: 0.7604Corrupt JPEG da 9/32 [=======>......................] - ETA: 1s - loss: 0.5328 - accuracy: 0.7535Corrupt JPEG da32/32 [==============================] - 2s 69ms/step - loss: 0.4752 - accuracy: 0.7890
>>> print(f"Test accuracy: {test_acc:.3f}")
Test accuracy: 0.789
>>>

  这次我的测试精度达到了 78.9%,这个进步不少!

 

标签:layers,增强,keras,笔记,dataset,深度,test,model,size
From: https://www.cnblogs.com/haozi0804/p/17783922.html

相关文章

  • 学习记录笔记
    学习记录笔记A*算法奇乐编程学院B站视频练习网站......
  • Linux读书笔记第5章
    在学习Linux的进程管理过程中,我总结了以下几个关键点:1.进程的创建和终止:Linux中的进程可以通过fork()系统调用来创建新的进程,该系统调用会复制当前进程的所有属性,并创建一个新的进程。另外,exec()系列的系统调用可以用于在新创建的进程中加载新的程序。而进程的终止可以通过调用e......
  • Linux进程管理学习感悟与笔记
    1.ps   'ps'是Linux中最基础的浏览系统中的进程的命令。能列出系统中运行的进程,包括进程号、命令、CPU使用量、内存使用量等。下述选项可以得到更多有用的消息。ps -a - 列出所有运行中/激活进程ps -ef |grep - 列出需要进程ps -aux - 显示进程信息,包括无终端......
  • 群论学习笔记
    群论学习笔记好厉害的东西。定义一个群\(\left\langle\mathbb{G},\circ\right\rangle\)由一个集合\(\mathbb{G}\)以及一个二元运算\(\circ:\mathbb{G}\times\mathbb{G}\to\mathbb{G}\)构成。群的4个性质:封闭性:对于\(a,b\in\mathbb{G},c=a\circb\),......
  • 《实现领域驱动设计》笔记——领域、子域和限界上下文
    总览从广义上讲,领域(Domain)即是一个组织所做的事情以及其中所包含的一切。商业机构通常会确定一个市场,然后在这个市场中销售产品和服务。每个组织都有它自己的业务范围和做事方式。这个业务范围以及在其中所进行的活动便是领域。当你为某个组织开发软件时,你面对的便是这个......
  • 蛤蟆先生去看心理医生-阅读笔记 All In One
    蛤蟆先生去看心理医生-阅读笔记AllInOne心理学作者:[英]罗伯特·戴博德出版社:天津人民出版社出品方:果麦文化原作名:CounsellingForToads:APsychologicalAdventure译者:陈赢出版年:2020-8-1页数:208定价:38.00元装帧:平装ISBN:9787201161693de......
  • 【刷题笔记】115. Distinct Subsequences
    题目Giventwostrings s and t,return thenumberofdistinctsubsequencesof s whichequals t.Astring's subsequence isanewstringformedfromtheoriginalstringbydeletingsome(canbenone)ofthecharacterswithoutdisturbingtheremainingch......
  • Git学习笔记:基础使用
    本随笔用于记录随笔作者在一般情况下使用Git的一些步骤和操作,主要用于在经过一段时间没有使用Git后能够通过该随笔马上回忆起基础操作,所以该随笔一开始并不会介绍Git的高级特性。本随笔内容摘录自官方教程随笔作者还在学习当中,难免会出现书写上和技术上的错误,如果发现类似错误,欢......
  • 组队学习-学习笔记P1
    组队学习-学习笔记P1Task01课程简介、安装Installation创建Conda环境在AnacondaPowershellPrompt中运行:condacreate-np2spython=3.10#conda环境创建其中-n代表创建的环境名称,这里是p2s,并指定Python版本为3.10这里的Python版本需要根据你自己下载的Pytho......
  • 笔记·环境搭建
    笔记·环境搭建下载与配置下载miniconda或者anaconda安装时配置如下右键开始图标后点击Windows终端(管理员)在终端中输入Set-ExecutionPolicy-ScopeCurrentUserRemoteSigned选择A后关闭(没反应不用管)打开miniconda输入condainit更换镜像源校园网联合镜像站(https://......