首页 > 其他分享 >深度学习笔记4:在卷积基上添加数据增强代码块和分类器

深度学习笔记4:在卷积基上添加数据增强代码块和分类器

时间:2023-12-14 09:01:28浏览次数:34  
标签:loss plt 基上 keras 卷积 dataset 分类器 history size

  特征提取的另一种方式是将原有模型与一个新的密集分类器相连接,以构建一个新的模型,然后对整个模型进行端到端的训练。这种方法在输入数据上进行整体训练,使模型能够更好地适应数据特性并提取更有效的特征。通过这种方式,模型的性能可以得到进一步提高,同时也能更好地捕捉到数据中的复杂模式。

冻结卷积基

from tensorflow import keras
conv_base = keras.applications.vgg16.VGG16(
   weights="imagenet",    
   include_top=False,    
   #input_shape=(180, 180, 3)
)
conv_base.trainable = False

在卷积基上添加数据增强代码块和分类器

data_augmentation = keras.Sequential([
  layers.RandomFlip("horizontal"),
  layers.RandomRotation(0.1),        
  layers.RandomZoom(0.2),    
])  

inputs = keras.Input(shape=(180, 180, 3))
x = data_augmentation(inputs)
x = keras.applications.vgg16.preprocess_input(x) 
x = conv_base(x)
x = layers.Flatten()(x)   
x = layers.Dense(256)(x)
x = layers.Dropout(0.5)(x)
outputs = layers.Dense(1, activation="sigmoid")(x)
model = keras.Model(inputs, outputs)
model.compile(loss="binary_crossentropy",optimizer="rmsprop",metrics=["accuracy"])

加载训练数据

import  pathlib

batch_size = 32
img_height = 180
img_width = 180

new_base_dir = pathlib.Path('C:/Users/wuchh/.keras/datasets/dogs-vs-cats-small')

train_dataset = keras.preprocessing.image_dataset_from_directory(
    new_base_dir / 'train' ,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size
)


validation_dataset = keras.preprocessing.image_dataset_from_directory(
    new_base_dir / 'train' ,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size
)


test_dataset = keras.preprocessing.image_dataset_from_directory(
    new_base_dir / 'test' ,
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size
)

训练模型

callbacks = [keras.callbacks.ModelCheckpoint( 
  filepath="feature_extraction_with_data_augmentation.model",
  save_best_only=True,
  monitor="val_loss")]

history = model.fit(    train_dataset,    epochs=50,    validation_data=validation_dataset,    callbacks=callbacks)  

绘制训练结果 

import matplotlib.pyplot as plt
acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, "bo", label="Training accuracy")
plt.plot(epochs, val_acc, "b", label="Validation accuracy")
plt.title("Training and validation accuracy")
plt.legend()
plt.figure()
plt.plot(epochs, loss, "bo", label="Training loss")
plt.plot(epochs, val_loss, "b", label="Validation loss")
plt.title("Training and validation loss")
plt.legend()
plt.show()

 在测试集上评估模型

test_model = keras.models.load_model("feature_extraction_with_data_augmentation.model")
test_loss, test_acc = test_model.evaluate(test_dataset)
print(f"Test accuracy: {test_acc:.3f}")

32/32 [==============================] - 15s 452ms/step - loss: 2.0066 - accuracy: 0.9790
Test accuracy: 0.979

 

总之,鉴于模型在验证数据上取得的好结果,这有点令人失望。模型的精度始终取决于评估模型的样本集。有些样本集可能比其他样本集更难以预测,在一个样本集上得到的好结果,并不一定能够在其他样本集上完全复现。

 

标签:loss,plt,基上,keras,卷积,dataset,分类器,history,size
From: https://www.cnblogs.com/haozi0804/p/17785316.html

