首页 > 其他分享 >解决加载GPT2(Tensorflow预训练模型)的Linear权重到PyTorch的Linear权重 形状不匹配(互为转置)问题

解决加载GPT2(Tensorflow预训练模型)的Linear权重到PyTorch的Linear权重 形状不匹配(互为转置)问题

时间:2024-04-17 18:22:39浏览次数:22  
标签:768 Linear 权重 转置 torch 3072 shape model Size

解决报错内容:

RuntimeError: Error(s) in loading state_dict for PyTorchBasedGPT2:

size mismatch for transformer.h.0.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768])......

 

一、错误原因分析

Pytorch中,Linear层的权重存储形状为[out_features, in_features]。而Tensorflow中Linear权重的存储形状为[in_features, out_features]。

这是由于两个库使用不同的数学运算表示 (参考https://www.null123.com/question/detail-2816063.html):

Pytorch: y = Wx + B

Tensorflow: y = xW + B

当直接使用pytorch实现的GPT2架构模型去加载GPT2的预训练参数时会发生:

1 PyTorchBasedGPT2.from_pretrained("openai-community/gpt2")
 1 RuntimeError: Error(s) in loading state_dict for PyTorchBasedGPT2:
 2     size mismatch for transformer.h.0.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
 3     size mismatch for transformer.h.0.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
 4     size mismatch for transformer.h.0.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
 5     size mismatch for transformer.h.1.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
 6     size mismatch for transformer.h.1.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
 7     size mismatch for transformer.h.1.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
 8     size mismatch for transformer.h.2.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
 9     size mismatch for transformer.h.2.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
10     size mismatch for transformer.h.2.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
11     size mismatch for transformer.h.3.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
12     size mismatch for transformer.h.3.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
13     size mismatch for transformer.h.3.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
14     size mismatch for transformer.h.4.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
15     size mismatch for transformer.h.4.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
16     size mismatch for transformer.h.4.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
17     size mismatch for transformer.h.5.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
18     size mismatch for transformer.h.5.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
19     size mismatch for transformer.h.5.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
20     size mismatch for transformer.h.6.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
21     size mismatch for transformer.h.6.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
22     size mismatch for transformer.h.6.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
23     size mismatch for transformer.h.7.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
24     size mismatch for transformer.h.7.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
25     size mismatch for transformer.h.7.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
26     size mismatch for transformer.h.8.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
27     size mismatch for transformer.h.8.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
28     size mismatch for transformer.h.8.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
29     size mismatch for transformer.h.9.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
30     size mismatch for transformer.h.9.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
31     size mismatch for transformer.h.9.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
32     size mismatch for transformer.h.10.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
33     size mismatch for transformer.h.10.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
34     size mismatch for transformer.h.10.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
35     size mismatch for transformer.h.11.attn.c_attn.weight: copying a param with shape torch.Size([768, 2304]) from checkpoint, the shape in current model is torch.Size([2304, 768]).
36     size mismatch for transformer.h.11.mlp.c_fc.weight: copying a param with shape torch.Size([768, 3072]) from checkpoint, the shape in current model is torch.Size([3072, 768]).
37     size mismatch for transformer.h.11.mlp.c_proj.weight: copying a param with shape torch.Size([3072, 768]) from checkpoint, the shape in current model is torch.Size([768, 3072]).
38     You may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.
View Error

 

二、解决方法

这时需要将原本的权重转置后再使用Model.from_pretrained()加载模型。

1. 从Huggingface上拉模型,model_path为huggingface的repo名

1 model_path = "openai-community/gpt2"
2 model = transformers.AutoModelForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)

2. 转置原始权重中Linear的权重矩阵

  如果不确定如何获取矩阵可以先输出模型查看一下:

1 print(model)

获取权重并转置,在这里需要转置attn中的c_attn和c_proj,mlp中的c_fc和c_proj。(这几层看起来是卷积,但是代码实现实际上就是Linear层)

1 for layer in model.transformer.h:
2      layer.attn.c_attn.weight = torch.nn.Parameter(layer.attn.c_attn.weight.transpose(0, 1).contiguous()) # .contiguous()负责返回一个数据相同但内存布局连续的新张量
3      layer.attn.c_proj.weight = torch.nn.Parameter(layer.attn.c_proj.weight.transpose(0, 1).contiguous())
4      layer.mlp.c_fc.weight = torch.nn.Parameter(layer.mlp.c_fc.weight.transpose(0, 1).contiguous())
5      layer.mlp.c_proj.weight = torch.nn.Parameter(layer.mlp.c_proj.weight.transpose(0, 1).contiguous())

3. 最后存储model到指定路径

1 output_dir = "new_gpt2"
2 model.save_pretrained(output_dir)

这样在pytorch实现的类GPT2模型加载参数时就可以顺利从指定路径加载了:

