首页 > 编程语言 >大语言模型生成模型的源码结构复习

大语言模型生成模型的源码结构复习

时间:2023-12-26 17:14:50浏览次数:34  
标签:lm 复习 模型 labels next tokens token 源码 logits

modeling_gpt2.py:1099

        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(lm_logits.device)
            # Shift so that tokens < n predict n
            shift_logits = lm_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

这个是静电的生成模型. 看lm_logits[:,:-1,:], labels[:,1:]这个就是输入0,1,2,3,4
然后输出的特征shape 是 1, 5, 30000. 字典大小3w
labels也是0,1,2,3,4
所以网络输出的含义是: lm_logits[:,:-1,:]=1,2,3,4 所以跟labels[:,1:]=1,2,3,4 算交叉熵即可.
总结. 给gpt2或者所有的causallm模型喂入 序列1,2,3,4那么他输出的shape大小是1,4,3w. 其中4的维度上第一个值表示的是1的后续token预测值在3w上的概率分布. 第二个值表示2的后续token在....,
用\(*\)表示估计值的语言来说就是 2,3,4,5.
所以你会经常看到代码 取lm_logits[:,-1,:]再softmax就预测了下一个token是什么. 也就是从1,2,3,4预测得到了5.

参考代码:
D:\Users\admin\miniconda3\Lib\site-packages\transformers\generation\utils.py:2531行

next_token_logits = outputs.logits[:, -1, :]
            next_tokens = torch.argmax(next_tokens_scores, dim=-1)

            # finished sentences should have their next token be a padding token
            if eos_token_id is not None:
                if pad_token_id is None:
                    raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
                next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

            # update generated ids, model inputs, and length for next step
            input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)

从这里面代码我们也很清楚看到每生成一个token,他就torch.cat到之前的input_ids里面.再循环生成.

            if stopping_criteria(input_ids, scores):
                this_peer_finished = True

一直到这个结尾判定成功就停止生成了.

以上就是gpt2的生成代码的全部分析了. 需要一定掌握.非常非常重要.

标签:lm,复习,模型,labels,next,tokens,token,源码,logits
From: https://www.cnblogs.com/zhangbo2008/p/17928803.html

相关文章

  • 7. Java 内存模型
    Java内存模型Java内存模型(JavaMemoryModel)的主要目的是定义程序中各种变量的访问规则,即关注在虚拟机中把变量值存储到内存和从内存中取出变量值这样的底层细节1.主内存与工作内存Java内存模型规定了所有的变量都存储在主内存(MainMemory)中(虚拟机内存的一部分)。每条线程......
  • Runway官宣下场通用世界模型!解决视频AI最大难题,竟靠AI模拟世界?
    前言 Runway突然发布公告,宣称要开发通用世界模型,解决AI视频最大难题,未来要用AI模拟世界。本文转载自新智元仅用于学术分享,若侵权请联系删除欢迎关注公众号CV技术指南,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读、CV招聘信息。CV各大方向专栏与各个部署框架最全......
  • 基于SpringBoot+Vue的毕业设计系统的开发设计实现(源码+lw+部署文档+讲解等)
    (文章目录)前言:heartpulse:博主介绍:✌全网粉丝10W+,CSDN特邀作者、博客专家、CSDN新星计划导师、全栈领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌:heartpulse:......
  • Spring MVC 源码分析 - 一个请求的旅行过程
    在上一篇《WebApplicationContext容器的初始化》文档中分析了SpringMVC是如何创建两个容器的,其中创建RootWebApplicationContext 后,调用其refresh()方法会触发刷新事件,完成SpringIOC初始化相关工作,会初始化各种SpringBean到当前容器中,该系列文档暂不分析我们先来了解一......
  • 查看onnx模型结构-使用Netron模块
    查看onnx模型结构-使用Netron模块1安装$pipinstallnetron2可选-查看安装的路径$pipshownetron3查看onnx结构importnetron#�??�?ONNX模�??�??件�??路�?onnx_model_path=r'yolo5/yolov5n-seg_toXiaoLiu/model/yolov5n-seg.onnx'#�?�"�netron�?��?��?�??ONNX模�??net......
  • 【源码系列#04】Vue3侦听器原理(Watch)
    专栏分享:vue2源码专栏,vue3源码专栏,vuerouter源码专栏,玩具项目专栏,硬核......
  • AI大模型:从GPT4到BERT,发展趋势与比较
    1.背景介绍自从深度学习技术在2012年的ImageNet大赛中取得了突破性的成果以来,人工智能领域的发展就不断加速。随着计算能力和数据规模的不断提高,人工智能技术的应用也不断拓展。在自然语言处理(NLP)领域,GPT(GenerativePre-trainedTransformer)和BERT(BidirectionalEncoderRepresenta......
  • 自然语言理解与语言模型:结合的力量
    1.背景介绍自然语言理解(NaturalLanguageUnderstanding,NLU)和自然语言模型(LanguageModel,LM)是人工智能领域中的两个重要概念。NLU涉及到从自然语言文本中抽取出有意义的信息,以便于进行进一步的处理和分析。而自然语言模型则是一种用于预测给定上下文中下一个词的统计模型。在这......
  • Diffusion 扩散模型
    Diffusion扩散模型对比GAN和VAE扩散原理扩散过程:加噪声,均匀分布到整个空间重参数:避免梯度消失、爆炸复原过程:去噪声,恢复原始图像损失函数:交叉熵损失,变分推断训练流程 对比GAN和VAE原先,图像生成领域最常见生成模型有GAN和VAE。后来,Diffusion扩散模型也是生成模型,且在......
  • 一. 什么是LLM(大语言模型)?
    1.发展历程语言建模的研究始于20世纪90年代,最初采用了统计学习方法,通过前面的词汇来预测下一个词汇。然而,这种方法在理解复杂语言规则方面存在一定局限性。随后,研究人员不断尝试改进,其中在2003年,深度学习先驱Bengio在他的经典论文《ANeuralProbabilisticLanguageModel》中,首次......