首页 > 其他分享 >聊聊ChatGLM中P-tuning v2的应用

聊聊ChatGLM中P-tuning v2的应用

时间:2024-01-11 11:36:51浏览次数:31  
标签:tuning self torch prefix past v2 ChatGLM hidden size

论文PDF地址:https://arxiv.org/pdf/2110.07602.pdf

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

P-Tuning v2

摘录自第三部分

image.png

桔色块指代可训练的prompt embedding;蓝色块是由固定(冻结)的预训练语言模型 存储或计算的embedding。

Deep Prompt Tuning

continuous prompts(连续提示) 仅仅能够插入到input embedding序列层。如此,有两个问题:首先由于序列长度的约束限制,可调参数的数量有限。其次,输入的embedding对模型预测有间接的影响。
为了解决这些问题,P-Tuning v2使用deep prompt tuning的方案。正如上图的b部分,prompt作为prefix token插入到不同的层中。一方面,p-tuning v2有更多可调的特定任务参数(从 0.01% 到 0.1%~3%),扩大了任务的容量也提高了参数效率;另一方面,添加到更深层的prompt对模型的预测会有更多直接的影响。

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

理解

在P-tuning v2的方案中,从图直观来看,有两个关键的点:

  1. prompts会加在序列的前端,而不仅仅是插入到input embedding
  2. 每一层都会插入prompts

v2版本主要基于p-tuning和prefix-tuning技术。prompt 向量是在模型的 embedding 层与其他输入 token 的 embedding 相拼接的,且通过在预训练模型的每一层引入可训练的 prompt 向量来提高模型对特定任务的适应性。
p-tuning主要是利用一个prompt encoder,将prompt先encoder再与input embedding进行拼接。
prefix-tuning是在Transformer的Encoder和Decoder的网络中都加了一些特定的前缀。
而基于这两种技术的v2版本,则是将两者结合。在embedding与transformer模块都做了prompt向量的插入。
ChatGLM中,首先要对prompt做encode,作为前缀prefix拼接插入到input embedding与transformer模型中。


# 转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

class PrefixEncoder(torch.nn.Module):
    """
    The torch.nn model to encode the prefix
    Input shape: (batch-size, prefix-length)
    Output shape: (batch-size, prefix-length, 2*layers*hidden)
    """

    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        if self.prefix_projection:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(config.hidden_size, config.hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(config.hidden_size, config.num_layers * config.hidden_size * 2)
            )
        else:
            self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_layers * config.hidden_size * 2)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values

在ChatGLMModel中调用并插入到每一个transformer模型层中。

