首页 > 编程语言 >ChatGLM2 源码解析:`ChatGLMModel`

ChatGLM2 源码解析:`ChatGLMModel`

时间:2023-09-04 18:34:10浏览次数:46  
标签:None ChatGLMModel ChatGLM2 past init 源码 hidden config self

# 完整的 GLM 模型,包括嵌入层、编码器、输出层
class ChatGLMModel(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, device=None, empty_init=True):
        super().__init__(config)
        # 如果设置了`empty_init`,创建任何 PyTorch 模块时,不初始化参数
        if empty_init:
            init_method = skip_init
        else:
            init_method = default_init
        init_kwargs = {}
        if device is not None:
            init_kwargs["device"] = device
        # 单词嵌入层
        self.embedding = init_method(Embedding, config, **init_kwargs)
        # LC
        self.num_layers = config.num_layers
        # GC
        self.multi_query_group_num = config.multi_query_group_num
        # HS
        self.kv_channels = config.kv_channels

        # SL
        self.seq_length = config.seq_length
        rotary_dim = (
            config.hidden_size // config.num_attention_heads if config.kv_channels is None else config.kv_channels
        )
        # 位置嵌入(PE)
        self.rotary_pos_emb = RotaryEmbedding(rotary_dim // 2, original_impl=config.original_rope, device=device,
                                              dtype=config.torch_dtype)
        # GLM 编码器
        self.encoder = init_method(GLMTransformer, config, **init_kwargs)
        # 输出层
        self.output_layer = init_method(nn.Linear, config.hidden_size, config.padded_vocab_size, bias=False,
                                        dtype=config.torch_dtype, **init_kwargs)
        self.pre_seq_len = config.pre_seq_len
        self.prefix_projection = config.prefix_projection
        if self.pre_seq_len is not None:
            # 如果设置了前缀序列长度(PSL)
            # 关闭所有参数的自动梯度
            for param in self.parameters():
                param.requires_grad = False
            # [0, 1, ..., PSL - 1]
            self.prefix_tokens = torch.arange(self.pre_seq_len).long()
            # 初始化前缀编码层和 Dropout
            self.prefix_encoder = PrefixEncoder(config)
            self.dropout = torch.nn.Dropout(0.1)

    def get_input_embeddings(self):
        return self.embedding.word_embeddings

    def get_prompt(self, batch_size, device, dtype=torch.half):
        # prefix_tokens = [0, 1, ..., PSL - 1]
        # [PSL] => [1, PSL] => [BS, PSL]
        prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(device)
        # [BS, PSL, KVS=NL * HS * 2GC]
        past_key_values = self.prefix_encoder(prefix_tokens).type(dtype)
        # [BS, PSL, KVS=NL * HS * 2GC] => [BS, PSL, 2NL, GC, HS]
        past_key_values = past_key_values.view(
            batch_size,
            self.pre_seq_len,
            self.num_layers * 2,
            self.multi_query_group_num,
            self.kv_channels
        )
        
        past_key_values = self.dropout(past_key_values)
        # [BS, PSL, 2NL, GC, HS] => [2NL, PSL, BS, GC, HS] => NL * [2, PSL, BS, GC, HS]
        past_key_values = past_key_values.permute([2, 1, 0, 3, 4]).split(2)
        return past_key_values

    def forward(
            self,
            input_ids,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.BoolTensor] = None,
            full_attention_mask: Optional[torch.BoolTensor] = None,
            past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
    ):
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict
        # 输入是单词 ID,的形状为 [BS, SL]
        batch_size, seq_length = input_ids.shape
        # 将单词 ID 传递给词嵌入层得到嵌入向量
        if inputs_embeds is None:
            inputs_embeds = self.embedding(input_ids)

        # 如果设置了 PSL
        if self.pre_seq_len is not None:
            # 如果没有提供 KV 缓存,初始化为前 PSL 个前缀的词嵌入
            if past_key_values is None:
                past_key_values = self.get_prompt(batch_size=batch_size, device=input_ids.device,
                                                  dtype=inputs_embeds.dtype)
            if attention_mask is not None:
                attention_mask = torch.cat([attention_mask.new_ones((batch_size, self.pre_seq_len)),
                                            attention_mask], dim=-1)

        if full_attention_mask is None:
            if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
                full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

        # 计算 PE
        # 初始化位置编码层
        rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
        # 如果提供了位置 ID 就是用它检索位置嵌入矩阵
        # 如果没有,就返回嵌入矩阵的前 SL 个向量
        if position_ids is not None:
            rotary_pos_emb = rotary_pos_emb[position_ids]
        else:
            rotary_pos_emb = rotary_pos_emb[None, :seq_length]
        # [BS, SL, ES] => [SL, BS, ES]
        rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()

        # 将词嵌入和位置嵌入传给编码器得到编码器输出
        hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(
            inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,
            kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states
        )

        # 返回 GLM 输出,每层的 KV 缓存和每层的输出
        if not return_dict:
            return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)

        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=presents,
            hidden_states=all_hidden_states,
            attentions=all_self_attentions,
        )

    def quantize(self, weight_bit_width: int):
        from .quantization import quantize
        quantize(self.encoder, weight_bit_width)
        return self

