首页 > 其他分享 >【漫话机器学习系列】043.提前停止训练(Early Stopping)

【漫话机器学习系列】043.提前停止训练(Early Stopping)

时间:2025-01-11 12:59:13浏览次数:3  
标签:loss val 32 self Early Stopping 模型 漫话 accuracy

92b0a9c675454cca982fc4338f413dfc.jpeg

提前停止训练(Early Stopping)

提前停止(Early Stopping) 是一种在训练机器学习模型(尤其是深度学习模型)时常用的正则化技术,用于防止过拟合并提升模型的泛化能力。它通过监控验证集的性能,在性能不再提高或开始下降时终止训练,从而选择性能最佳的模型。


工作原理

提前停止的基本思想是:

  1. 在每个训练轮次(epoch)后,评估模型在验证集上的性能(通常使用损失函数值或评价指标,如准确率)。
  2. 如果验证集性能在多个轮次内未改善,则停止训练并恢复到性能最佳的模型状态。

实现步骤

  1. 分割数据集: 将训练数据分为训练集和验证集,训练集用于优化模型参数,验证集用于监控模型的泛化性能。

  2. 设定监控指标: 选择一个监控指标(如验证损失、验证准确率等),作为衡量模型性能的标准。

  3. 设定耐心值(Patience): 耐心值是指允许验证集性能在指定轮次内未改善的次数。如果超过耐心值还未见性能提升,则停止训练。

  4. 保存最佳模型: 在训练过程中,记录验证集性能最优的模型状态,停止训练后使用该状态作为最终模型。


优点

  1. 防止过拟合:通过终止训练,避免模型过度拟合训练数据。
  2. 提高泛化能力:选择验证集上性能最优的模型,提升模型对未见数据的表现。
  3. 节省训练时间:减少不必要的迭代,节约计算资源。
  4. 动态调整:适应数据集的不同复杂度,不需要预设固定的训练轮次。

缺点

  1. 需要验证集:需要分出一部分数据作为验证集,可能导致训练数据减少。
  2. 过早停止的风险:模型可能在某些训练阶段出现短暂波动,提前停止可能会错过更好的优化结果。
  3. 适合深度学习模型:对于小规模模型或简单问题,提前停止的效果可能不明显。

实现方式

1. 使用 TensorFlow/Keras 实现

 
import numpy as np
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping

# 示例数据集
X_train = np.random.rand(1000, 20)  # 1000个样本,每个样本20个特征
y_train = np.random.randint(2, size=(1000, 1))  # 1000个样本的二分类标签
X_val = np.random.rand(200, 20)  # 200个样本,每个样本20个特征
y_val = np.random.randint(2, size=(200, 1))  # 200个样本的二分类标签

model = Sequential([
    Dense(64, activation='relu', input_shape=(X_train.shape[1],)),
    Dense(1, activation='sigmoid')
])

# 定义 EarlyStopping 回调,监控验证集损失,如果连续5个epoch没有改善则停止训练,并恢复最佳权重
early_stopping = EarlyStopping(
    monitor='val_loss',  # 监控的指标
    patience=5,          # 在验证集性能不提升的轮数后停止
    restore_best_weights=True  # 恢复验证集性能最优的模型
)

# 编译模型
model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])

# 训练模型,使用早停机制
model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=100,
    callbacks=[early_stopping]
)

运行结果 

Epoch 1/100
32/32 [==============================] - 1s 8ms/step - loss: 0.2522 - accuracy: 0.5210 - val_loss: 0.2504 - val_accuracy: 0.5350
Epoch 2/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2491 - accuracy: 0.5320 - val_loss: 0.2502 - val_accuracy: 0.5300
Epoch 3/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2484 - accuracy: 0.5320 - val_loss: 0.2507 - val_accuracy: 0.4950
Epoch 4/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2468 - accuracy: 0.5260 - val_loss: 0.2521 - val_accuracy: 0.4950
Epoch 5/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2456 - accuracy: 0.5560 - val_loss: 0.2524 - val_accuracy: 0.5150
Epoch 6/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2452 - accuracy: 0.5450 - val_loss: 0.2540 - val_accuracy: 0.5000
Epoch 7/100
32/32 [==============================] - 0s 2ms/step - loss: 0.2457 - accuracy: 0.5500 - val_loss: 0.2529 - val_accuracy: 0.4750

 

2. 使用 PyTorch 实现

import torch
import torch.nn as nn

class EarlyStopping:
    def __init__(self, patience=5, delta=0, path='checkpoint.pt'):
        self.patience = patience
        self.delta = delta
        self.best_loss = None
        self.counter = 0
        self.early_stop = False
        self.path = path

    def __call__(self, val_loss, model):
        if self.best_loss is None or val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.counter = 0
            torch.save(model.state_dict(), self.path)
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

# 定义示例数据集
X_train = torch.randn(1000, 20)  # 1000个样本,每个样本20个特征
y_train = torch.randint(0, 2, (1000, 1))  # 1000个样本的二分类标签
X_val = torch.randn(200, 20)  # 200个样本,每个样本20个特征
y_val = torch.randint(0, 2, (200, 1))  # 200个样本的二分类标签

