首页 > 其他分享 >什么是Layer Normalization?

什么是Layer Normalization?

时间:2024-12-05 17:59:26浏览次数:5  
标签:Layer LN 什么 after sample single var Normalization before

一、概念

        前面的文章中,我们介绍了Batch NormalizationBN的目的是使得每个batch的输入数据在每个维度上的均值为0、方差为1(batch内,数据维度A的所有数值均值为0、方差为1,维度B、C等以此类推),这是由于神经网络的每一层输出数据分布都会发生变化,随着网络层数的增加,内部协变量的偏移程度会变大。我们在数据预处理阶段使用sklearn等工具进行的Normalization仅仅解决了第一层输入的问题,而隐藏层中各层的输入问题仍然存在。因此我们将BN嵌入到模型结构内部,用于把每一个batch的数据拉回正态分布。

        然而,BN通过对每个维度进行正态分布处理,会使得各个维度之间的数值大小关系失真,也就是单一样本内部的特征关系被打乱了。显然,这对于处理文本向量等序列数据来说并不友好,文本向量内部的语义关系会受到BN的影响。因此,预训练模型、大语言模型等内部一般不会采用BN,而是采用Layer Normalization。

        Layer Normalization对神经网络中每一层的输出进行归一化处理,确保单条样本内部各特征的均值为0、方差为1

  • 计算均值和方差:对每个样本的特征维度计算均值和方差。
  • 归一化处理:使用计算出的均值和方差对当前样本进行归一化,使其均值为0,方差为1。
  • 缩放和平移:引入可学习的参数进行尺度和偏移变换,以恢复模型的表达能力。

        Layer Normalization能够减少训练过程中的梯度爆炸或消失问题,从而提高模型的稳定性和训练效率。尤其是在RNN和Transformer等序列模型中,LN所实现的稳定数据分布有助于模型层与层之间的信息流更加平滑。

二、LN示例

        下面,我们给出一个LN的简单示例。与Batch Normalization不同,Layer Normalization不依赖于mini-batch,而是对每一个样本独立进行归一化,这使得它适用于各种数据规模,包括小批量和单个样本。

import torch
import torch.nn as nn

# 构造一个单一样本,包含5个特征
sample = torch.tensor([2.0, 3.0, 5.0, 1.0, 4.0], requires_grad=True)
print("Original Sample:", sample)

# 定义Layer Normalization层
# 特征数量(特征维度)为5
ln = nn.LayerNorm(normalized_shape=[5])

# 应用Layer Normalization
sample_norm = ln(sample)
print("Normalized Sample:", sample_norm)

# 检查均值和方差
mean = sample_norm.mean()
var = sample_norm.var()
print("Mean:", mean)
print("Variance:", var)

三、python应用

        这里,我们在构建网络的过程中加入LN,并对比前后的数据差异。

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

# 设置随机种子以确保结果可复现
torch.manual_seed(0)

# 创建一个简单的模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.linear = nn.Linear(100, 50)  # 一个线性层
        self.ln = nn.LayerNorm(50)  # Layer Normalization层

    def forward(self, x):
        x = self.linear(x)
        x = self.ln(x)
        return x

# 创建模型实例
model = SimpleModel()

# 生成模拟数据:100个样本,每个样本100个特征
x = torch.randn(100, 100, requires_grad=True)

# 前向传播,计算LN前的数据
x_linear = model.linear(x)
x_linear = x_linear.detach()

# 计算LN前的数据均值和方差
mean_before = x_linear.mean(dim=0)
var_before = x_linear.var(dim=0)

# 应用LN
x_ln = model(x)
x_ln = x_ln.detach()

# 计算LN后的数据均值和方差
mean_after = x_ln.mean(dim=0)
var_after = x_ln.var(dim=0)

# 随机选择一个样本
sample_index = 0
single_sample_before = x_linear[sample_index].unsqueeze(0)
single_sample_after = x_ln[sample_index].unsqueeze(0)

# 计算单个样本的均值和方差
mean_single_before = single_sample_before.mean()
var_single_before = single_sample_before.var()
mean_single_after = single_sample_after.mean()
var_single_after = single_sample_after.var()

# 绘制LN前后数据的分布
fig, ax = plt.subplots(2, 3, figsize=(18, 12))

# 绘制LN前的数据分布
ax[0, 0].hist(x_linear.detach().numpy().flatten(), bins=30, color='blue', alpha=0.7)
ax[0, 0].set_title('Before LN: Data Distribution')

# 绘制LN后的数据分布
ax[0, 1].hist(x_ln.detach().numpy().flatten(), bins=30, color='green', alpha=0.7)
ax[0, 1].set_title('After LN: Data Distribution')

# 绘制单个样本LN前的数据分布
ax[0, 2].hist(single_sample_before.detach().numpy().flatten(), bins=30, color='red', alpha=0.7)
ax[0, 2].set_title('Single Sample Before LN')

# 绘制LN前的数据均值和方差
ax[1, 0].bar(range(50), var_before, color='blue', alpha=0.7)
ax[1, 0].set_title('Before LN: Variance per Feature')
ax[1, 0].set_xticks(range(0, 50, 5))

# 绘制LN后的数据均值和方差
ax[1, 1].bar(range(50), var_after, color='green', alpha=0.7)
ax[1, 1].set_title('After LN: Variance per Feature')
ax[1, 1].set_xticks(range(0, 50, 5))