1 model = PyTorchBasedGPT2.from_pretrained("new_gpt2")
2 print(model)

得到模型:

标签:768,Linear,权重,转置,torch,3072,shape,model,Size
From: https://www.cnblogs.com/pplap/p/18141452

相关文章

  • 论文解读(Polynormer)《Polynormer: Polynomial-Expressive Graph Transformer in Linea
    Note:[wechat:Y466551|可加勿骚扰,付费咨询]2024年4月14日17:13:41论文信息论文标题:Polynormer:Polynomial-ExpressiveGraphTransformerinLinearTime论文作者:论文来源:2024 aRxiv论文地址:download论文代码:download视屏讲解:click1-摘要图转换器(GTs)已经成为一种......
  • [深度学习]L2正则化和权重衰退(Weight Decay)
    L2正则化和权重衰退(WeightDecay)一、权重衰退介绍1.什么是权重衰减/权重衰退——weight_decayL2正则化主要作用是:解决过拟合,在损失函数中加入L2正则化项2.L2范数L2范数,也被称作欧几里得范数或者Frobenius范数(当应用于矩阵时),是最常用的向量范数之一,用于衡量向量元......
  • atcgis反距离权重插值
    前面导入盟市界、旗县界shp,还有站点经纬度、PA等级什么的我就不说了站点经纬度数据内容: 主要想做的事情是插值PA_class的内容。第一步:自定义-扩展模块-勾选这两个第二步:自定义-工具条-geostatisticalanalyst 点击geostatisticalanalyst里的  地统计向导-反距离权......
  • PiSSA :将模型原始权重进行奇异值分解的一种新的微调方法
    我们开始看4月的新论文了,这是来自北京大学人工智能研究所、北京大学智能科学与技术学院的研究人员发布的PrincipalSingularValuesandSingularVectorsAdaptation(PiSSA)方法。PiSSA和LoRA一样,都是基于这样的前提:对模型参数的改变会形成一个低秩矩阵。这种方法通过将模型中的......
  • SciTech-Mathmatics-Advanced Algebra-LinearAlgebra: 矩阵的相抵、相似与合同
    https://www.math.pku.edu.cn/teachers/baozq/algebra/alg1.htm矩阵的相抵、相似与合同基本概念:相抵,相抵标准形相似,对角化,迹,可对角化矩阵的相似标准形特征值,特征向量,特征多项式,特征子空间正交矩阵,Kn的内积,标准正交基实对称矩阵的正交相似标准形二次型......
  • 最简单知识点PyTorch中的nn.Linear(1, 1)
    一、nn.Linear(1,1)nn.Linear(1,1) 是PyTorch中的一个线性层(全连接层)的定义。nn 是PyTorch的神经网络模块(torch.nn)的常用缩写。nn.Linear(1,1) 的含义如下:第一个参数 1:输入特征的数量。这表示该层接受一个长度为1的向量作为输入。第二个参数 1:输出特征的数量......
  • 在Blender中,重新调整已经绑定权重的骨骼位置而不影响绑定的顶点位置
    在Blender中,重新调整已经绑定权重的骨骼位置而不影响绑定的顶点位置,是一个比较特殊的需求。这通常涉及到调整骨骼的“RestPose”(休息姿势),而不是它的“PosePosition”(姿势位置),以保持顶点相对于骨骼的位置不变。下面是详细的步骤和一些建议,以达到这个目的:1. 准备工作确保你的......
  • 转置原理小练习:Do Use FFT
    \(\text{Link}\)题意给定三个长为\(n\)的数组\(a_{0,\dots,n-1},b_{0,\dots,n-1},c_{0,\dots,n-1}\),对\(\foralli\in[0,n-1]\)求出:\[d_i=\sum_{j=0}^{n-1}c_j\prod_{k=0}^i(a_j+b_k)\]对\(998244353\)取模。\(n\le2.5\times10^5\)。思路将\(a,b\)看成常......
  • [ABC211F] Rectilinear Polygons 题解
    [ABC211F]RectilinearPolygons题解思路什么的上一篇题解已经写的非常明白了,这里只是提供一个补充&另一个实现的方法。思路解析先说结论:扫描线。顾名思义,扫描线的本质就是用一条线沿着\(x\)或\(y\)轴扫过去,每碰到一条边就记录一下加边后是面积是增加还是减少,然后用树状......
  • Python环境下基于小波分析的Linear电磁谱降噪
    小波变换以其良好的时频局部化特性,成功地解决了保护信号局部性和抑制噪声之间的矛盾,因此小波技术在信号降噪中得到了广泛的研究,并获得了非常好的应用效果。小波降噪中最常用的方法是小波阈值降噪。基于小波变换的阈值降噪关键是要解决两个问题:阈值的选取和阈值函数的确定,目前常......