首页 > 其他分享 >过拟合与欠拟合、批量标准化

过拟合与欠拟合、批量标准化

时间:2024-09-17 17:23:46浏览次数:3  
标签:训练 批量 模型 标准化 正则 参数 拟合 数据

过拟合与欠拟合

过拟合(Overfitting)

1、基本概念:过拟合指的是模型在训练数据上表现很好,但在未见过的测试数据上表现较差的情况。过拟合发生的原因是模型过于复杂,能够记住训练数据的细节和噪声,而不是学习数据的通用模式。

2、特征

  • 模型在训练数据上的准确度高。

  • 模型在测试数据上的准确度较低。

  • 模型的参数数量过多,容易记忆训练数据。

3、防止过拟合的方法

  • 数据集扩增:增加更多的训练数据,可以减少过拟合的风险。

  • 正则化:通过添加正则化项,如L1正则化(Lasso)或L2正则化(Ridge),来惩罚模型参数的大小,使模型更简单。

  • 特征选择:选择最重要的特征,降低模型的复杂度。

  • 交叉验证:使用交叉验证来估计模型的性能,选择最佳的模型参数。

  • 早停止:在训练过程中监控验证集的性能,当性能开始下降时停止训练,以防止过拟合。

欠拟合(Underfitting)

1、基本概念:欠拟合表示模型太过简单,无法捕获数据中的关键特征和模式。模型在训练数据和测试数据上的性能都较差。

2、特征

  • 模型在训练数据上的准确度较低。

  • 模型在测试数据上的准确度也较低。

  • 模型可能太简单,参数数量不足。

3、防止欠拟合的方法

  • 增加模型复杂度:使用更复杂的模型,例如增加神经网络的层数或增加决策树的深度。

  • 增加特征:添加更多的特征或进行特征工程,以捕获更多数据的信息。

  • 减小正则化强度:如果使用了正则化,可以降低正则化的强度,使模型更灵活。

  • 调整超参数:调整模型的超参数,如学习率、批量大小等,以改善模型的性能。

  • 使用更多数据:如果可能的话,增加训练数据可以提高模型的性能。

总的来说,过拟合和欠拟合都是需要非常注意的问题。

选择合适的模型复杂度、正则化方法和特征工程技巧可以帮助在训练机器学习模型时避免这些问题,获得更好的泛化性能。

解决过拟合

  • L1 正则化 更适合用于产生稀疏模型,会让部分权重完全为零,适合做特征选择。

  • L2 正则化 更适合平滑模型的参数,避免过大参数,但不会使权重变为零,适合处理高维特征较为密集的场景。

L2正则化

L2 正则化通过在损失函数中添加权重参数的平方和来实现,目标是惩罚过大的参数值。

L_{\text{total}}(\theta) = L(\theta) + \lambda \cdot \frac{1}{2} \sum_{i} \theta_i^2

  • L(\theta) 是原始损失函数(比如均方误差、交叉熵等)。

  • \lambda 是正则化强度,控制正则化的力度。

  • \theta_i是模型的第 $$i$$ 个权重参数。

  • \frac{1}{2} \sum_{i} \theta_i^2 是所有权重参数的平方和,称为 L2 正则化项。

L2 正则化会惩罚权重参数过大的情况,通过参数平方值对损失函数进行约束。

梯度更新

\theta_{t+1} = \theta_t - \eta \left( \nabla L(\theta_t) + \lambda \theta_t \right)

  • \eta 是学习率。

  • \nabla L(\theta_t)是损失函数关于参数\theta_t的梯度。

  • \lambda \theta_t 是 L2 正则化项的梯度,对应的是参数值本身的衰减。

参数越大惩罚力度就越大,从而让参数逐渐趋向于较小值,避免出现过大的参数。

作用

  1. 防止过拟合:当模型过于复杂、参数较多时,模型会倾向于记住训练数据中的噪声,导致过拟合。L2 正则化通过抑制参数的过大值,使得模型更加平滑,降低模型对训练数据噪声的敏感性。

  2. 限制模型复杂度:L2 正则化项强制权重参数尽量接近 0,避免模型中某些参数过大,从而限制模型的复杂度。通过引入平方和项,L2 正则化鼓励模型的权重均匀分布,避免单个权重的值过大。

  3. 提高模型的泛化能力:正则化项的存在使得模型在测试集上的表现更加稳健,避免在训练集上取得极高精度但在测试集上表现不佳。

  4. 平滑权重分布:L2 正则化不会将权重直接变为 0,而是将权重值缩小。这样模型就更加平滑的拟合数据,同时保留足够的表达能力。

