首页 > 其他分享 >transformer 中的bert是如何初始化的

transformer 中的bert是如何初始化的

时间:2023-03-15 20:45:14浏览次数:60  
标签:bert transformer 初始化 self module init weights initialize

动机:在看BertForMaskedLM 的实现代码时,发现在class init的时候有一个self.post_init() 函数,希望看一下它内部调用的哪个函数,以及如果我们自己定义了一些新的模型参数或者embedding怎么进行初始化?

在代码里有两个init_weights 函数,分别是post_init调用的,另一个我们可以用于初始化我们自己的参数:
1.def init_weights(self):

 def init_weights(self):
        """
        If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
        initialization logic in `_init_weights`.
        """
        # Prune heads if needed
        if self.config.pruned_heads:
            self.prune_heads(self.config.pruned_heads)

        if _init_weights:
            # Initialize weights
            self.apply(self._initialize_weights)

            # Tie weights should be skipped when not initializing all weights
            # since from_pretrained(...) calls tie weights anyways
            self.tie_weights()

Q: self.apply(self._initialize_weights) 是什么意思呢?
A: 在 Hugging Face 的 Transformers 库中,self.apply(self._initialize_weights) 这行代码的作用是在模型对象创建时,将 _initialize_weights 方法应用到模型的所有参数中,从而初始化所有的权重。

具体来说,self.apply() 是 PyTorch 中 nn.Module 的一个方法,用于将一个函数应用到模型的所有模块中。这个函数可以是任意的自定义函数,它会遍历模型的所有模块,并对每个模块应用这个函数。在这个过程中,模块本身作为参数传递给函数,因此函数可以对模块进行任意的操作。

在这里,我们将 self._initialize_weights() 方法作为参数传递给 self.apply() 方法,这样就可以将这个方法应用到模型的所有模块中。由于 _initialize_weights() 方法会遍历模型的所有模块,并对每个模块的权重进行初始化,因此这个操作可以确保模型的所有权重都被正确地初始化了。

以下是 self.apply() 方法的代码实现,可以看到,这个方法会递归遍历模型的所有子模块,并对每个子模块调用 fn() 函数,最终还会对模型本身调用 fn() 函数。在我们的例子中,fn() 函数就是 _initialize_weights() 方法。

def apply(self, fn):
    """Applies ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
    Typical use includes initializing the parameters of a model (see also :meth:`torch.nn.Module.apply`).
    """
    for module in self.children():
        module.apply(fn)
    fn(self)

Q: _initialize_weights() 方法是怎样定义的呢
A: 在 Hugging Face 的 Transformers 库中,PreTrainedModel 类的 _initialize_weights() 方法是在 modeling_utils.py 文件中定义的。这个方法用于初始化模型的权重,并且在模型的构造函数中被调用。

具体来说,_initialize_weights() 方法会遍历模型的所有模块,对每个模块的权重进行初始化。如果模块包含可训练的权重,则会使用 PyTorch 的默认的权重初始化方法来初始化这些权重。如果模块是一个 nn.Linear 层,则会使用正态分布的方法来初始化权重,同时使用小的偏置值来避免死亡 ReLU 问题。

以下是 _initialize_weights() 方法的代码实现,注意到这个方法会根据模块的类型不同采取不同的初始化方法,以确保每个权重都能够得到合适的初始化。

def _initialize_weights(self, module):
    """Initialize the weights."""
    if isinstance(module, (nn.Linear, nn.Embedding)):
        # 使用默认的初始化方法来初始化可训练的权重
        module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
    if isinstance(module, nn.Linear) and module.bias is not None:
        # 避免死亡 ReLU 问题,使用小的偏置值来初始化
        module.bias.data.zero_()
    elif isinstance(module, nn.LayerNorm):
        # 初始化层归一化的权重为 1,偏置为 0
        module.bias.data.zero_()
        module.weight.data.fill_(1.0)

2.def _init_weights(self, module):

 def _init_weights(self, module):
        """Initialize the weights"""
        if isinstance(module, nn.Linear):
            # Slightly different from the TF version which uses truncated_normal for initialization
            # cf https://github.com/pytorch/pytorch/pull/5617
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)