相关文章

  • 论文精读:STMGCN利用时空多图卷积网络进行移动边缘计算驱动船舶轨迹预测(STMGCN: Mobile
    《STMGCN:MobileEdgeComputing-EmpoweredVesselTrajectoryPredictionUsingSpatio-TemporalMultigraphConvolutionalNetwork》论文链接:https://doi.org/10.1109/TII.2022.3165886摘要利用移动边缘计算MEC范例提出基于时空多图卷积网络(STMGCN)的轨迹预测框。STMGCN由三......
  • 论文精读:基于具有时空感知的稀疏多图卷积混合网络的大数据驱动船舶轨迹预测(Big data d
    论文精读:基于具有时空感知的稀疏多图卷积混合网络的大数据驱动船舶轨迹预测《Bigdatadrivenvesseltrajectorypredictionbasedonsparsemulti-graphconvolutionalhybridnetworkwithspatio-temporalawareness》论文链接:https://doi.org/10.1016/j.oceaneng.2023.115......
  • 网络分类器 cgroup 【ChatGPT】
    https://www.kernel.org/doc/html/v6.6/admin-guide/cgroup-v1/net_cls.html网络分类器cgroup网络分类器cgroup提供了一个接口,用于给网络数据包打上一个类别标识符(classid)。流量控制器(tc)可以用来为来自不同cgroup的数据包分配不同的优先级。此外,Netfilter(iptables)也可以......
  • 基于卷积神经网络实现高速公路表面图像裂缝检测程序
    作者简介:Java领域优质创作者、CSDN博客专家、CSDN内容合伙人、掘金特邀作者、阿里云博客专家、51CTO特邀作者、多年架构师设计经验、腾讯课堂常驻讲师主要内容:Java项目、Python项目、前端项目、人工智能与大数据、简历模板、学习资料、面试题库、技术互助收藏点赞不迷路 关注作......
  • 聊聊神经网络模型流程与卷积神经网络的实现
    神经网络模型流程神经网络模型的搭建流程,整理下自己的思路,这个过程不会细分出来,而是主流程。在这里我主要是把整个流程分为两个主流程,即预训练与推理。预训练过程主要是生成超参数文件与搭设神经网络结构;而推理过程就是在应用超参数与神经网络。卷积神经网络的实现在聊聊卷......
  • 机器学习中的典型算法——卷积神经网络(CNN)
    1.机器学习的定位AI,是我们当今这个时代的热门话题,那AI到底是啥?通过翻译可知:人工智能,而人工智能的四个核心要素:-数据-算法-算力-场景然后机器学习是人工智能的一部分,机器学习里面又有新的特例:深度学习。通俗来说机器学习即使用机器去学习一部分数据,然后去预测新的数据所属......
  • 聊聊卷积神经网络CNN
    卷积神经网络(ConvolutionalNeuralNetwork,CNN)是一种被广泛应用于图像识别、语音识别和自然语言处理等领域的深度学习模型。与RNN、Transformer模型组成AI的三大基石。在卷积神经网络中,相比较普通的神经网络,增加了卷积层(Convolution)和池化层(Pooling)。其结构一般将会是如下:......
  • 斯坦福大学引入FlashFFTConv来优化机器学习中长序列的FFT卷积
    斯坦福大学的FlashFFTConv优化了扩展序列的快速傅里叶变换(FFT)卷积。该方法引入Monarch分解,在FLOP和I/O成本之间取得平衡,提高模型质量和效率。并且优于PyTorch和FlashAttention-v2。它可以处理更长的序列,并在人工智能应用程序中打开新的可能性。处理长序列的效率一直是机器学习......
  • c4w2_深度卷积网络案例探究
    深度卷积模型:案例探究为什么要学习一些案例呢?就像通过看别人的代码来学习编程一样,通过学习卷积神经模型的案例,建立对卷积神经网络的(CNN)的“直觉”。并且可以把从案例中学习到的思想、模型移植到另外的任务上去,他们往往也表现得很好。接下来要学习的神经网络:经典模型:LeNet5、A......
  • c4w1_卷积神经网络
    卷积神经网络计算机视觉问题计算机视觉(computervision)是因深度学习而快速发展的领域之一,它存进了如自动驾驶、人脸识别等应用的发展,同时计算机视觉领域的发展还可以给其他领域提供思路。计算机视觉应用的实例:图片分类(识别是不是一只猫)、目标检测(检测途中汽车行人等)、图片风格......