import torch.optim as optim
optimizer = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.001)  # L2 正则化

L1正则化

L1 正则化通过在损失函数中添加权重参数的绝对值之和来约束模型的复杂度。

L_{\text{total}}(\theta) = L(\theta) + \lambda \sum_{i} |\theta_i|

  • L(\theta)是原始损失函数。

  • \lambda 是正则化强度,控制正则化的力度。

  • |\theta_i| 是模型第 i 个参数的绝对值。

  • \sum_{i} |\theta_i|是所有权重参数的绝对值之和,这个项即为 L1 正则化项。

L1 正则化依赖于参数的绝对值,其梯度更新时不是简单的线性缩小,而是通过符号函数来直接调整参数的方向。

作用
  1. 稀疏性:L1 正则化的一个显著特性是它会促使许多权重参数变为 。这是因为 L1 正则化倾向于将权重绝对值缩小到零,使得模型只保留对结果最重要的特征,而将其他不相关的特征权重设为零,从而实现 特征选择 的功能。

  2. 防止过拟合:通过限制权重的绝对值,L1 正则化减少了模型的复杂度,使其不容易过拟合训练数据。相比于 L2 正则化,L1 正则化更倾向于将某些权重完全移除,而不是减小它们的值。

  3. 简化模型:由于 L1 正则化会将一些权重变为零,因此模型最终会变得更加简单,仅依赖于少数重要特征。这对于高维度数据特别有用,尤其是在特征数量远多于样本数量的情况下。

  4. 特征选择:因为 L1 正则化会将部分权重置零,因此它天然具有特征选择的能力,有助于自动筛选出对模型预测最重要的特征。

l1_lambda = 0.001
# 计算 L1 正则化项并将其加入到总损失中
l1_norm = sum(p.abs().sum() for p in model.parameters())
loss = loss + l1_lambda * l1_norm

Dropout

Dropout 是一种在训练过程中随机丢弃部分神经元的技术。它通过减少神经元之间的依赖来防止模型过于复杂,从而避免过拟合。

import torch
import torch.nn as nn

def dropout():
    dropout=nn.Dropout(p=0.5)
    x=torch.randn(2,2)
    print(x)
    print("------------------")
    print(dropout(x))
    
if __name__ == "__main__":
    dropout()
    
"""
tensor([[-0.3970, -1.8862],
        [-0.5632,  0.0390]])
------------------
tensor([[-0.7940, -3.7724],
        [-1.1264,  0.0000]])
"""

Dropout过程:

  1. 按照指定的概率把部分神经元的值设置为0;

  2. 为了规避该操作带来的影响,需对非 0 的元素使用缩放因子1/(1-p)进行强化。

权重影响

简化模型

  • 减少网络层数和参数: 通过减少网络的层数、每层的神经元数量或减少卷积层的滤波器数量,可以降低模型的复杂度,减少过拟合的风险。

  • 使用更简单的模型: 对于复杂问题,使用更简单的模型或较小的网络架构可以减少参数数量,从而降低过拟合的可能性。

数据增强

通过对训练数据进行各种变换(如旋转、裁剪、翻转、缩放等),可以增加数据的多样性,提高模型的泛化能力。

from torchvision import transforms

transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor()
])


"""
transforms.Compose([...])
transforms.Compose接受一个转换操作列表,并按顺序应用这些转换。它允许你将多个转换操作组合在一起,形成一个统一的转换流程。

转换操作列表
1. transforms.RandomHorizontalFlip()
功能:随机水平翻转图像。
参数:默认情况下,p=0.5,即有50%的概率会对图像进行水平翻转。
用途:用于数据增强,增加模型的泛化能力。
2. transforms.RandomVerticalFlip()
功能:随机垂直翻转图像。
参数:默认情况下,p=0.5,即有50%的概率会对图像进行垂直翻转。
用途:同样用于数据增强,使模型能够更好地处理不同角度的图像。
3. transforms.RandomRotation(10)
功能:随机旋转图像。
参数:degrees,表示旋转的角度范围。这里的10表示图像将以10度为范围进行随机旋转(-10到10度之间)。
用途:增强数据,使模型能够更好地处理不同角度下的图像。
4. transforms.ToTensor()
功能:将PIL Image或numpy数组转换为PyTorch的Tensor,并且归一化到[0.0, 1.0]。
用途:将图像数据转换为适合在网络中使用的格式,同时进行归一化处理。
"""

早停

早停是一种在训练过程中监控模型在验证集上的表现,并在验证误差不再改善时停止训练的技术。这样可避免训练过度,防止模型过拟合。

模型集成