# 定义模型
model = nn.Sequential(
    nn.Linear(20, 64),
    nn.ReLU(),
    nn.Linear(64, 1),
    nn.Sigmoid()
)

# 定义训练和验证函数
def train():
    pass

def validate():
    return torch.tensor(0.5)  # 示例验证损失

# 使用示例
early_stopping = EarlyStopping(patience=5)

for epoch in range(100):
    train()  # 训练过程
    val_loss = validate()  # 验证损失
    early_stopping(val_loss, model)

    if early_stopping.early_stop:
        model.load_state_dict(torch.load('checkpoint.pt'))
        break

 


总结

提前停止训练是机器学习和深度学习中的一种简单高效的正则化方法,能够显著提升模型的泛化能力,同时减少训练时间。结合耐心值(patience)、监控指标以及最佳模型保存机制,可以灵活地应用到各种场景中。

 

 

 

标签:loss,val,32,self,Early,Stopping,模型,漫话,accuracy
From: https://blog.csdn.net/IT_ORACLE/article/details/145048155

相关文章

  • 探索Bearly Code Interpreter:远程代码执行与数据交互的完美结合
    #探索BearlyCodeInterpreter:远程代码执行与数据交互的完美结合##引言随着人工智能和编程领域的快速发展,安全高效的代码执行环境变得越来越重要。这篇文章将介绍BearlyCodeInterpreter,一个允许远程执行代码的强大工具,使得如代码解释器等功能的实现更加安全可靠。我......
  • 3.4.4 __ipipe_init_early之再论虚拟中断
    点击查看系列文章=》 InterruptPipeline系列文章大纲-CSDN博客3.4.4__ipipe_init_early之再论虚拟中断     根据《3.4.1.2IPIPE对Linux中断号的改造》的分析,IPIPE引入的虚拟中断virtualinterrupt的概念,其中前10个虚拟中断本质上是利用SGI实现的IPI中断。IPIPE在......
  • 3.4.3 __ipipe_init_early之初始化root domain
    点击查看系列文章=》 InterruptPipeline系列文章大纲-CSDN博客3.4.3__ipipe_init_early之初始化rootdomain      如下图所示,红框里面的函数当前都是空的,本章还是分析蓝框中的代码片段。第295行,变量ipd指向了ipipe_root即ipd代表rootdomain。第305行,rootdoma......
  • 漫话linux:基础IO,软硬链接,动静态库管理
    1.软硬链接    1.软链接:是一个独立文件,具有独立的inode,也有独立的数据块,它的数据块里面保存的是指向的文件的路径,公用inode        1.建立软连接ln-s目标文件或目录,链接名 目标文件或目录表示路径,链接名代表命令,无论在哪里输入命令就能调......
  • 多模态学习之论文阅读:《PREDICTING AXILLARY LYMPH NODE METASTASIS IN EARLY BREAST
    《PREDICTINGAXILLARYLYMPHNODEMETASTASISINEARLYBREASTCANCERUSINGDEEPLEARNINGONPRIMARYTUMORBIOPSYSLIDES》(一)要点提出一个基于注意力机制的多实例学习框架,构建了一个深度学习模型。该模型利用WSIs和临床数据预测早期乳腺癌(EBC)患者的腋窝淋巴结(ALN)转移状态......
  • E. Nearly Shortest Repeating Substring
    #include<iostream>#include<algorithm>#include<cstring>#include<cmath>usingnamespacestd;intn,m;intmain(){ cin>>n; while(n--) { //strings; cin>>m; strings; cin>>s; intres=m; f......
  • E. Nearly Shortest Repeating Substring
    原题链接题解1.模拟题,注意细节2.时间复杂度\(O(n·sqrt(n))\)code#include<bits/stdc++.h>usingnamespacestd;intn;strings;intcheck(intlen){intflag=0;for(intk=0;k<len;k++){inta[26]={0};for(inti=k;i<n;i+=len)......
  • Codeforces Round 937 (Div. 4)----->E. Nearly Shortest Repeating Substring
    一,思路:1.这题很容易想到枚举n的因数(时间复杂度n^(1/2)),然后根据这个长度枚举字符串,看是否满足最多只有一个不相同(时间复杂度n)。总的时间复杂度是(n根号n)的级别。又n是1e5级别所以可以过。但是当n是1e6就不行了。2.难点在于如何判断,一个字符串的不同字符数量,主要是hshah......
  • 初中英语优秀范文100篇-062Going to Sleep Early Is a Good Habit早一点睡觉是个好习
    PDF格式公众号回复关键字:SHCZFW062记忆树1Goingtosleepearlyisofgreatbenefittome.翻译早一点睡觉对我非常有益。简化记忆早睡句子结构"Goingtosleepearly"是主语,表示一个动作或状态。"is"是系动词,用来连接主语和表语,表示主语的特征或状态。"ofgreat......
  • Json Schema介绍 和 .net 下的实践 - 基于Lateapexearlyspeed.Json.Schema - 基础1 -
    本系列旨在介绍JsonSchema的常见用法,以及.net实现库Lateapexearlyspeed.Json.Schema的使用这篇文章将介绍JsonSchema中的type关键字,和string类型的常见验证功能。用例基于.net的LateApexEarlySpeed.Json.Schemanugetpackage。这是新创建的一个JsonSchema在.net下的高性能......