# 绘制单个样本LN后的均值和方差
ax[1, 2].bar(range(1), var_single_after.item(), color='red', alpha=0.7)
ax[1, 2].set_title('Single Sample After LN: Variance')
ax[1, 2].set_xticks([])

plt.tight_layout()
plt.show()

# 打印LN前后的数据均值和方差
print(f"Mean before LN: {mean_before}")
print(f"Mean after LN: {mean_after}")
print(f"Variance before LN: {var_before}")
print(f"Variance after LN: {var_after}")
print(f"Mean of single sample before LN: {mean_single_before}")
print(f"Variance of single sample before LN: {var_single_before}")
print(f"Mean of single sample after LN: {mean_single_after}")
print(f"Variance of single sample after LN: {var_single_after}")

        可见右下角子图,LN之后,单条样本内部已经拉成正态分布了。

31183e9a82284f10925a517cd166efd7.png

四、总结

        BN和LN都是缓解深度学习模型梯度消失或者梯度爆炸重要技巧,实际建模过程中我们也可以通过对比加入BN或者LN前后的模型表现来调整最终的模型架构。但值得注意的是,在选择BN或者LN的时候,我们需要想清楚到底单一维度的正态分布对当前任务来说更有意义还是说单一样本内部数值的正态分布更有意义。

 

标签:Layer,LN,什么,after,sample,single,var,Normalization,before
From: https://blog.csdn.net/ChaneMo/article/details/144208309

相关文章

  • git 中 rebase 是什么样的操作,应该从哪个分支rebase到哪个分支
    使branch_1rebase(变基)到branch_2branch_1是当前活动分支,使用rebasebranch_2,把branch_2分支的提交放在branch_1提交的前面,这样使branch_1合并了branch1且使branch_1和branch_2的提交是线性的一般来说,个人理解应该这么用:在dev分支中有新提交,且master也有了......
  • 物料堆放检测视频分析服务器违规生产检测:安防摄像机里的视频采集参数有什么意义
    在安防领域,摄像机的图像质量是衡量其性能的关键指标之一。一个高质量的摄像机不仅需要优质的硬件基础,如高性能的DSP处理器和高灵敏度的图像传感器,还需要通过精细的调整和优化来发挥其最大潜力。本文将深入探讨如何通过理解和调整摄像机的关键视频图像采集参数,来提升摄像机的图像效......
  • IT行业的流程管理该怎么优化?有什么好用的工具?
    无论是大规模制造业还是科技创新型企业,优化流程管理都能够显著提高工作效率、降低成本并增强企业的市场适应力。那么,如何才能做好流程管理,打造高效流程管理体系呢?首先,要搞清楚流程管理的五大步骤。1.信息收集工作要做好在进行任何流程优化前,企业首先需要收集大量关于现有流程......
  • RTSP播放器EasyPlayer.js报错The play() request was interrupted because video-only
    随着技术的发展,越来越多的H5流媒体播放器开始支持H.265编码格式。例如,EasyPlayer.jsH5播放器能够支持H.264、H.265等多种音视频编码格式,这使得播放器能够适应不同的视频内容和网络环境。那么为什么会出现Theplay()requestwasinterruptedbecausevideo-onlybackgroundmed......
  • 跑AI大模型的K8s与普通K8s有什么不同?
    跑AI大模型的K8s与普通K8s有什么不同? 摘要:在面对大模型AI火热的当下,咱们从程序员三大件“计算、存储、网络”出发,一起看看这种跑大模型AI的K8s与普通的K8s有什么区别?有哪些底层就可以构筑AI竞争的地方。本文分享自华为云社区《跑AI大模型的K8s与普通K8s有什么不同?》,作者......
  • 什么是堡垒机(运维系统)
    堡垒机(BastionHost),也称为跳板机、边界机或前置机,是一种特别配置的计算机系统,它被设计为网络中的第一个防线。堡垒机通常位于一个组织的网络和外部互联网之间,是唯一允许从外部直接访问的内部主机。由于其特殊的地位,堡垒机经过了强化的安全配置,并且运行着专门设计来抵御攻击的操作......
  • 源代码加密是什么?如何做源代码加密?
    源代码加密是什么?如何做源代码加密?在软件开发过程中,版本管理工具如SVN和GIT是不可或缺的组成部分,它们帮助团队管理源代码的变更和版本。然而,这些工具也面临着源代码泄露的安全风险。如果不针对数据进行加密保护,很容易出现“一锅端”的现象。所以源代码开发环境复杂,涉及的开发软件......
  • QQOP数据:什么是op数据号?怎么提取op数据?能不能大量提取?
    ......
  • OpenAI 终于揭示了为什么 ChatGPT 不愿意提到 “David Meyer“
    如果您上周末上网,您可能会看到关于一个名为DavidMayer的人的奇怪新闻。他之所以成为热门人物,并不是因为某个重大事件或某个病毒式传播时刻,而是因为ChatGPT出现了一个奇怪的故障。无论用户如何努力,都无法让聊天机器人吐出他的名字。相反,它要么说到一半就僵住了,要么声......
  • 为什么不推荐使用jax ( jax vs pytorch)—— google推出jax后为什么迟迟没有得到业界
    在2017年后,Google的TensorFlow在与Facebook的pytorch的竞争中落败,于是为了重夺业内位置,Google在将开放重点从TensorFlow转为新开发一种新的工具框架,那就是jax。虽然在某种意义上来说Google已经放弃了TensorFlow,但是在Google内部依然保持着部分人员再继续维护和开发TensorFlow,但是......