通过将多个不同模型的预测结果进行集成,可以减少单个模型过拟合的风险。常见的集成方法包括投票法、平均法和堆叠法。

 交叉验证

使用交叉验证技术可以帮助评估模型的泛化能力,并调整模型超参数,以防止模型在训练数据上过拟合。

这些方法可以单独使用,也可以结合使用,以有效地防止参数过大和过拟合。根据具体问题和数据集的特点,选择合适的策略来优化模型的性能。

批量标准化

批量标准化(Batch Normalization, BN)是一种广泛使用的神经网络正则化技术,核心思想是对每一层的输入进行标准化,然后进行缩放和平移,旨在加速训练、提高模型的稳定性和泛化能力。

实现过程

批量标准化的基本思路是在每一层的输入上执行标准化操作,并学习两个可训练的参数:缩放因子 \gamma和偏移量\beta

1.计算均值和方差

均值

\mu_B = \frac{1}{m} \sum_{i=1}^m x_i

方差

\sigma_B^2 = \frac{1}{m} \sum_{i=1}^m (x_i - \mu_B)^2

2.标准化

使用计算得到的均值和方差对数据进行标准化,使得每个特征的均值为0,方差为1。

标准化后的值

\hat{x}_i = \frac{x_i - \mu_B}{\sqrt{\sigma_B^2 + \epsilon}}

3.缩放和平移

标准化后的数据通常会通过可训练的参数进行缩放和平移,以恢复模型的表达能力。

  • 缩放(Gamma)
    y_i = \gamma \hat{x}_i
     

  • 平移(Beta)
    y_i = \gamma \hat{x}_i + \beta

训练和推理阶段

  • 训练阶段: 在训练过程中,均值和方差是基于当前批次的数据计算得到的。

  • 推理阶段: 在推理阶段,批量标准化使用的是训练过程中计算得到的全局均值和方差,而不是当前批次的数据。这些全局均值和方差通常会被保存在模型中,用于推理时的标准化过程。

作用

提高神经网络的训练稳定性、加速训练过程并减少过拟合

可以从一下几个方面来改善:

1 缓解梯度问题

标准化处理可以防止激活值过大或过小,避免了激活函数(如 Sigmoid 或 Tanh)饱和的问题,从而缓解梯度消失或爆炸的问题。

2 加速训练

由于 BN 使得每层的输入数据分布更为稳定,因此模型可以使用更高的学习率进行训练。这可以加快收敛速度,并减少训练所需的时间。

3 减少过拟合

  • 类似于正则化:虽然 BN 不是一种传统的正则化方法,但它通过对每个批次的数据进行标准化,可以起到一定的正则化作用。它通过在训练过程中引入了噪声(由于批量均值和方差的估计不完全准确),这有助于提高模型的泛化能力。

  • 避免对单一数据点的过度拟合:BN 强制模型在每个批次上进行标准化处理,减少了模型对单个训练样本的依赖。这有助于模型更好地学习到数据的整体特征,而不是对特定样本的噪声进行过度拟合。

import torch 
import torch.nn as nn

def test():
    x=torch.randn(2,3,4,4)
    print(x)
    
    bn=nn.BatchNorm2d(3)
    print(bn)
    
if __name__=='__main__':
    test()



"""
tensor([[[[-0.1629, -0.3630,  0.6086,  1.2669],
          [-0.7454,  0.1635, -0.1000,  0.9490],
          [ 0.2711,  1.9755, -0.6669, -0.3346],
          [-0.2467,  0.9544, -0.3537, -0.8904]],

         [[ 0.9441, -0.7221, -0.0377,  0.3374],
          [-2.2795, -1.1555,  0.9555,  0.4566],
          [-0.1251,  0.5129, -1.6877, -0.3519],
          [ 0.5455,  1.1250,  0.6385, -0.1447]],

         [[ 0.8211,  0.2494, -0.4131,  1.2432],
          [-0.6434,  1.1120,  1.1102,  0.8328],
          [ 0.0868,  0.2222,  0.1554, -0.7188],
          [ 0.8627, -0.7993, -0.8812,  0.9972]]],


        [[[-1.1189, -0.9412, -0.9145, -0.0048],
          [ 0.6170,  1.2101,  0.1813, -0.5363],
          [ 0.9798,  0.4064,  0.5711,  0.2156],
          [ 1.6940,  0.4776,  0.1171, -0.1421]],

         [[ 0.6171, -1.2645, -0.1189, -0.3172],
          [-0.5279,  0.3126,  1.5111,  0.8772],
          [-1.2101, -1.4024,  1.5457, -1.0882],
          [-1.4969, -0.3039, -0.6469, -0.3612]],

         [[-0.8796,  0.6566,  1.0026,  0.2472],
          [ 0.6985, -0.4325,  0.5768,  1.2399],
          [-0.8927, -0.3637, -0.5471, -1.9263],
          [ 0.9424,  1.6031,  1.5086,  0.2109]]]])
BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
"""