这个函数和刚刚的_initialize_weights作用一致,_initialize_weights也可以写成如下样子:

def _initialize_weights(self, module):
        """
        Initialize the weights if they are not already initialized.
        """
        if getattr(module, "_is_hf_initialized", False):
            return
        self._init_weights(module)
        module._is_hf_initialized = True

Q: 我们如果有新定义的参数如何进行初始化?
A:如果你初始化了新的embedding,你可以如下方式初始化

# Build new embeddings
new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
new_embeddings.to(old_embeddings.weight.device, dtype=old_embeddings.weight.dtype)

# initialize all new embeddings (in particular added tokens)
self._init_weights(new_embeddings)

其他:
我们一般会用 from_pretrained 加载预训练好的模型参数,在这种情况下,大概的模型加载流程如下:

  1. 找到正确的基础模型类进行初始化
  2. 使用伪随机初始化来初始化该类(通过使用_init_weights您提到的函数)
  3. 找到具有预训练权重的文件
  4. 在适用的情况下使用预先训练的权重覆盖我们刚刚创建的模型的权重,在初始化参数时,如果模型结构与预训练模型不同,那么只有与预训练模型相同的部分才会被初始化。

标签:bert,transformer,初始化,self,module,init,weights,initialize
From: https://www.cnblogs.com/carolsun/p/17219871.html

相关文章

  • GPU服务器无root权限conda初始化
    1.给anaconda文件写入权限sudochmoda+w.conda如果没有权限则会在创建环境时报以下错误NoWritableEnvsDirError:Nowriteableenvsdirectoriesconfigured.-......
  • CVPR2023 | 集成预训练金字塔结构的Transformer模型
    前言 本文提出了一种新的预训练模型架构(iTPN),该架构由多个金字塔形的Transformer层组成。每个层都包含多个子层,其中一些是普通的self-attention和feed-forward层,而另一些......
  • CSS - css初始化
    常见的初始化:html,body,ul,li,ol,dl,dd,dt,p,h1,h2,h3,h4,h5,h6,form,fieldset,legend,img{margin:0;padding:0;}fieldset,img,in......
  • AAAI 2023 | 一种通用的粗-细视觉Transformer加速方案
    前言 VisionTransformers中,输入图像的空间维度会出现相当大的冗余,从而导致大量的计算成本。因此,本文中提出了一种由粗到精的视觉变换器(CF-ViT)来减轻计算负担,同时保持性......
  • Crossformer: Transformer Utilizing Cross-Dimension Dependency for Multivariate T
    用于时序预测的Transformer也是基于分块思路,跨时间、跨维度对齐https://openreview.net/forum?id=vSVLM2j9eiehttps://github.com/Thinklab-SJTU/Crossformer上海交通......
  • Transformer中的encoder与decoder
    Transformer是一种非常强大的神经网络架构,被广泛应用于自然语言处理任务中。它的核心部分是由若干个Encoder和Decoder组成的。下面简要介绍一下Encoder和Decoder的区别。......
  • 别再蒸馏3层BERT了!变矮又能变瘦的DynaBERT了解一下
    一只小狐狸带你解锁炼丹术&NLP秘籍神经网络模型除了部署在远程服务器之外,也会部署在手机、音响等智能硬件上。比如在自动驾驶的场景下,大部分模型都得放在车上的终端里,不然荒......
  • Transformer
    Reference:https://builtin.com/artificial-intelligence/transformer-neural-network1.AdvantagesoverRNNOvercomesthevanishinggradientissuebymulti-heade......
  • 分析类初始化
    分析类初始化!什么时候类会初始化主动引用main方法被调用时,其所在的类会首先被初始化new一个类对象调用类的静态成员和方法(除了final常量)反射调用时初始化子类时......
  • 注意力机制和Transformer原理,其他文章看不懂就看这个吧,根据《python深度学习》 和 《
      注意力机制和Transformer原理,网上一堆文章都没有说清楚,自己根据《python深度学习》和《动手学深度学习》这两本书结合起来总结下。两本书的地址:https://zh.d2l.a......