首页 > 编程语言 >ChatGLM2 源码解析:`ChatGLMForConditionalGeneration.forward`

ChatGLM2 源码解析:`ChatGLMForConditionalGeneration.forward`

时间:2023-09-04 18:33:36浏览次数:53  
标签:None hidden return ChatGLM2 源码 forward logits Optional self

class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
    def __init__(self, config: ChatGLMConfig, empty_init=True, device=None):
        super().__init__(config)

        self.max_sequence_length = config.max_length
        self.transformer = ChatGLMModel(config, empty_init=empty_init, device=device)
        self.config = config
        self.quantized = False

        if self.config.quantization_bit:
            self.quantize(self.config.quantization_bit, empty_init=True)
    
    def forward(
            self,
            input_ids: Optional[torch.Tensor] = None,
            position_ids: Optional[torch.Tensor] = None,
            attention_mask: Optional[torch.Tensor] = None,
            past_key_values: Optional[Tuple[torch.FloatTensor]] = None,
            inputs_embeds: Optional[torch.Tensor] = None,
            labels: Optional[torch.Tensor] = None,
            use_cache: Optional[bool] = None,
            output_attentions: Optional[bool] = None,
            output_hidden_states: Optional[bool] = None,
            return_dict: Optional[bool] = None,
            return_last_logit: Optional[bool] = False,
    ):
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        transformer_outputs = self.transformer(
            input_ids=input_ids,
            position_ids=position_ids,
            attention_mask=attention_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = transformer_outputs[0]
        # `return_last_logit`表示只保留最后一个单词的
        if return_last_logit:
            hidden_states = hidden_states[-1:]
        # 将编码器输出传入输出层得到单词概率
        lm_logits = self.transformer.output_layer(hidden_states)
        # [SL, BS, ...] => [BS, SL, ...]
        lm_logits = lm_logits.transpose(0, 1).contiguous()

        loss = None
        if labels is not None:
            lm_logits = lm_logits.to(torch.float32)

            # 让第 i 个词前面的单词预测第 i 个词
            # 假如原文是 [A, B, C, D, E]
            # logits = [A, B, C, D],labels = [B, C, D, E]
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # 单词 Logits 变形为 [BS * (SL - 1), VS]
            # 标签变形为 [BS * (SL - 1)]
            # 计算交叉熵
            loss_fct = CrossEntropyLoss(ignore_index=-100)
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

            lm_logits = lm_logits.to(hidden_states.dtype)
            loss = loss.to(hidden_states.dtype)

        # 返回损失、单词 Logits、KV 缓存、编码器输出、以及编码器注意力矩阵
        if not return_dict:
            output = (lm_logits,) + transformer_outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=lm_logits,
            past_key_values=transformer_outputs.past_key_values,
            hidden_states=transformer_outputs.hidden_states,
            attentions=transformer_outputs.attentions,
        )

标签:None,hidden,return,ChatGLM2,源码,forward,logits,Optional,self
From: https://www.cnblogs.com/apachecn/p/17677794.html

相关文章

  • ChatGLM2 源码解析:`GLMTransformer`
    #编码器模块,包含所有GLM块classGLMTransformer(torch.nn.Module):"""Transformerclass."""def__init__(self,config:ChatGLMConfig,device=None):super(GLMTransformer,self).__init__()self.fp32_residual_co......
  • ChatGLM2 源码解析:`MLP`
    classMLP(torch.nn.Module):"""MLP.MLPwilltaketheinputwithhhiddenstate,projectitto4*hhiddendimension,performnonlineartransformation,andprojectthestatebackintohhiddendimension.""&quo......
  • 百度上传下载组件源码
    ​ 以ASP.NETCoreWebAPI 作后端 API ,用 Vue 构建前端页面,用 Axios 从前端访问后端 API,包括文件的上传和下载。 准备文件上传的API #region 文件上传  可以带参数        [HttpPost("upload")]        publicJsonResultuploadProject(I......
  • java智慧工地:智慧工地大数据中心源码
    智慧工地技术架构:微服务+Java+SpringCloud+Vue+UniApp+MySql智慧工地形成安全、质量、进度、人员、机械、绿色施工六大针对性解决方案。 安全管理围绕重大危险源提供管控,可视化跟踪消防、安防、基坑、高支模、临边防护、卸料平台等设施设备的安全状态、管理痕迹、趋势预测,......
  • 分享实用工具源码--实现Windows IDE中查看Linux下编译信息
    作者:fbysss关键字:实用工具源码 Windows下查看Linux编译信息一、背景:本人写C程序不多,更不用说Linux下了。偶然一个机会,接了个这样的活,vi我用的还马马虎虎,但程序超过一千行,看起来就有些眼花了。于是只好在VC下编写代码,ftp传到Linux服务器,再用gcc编译,出错了再到VC下修改,再上传,如......
  • 直播带货源码,iOS 获取图片主题色
    直播带货源码,iOS获取图片主题色 -(void)getMostColorFormImage:(UIImage*)image{  WEAKSELF  [imagegetPaletteImageColorWithMode:ALL_MODE_PALETTEwithCallBack:^(PaletteColorModel*recommendColor,NSDictionary*allModeColorDic,NSError*error){   ......
  • 直播源码,自定义progressBar样式
    直播源码,自定义progressBar样式1、layout中xml布局如下: <RelativeLayout  android:layout_height="16dp"  android:layout_width="match_parent">  <ProgressBar    style="?android:attr/progressBarStyleHorizontal"    android......
  • 百度上传下载控件源码
    ​ 我们平时经常做的是上传文件,上传文件夹与上传文件类似,但也有一些不同之处,这次做了上传文件夹就记录下以备后用。首先我们需要了解的是上传文件三要素:1.表单提交方式:post(get方式提交有大小限制,post没有)2.表单的enctype属性:必须设置为multipart/form-data.3.表单必须......
  • 一口气用Python写了13个小游戏(附源码)
    今天给大家分享13个游戏源码,可以自己复现玩玩,研究下里面的编程逻辑,对学习编程(特别是初学者)应该会有很大帮助。1、吃金币源码分享:importosimportcfgimportsysimportpygameimportrandomfrommodulesimport*'''游戏初始化'''definitGame():#初始化pygame,设......
  • 【腾讯云 Cloud Studio 实战训练营】使用在线编程的方式用Nuxt3开发一个后台管理系统(
    前言大家好,我是刘明,开源技术爱好者,十年创业老兵。CSDN近期联合腾讯云、Coding、CloudStudio组织了【腾讯云CloudStudio实战训练营活动】,苦于前些日子一直在备考注册会计师,没有很好的体验CloudStudio的云IDE产品。现在考试结束了,体验了一把云IDE,不禁感慨云端开发原来可以这么......