标签:训练,批量,模型,标准化,正则,参数,拟合,数据
From: https://blog.csdn.net/m0_68153457/article/details/142315208

相关文章

  • C#方法将数据库图片批量插入到EXCEL中
    效果图一般数据库图片查询出来为byte[]类型这里使用的是Spire.Officefor.NETnet4.0和net6.0可以使用附件的dll,其他版本可去官网下载相应的dll官方网站:https://www.e-iceblue.com/GitHub:https://github.com/eiceblueNuGet:https://www.nuget.org/packages/FreeSpire.Off......
  • Python使用starmap函数批量更新数据库
    在数据库操作中,有时候需要对多条记录进行批量更新操作,而这些记录的更新逻辑可能是相同的,只是参数不同。starmap函数可以更加高效地实现批量更新数据库的操作。importsqlite3fromitertoolsimportstarmap#连接数据库conn=sqlite3.connect('example.db')cursor=conn......
  • AI视频批量自动剪辑软件
     小咖批量剪辑助手是一款视频批量自动剪辑软件,具有智能化、批量化、操作简单等特点。该软件适用于自动化处理和生产视频,旨在帮助用户实现批量化生产产品推广视频的功能。三、安装与配置安装步骤:下载程序压缩包:访问官方网站或指定下载地址,下载小咖批量剪辑助手程序压缩......
  • AI视频批量自动剪辑软件
     小咖批量剪辑助手是一款视频批量自动剪辑软件,具有智能化、批量化、操作简单等特点。该软件适用于自动化处理和生产视频,旨在帮助用户实现批量化生产产品推广视频的功能。三、安装与配置安装步骤:下载程序压缩包:访问官方网站或指定下载地址,下载小咖批量剪辑助手程序压缩......
  • 代理ip批量检测工具,采用多线程并发编程,支持http,https,socks4,socks5协议!
     工具使用c++编程语言,采用多线程并发检测技术:支持ipv4及ipv6代理ip批量检测。支持httphttpssocks4及socks5代理服务器的批量检测。支持所有windows版本运行!导入方式支持手工选择文件及拖放文件。导入格式支持三种格式:第一种:用|号分割2409:8a50:8019:e470:a8d7:bdf0:fbfe:8b5......
  • 编程日记 批量导入数据
    编程日记批量导入数据1.用可视化界面:适合一次性导入,数据量可控2.写程序:for循环,建议分批,不要一把梭哈(可以用接口控制),要保证可控、幂等,注意线上环境和测试环境是有区别的导入1000w条,fori1000w(不能再main方法里面写,会报空指针异常,userMapper无法注入)缺点是.class并不是一个......
  • 【办公类】大组工会学习(文心一言+Python批量)
    背景需求:每学期要写一份工会大组学习读后感(9月-1月,共5次)学习内容9月、10月、11月、12月、1月的学习内容文字稿在班级里,我擅长电脑工作,所以这种写的工作都包了。中2班三位老师一共写3篇,加上上个班级的搭档也让我写一份,本次我要写4份学习读后感。随着AI技术的深入,我想......
  • 【办公类】幼儿健康数据模版批量更改日期(保健老师填写)
    背景需求今天下发通知三个园区的保健老师需要填写1.2023学年(202406)的六一体检数据2.2024学年(202409)的新生入园体检数据我先把上一轮填写过的数据模版下载下来(套用模版)把EXCEL下载到原始文件夹里模版下载完成,我想到去年2023年9月用这些前年2022年9月的模版发给保......
  • 知识产权与标准化【软考】
    文章目录一、知识产权1.1保护范围与对象1.2保护期限1.3知识产权人确定一、知识产权1.1保护范围与对象软件著作权人享有权利发表权:决定软件是否公之于众的权力署名权:表明开发者身份,在软件上的署名的权力修改权:可对软件进行增删改的权力复制权:将软件制作一份或多份......
  • HDFS批量清理过期文件
    #!/bin/bashsource~/.bashrc#HADOOP所在的bin目录HADOOP_BIN_PATH=/opt/cloudera/parcels/CDH/bin#待检测的HDFS目录d1=/tmp1d2=/tmp/sac-sac1d3=/tmp/cep-bu4d4=/tmp/test_data_standardd5=/tmp/test_data_standard_sac#将待检测的目录(可以为多个)加载至数组中......