标签:None,ChatGLMModel,ChatGLM2,past,init,源码,hidden,config,self
From: https://www.cnblogs.com/apachecn/p/17677791.html

相关文章

  • ChatGLM2 源码解析:`ChatGLMForConditionalGeneration.forward`
    classChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):def__init__(self,config:ChatGLMConfig,empty_init=True,device=None):super().__init__(config)self.max_sequence_length=config.max_lengthself.transformer=C......
  • ChatGLM2 源码解析:`GLMTransformer`
    #编码器模块,包含所有GLM块classGLMTransformer(torch.nn.Module):"""Transformerclass."""def__init__(self,config:ChatGLMConfig,device=None):super(GLMTransformer,self).__init__()self.fp32_residual_co......
  • ChatGLM2 源码解析:`MLP`
    classMLP(torch.nn.Module):"""MLP.MLPwilltaketheinputwithhhiddenstate,projectitto4*hhiddendimension,performnonlineartransformation,andprojectthestatebackintohhiddendimension.""&quo......
  • 百度上传下载组件源码
    ​ 以ASP.NETCoreWebAPI 作后端 API ,用 Vue 构建前端页面,用 Axios 从前端访问后端 API,包括文件的上传和下载。 准备文件上传的API #region 文件上传  可以带参数        [HttpPost("upload")]        publicJsonResultuploadProject(I......
  • java智慧工地:智慧工地大数据中心源码
    智慧工地技术架构:微服务+Java+SpringCloud+Vue+UniApp+MySql智慧工地形成安全、质量、进度、人员、机械、绿色施工六大针对性解决方案。 安全管理围绕重大危险源提供管控,可视化跟踪消防、安防、基坑、高支模、临边防护、卸料平台等设施设备的安全状态、管理痕迹、趋势预测,......
  • 分享实用工具源码--实现Windows IDE中查看Linux下编译信息
    作者:fbysss关键字:实用工具源码 Windows下查看Linux编译信息一、背景:本人写C程序不多,更不用说Linux下了。偶然一个机会,接了个这样的活,vi我用的还马马虎虎,但程序超过一千行,看起来就有些眼花了。于是只好在VC下编写代码,ftp传到Linux服务器,再用gcc编译,出错了再到VC下修改,再上传,如......
  • 直播带货源码,iOS 获取图片主题色
    直播带货源码,iOS获取图片主题色 -(void)getMostColorFormImage:(UIImage*)image{  WEAKSELF  [imagegetPaletteImageColorWithMode:ALL_MODE_PALETTEwithCallBack:^(PaletteColorModel*recommendColor,NSDictionary*allModeColorDic,NSError*error){   ......
  • 直播源码,自定义progressBar样式
    直播源码,自定义progressBar样式1、layout中xml布局如下: <RelativeLayout  android:layout_height="16dp"  android:layout_width="match_parent">  <ProgressBar    style="?android:attr/progressBarStyleHorizontal"    android......
  • 百度上传下载控件源码
    ​ 我们平时经常做的是上传文件,上传文件夹与上传文件类似,但也有一些不同之处,这次做了上传文件夹就记录下以备后用。首先我们需要了解的是上传文件三要素:1.表单提交方式:post(get方式提交有大小限制,post没有)2.表单的enctype属性:必须设置为multipart/form-data.3.表单必须......
  • 一口气用Python写了13个小游戏(附源码)
    今天给大家分享13个游戏源码,可以自己复现玩玩,研究下里面的编程逻辑,对学习编程(特别是初学者)应该会有很大帮助。1、吃金币源码分享:importosimportcfgimportsysimportpygameimportrandomfrommodulesimport*'''游戏初始化'''definitGame():#初始化pygame,设......