class ChatGLMModel(ChatGLMPreTrainedModel):
    '''
    省略其它....
    '''
    def __init__(self, config: ChatGLMConfig, empty_init=True):
        if self.pre_seq_len is not None:
            for param in self.parameters():
                param.requires_grad = False
            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
            # encode prompt
            self.prefix_encoder = PrefixEncoder(config)
            self.dropout = torch.nn.Dropout(0.1)

    # 调用prompt
	def get_prompt(self, batch_size, device, dtype=torch.half):
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
        # 调用prompt并返回
        past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
        past_key_values = past_key_values.view(
            batch_size,
            self.pre_seq_len,
            self.num_layers * 2,
            self.num_attention_heads,
            self.hidden_size // self.num_attention_heads
        )
        # seq_len, b, nh, hidden_size
        past_key_values = self.dropout(past_key_values)
        past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
        # past_key_values = [(v[0], v[1]) for v in past_key_values]
        return past_key_values

    # 返回transformer模型
    def get_layer(layer_id):
        return GLMBlock(
            self.hidden_size,
            self.num_attention_heads,
            self.layernorm_epsilon,
            layer_id,
            inner_hidden_size=self.inner_hidden_size,
            hidden_size_per_attention_head=self.hidden_size_per_attention_head,
            layernorm=LayerNorm,
            use_bias=True,
            params_dtype=self.params_dtype,
            position_encoding_2d=self.position_encoding_2d,
            empty_init=empty_init
        )
                
	def forward(
            self,
            input_ids: Optional[torch.LongTensor] = None,
            position_ids: Optional[torch.LongTensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            inputs_embeds: Optional[torch.LongTensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPast]:
    	# 其它代码
        if past_key_values is None:
            if self.pre_seq_len is not None:
                # 调用prompt
                past_key_values = self.get_prompt(batch_size=input_ids.shape[0], device=input_ids.device,
                                                  dtype=inputs_embeds.dtype)
            else:
                past_key_values = tuple([None] * len(self.layers))

            if attention_mask is None:
                attention_mask = self.get_masks(
                    input_ids,
                    device=input_ids.device
                )
    	# 其它代码
    	for i, layer in enumerate(self.layers):

            if output_hidden_states:
                all_hidden_states = all_hidden_states + (hidden_states,)
            # 准备参数传递到layer
            layer_past = past_key_values[i]
        	# 每个layer 是一个GLMBlock即transformer模型层
            if self.gradient_checkpointing and self.training:
                # 将prompt传递到每个层中
                layer_ret = torch.utils.checkpoint.checkpoint(
                    layer,
                    hidden_states,
                    position_ids,
                    attention_mask,
                    torch.tensor(i),
                    layer_past,
                    use_cache,
                    output_attentions
                )
            else:
                layer_ret = layer(
                    hidden_states,
                    position_ids=position_ids,
                    attention_mask=attention_mask,
                    layer_id=torch.tensor(i),
                    layer_past=layer_past,
                    use_cache=use_cache,
                    output_attentions=output_attentions
                )
        # 其它代码

参考

大模型微调之P-tuning方法解析

通俗解读大模型微调(Fine Tuning)

转载请备注出处:https://www.cnblogs.com/zhiyong-ITNote/

标签:tuning,self,torch,prefix,past,v2,ChatGLM,hidden,size
From: https://www.cnblogs.com/zhiyong-ITNote/p/17958162

相关文章

  • AP8854 宽压降压电源管理芯片12-80V 7v2.5A 应用于电动车手暖套的PBC线路
    AP8854一款宽电压范围降压型DC-D电源管理芯片,内部集成使能开关控制、基准电源、误差放大器、过热保护、限流保护、短路保护等功能,非常适合宽电压输入降压使用。AP8854带使能控制,可以大大节省外围器件,更加适合电池场合使用,具有很高的方案性价比。产品特点:电压输入范围10V至120......
  • 使用cv2.getOptimalNewCameraMatrix函数,变为圆形是出现什么错误
    cv2.getOptimalNewCameraMatrix函数用于计算一个新的相机矩阵,以进行图像畸变校正。这个函数的目标是通过考虑畸变的影响,生成一个新的相机矩阵,使得校正后的图像更接近理想的情况。cv2.getOptimalNewCameraMatrix(cameraMatrix,distCoeffs,imageSize,alpha,newImgSize)其中......
  • AP8854 宽压降压电源管理芯片12-80V 7v2.5A 应用于电动车手暖套的PBC线路
    AP8854一款宽电压范围降压型DC-D电源管理芯片,内部集成使能开关控制、基准电源、误差放大器、过热保护、限流保护、短路保护等功能,非常适合宽电压输入降压使用。AP8854带使能控制,可以大大节省外围器件,更加适合电池场合使用,具有很高的方案性价比。产品特点:电压输入范围10V至......
  • opensuse修改cgroup到v2
    识别Linux节点上的cgroup版本cgroup版本取决于正在使用的Linux发行版和操作系统上配置的默认cgroup版本。要检查你的发行版使用的是哪个cgroup版本,请在该节点上运行stat-fc%T/sys/fs/cgroup/命令:对于cgroupv2,输出为cgroup2fs。对于cgroupv1,输出为tmpfs......
  • 界面组件DevExpress WPF v23.2 - 更轻量级的主题支持
    DevExpressWPFSubscription拥有120+个控件和库,将帮助您交付满足甚至超出企业需求的高性能业务应用程序。通过DevExpressWPF能创建有着强大互动功能的XAML基础应用程序,这些应用程序专注于当代客户的需求和构建未来新一代支持触摸的解决方案。DevExpressWPF控件日前正式发布了......
  • 世微AP3464同步降压恒压IC 4-30V2.4A输出车充专用驱动芯片
    AP3464是一款支持宽电压输入的同步降压电源管理芯片,输入电压4-30V范围内可实现2.4A的连续电流输出。通过调节FB端口的分压电阻,设定输出1.8V到28V的稳定电压。AP3464具有的恒压/恒流(CC/CV)特性。AP3464采用电流模式的环路控制原理,实现了快速的动态响应。A......
  • 百度地图JavaScript API v2.0创建地图
    接口文档:https://lbsyun.baidu.com/index.php?title=jspopular3.0https://lbs.baidu.com/faq/api?title=webapi地图创建代码:<!DOCTYPEhtml><html><head> <metahttp-equiv="Content-Type"content="text/html;charset=utf-8"/> &......
  • LLM增强LLM;通过预测上下文来提高文生图质量;Spikformer V2;同时执行刚性和非刚性编辑的
    文章首发于公众号:机器感知LLM增强LLM;通过预测上下文来提高文生图质量;SpikformerV2;同时执行刚性和非刚性编辑的通用图像编辑框架LLMAugmentedLLMs:ExpandingCapabilitiesthroughComposition本文研究了如何高效地组合现有的基础模型以实现新功能的问题,文章提出了CALM(Comp......
  • 大语言模型优化方法简介:Prompt、RAG、Fine-tuning
    GPT、LLama、Gemini等大语言模型虽展现出强大能力,但在实际应用中仍有问题,例如在准确性、知识更新速度和答案透明度方面,仍存在挑战。论文“Retrieval-AugmentedGenerationforLargeLanguageModels:ASurvey(面向大语言模型的检索增强生成技术:调查)”https://arxiv.org/abs/231......
  • ChatGLM-6B应用
    ChatGLM-6B是一个开源的、支持中英双语的对话语言模型,基于GeneralLanguageModel(GLM)架构,具有62亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本地部署(INT4量化级别下最低只需6GB显存)。ChatGLM-6B使用了和ChatGPT相似的技术,针对中文问答和对话进行了优化。......