首页 > 其他分享 >Llama3学习记录

Llama3学习记录

时间:2024-09-19 21:51:24浏览次数:12  
标签:dim 记录 Llama3 self torch freqs 学习 xq xk

Llama3学习记录

Llama3是一个稠密的transformer网络模型,应用于预测文本序列的下一个token。相较于先前版本的Llama模型,其性能提升主要来自于数据质量的提升以及多样性,并且也受益于模型参数的增加

1. 网络架构

  • 由上图可知,Llama3是一个decoder only的网络模型

  • Llama3模型具体架构层如上图,可以看到,Llama模型使用了前置的RMSNorm层,并且在注意力机制中采用了GQA架构,并且在Q、K上使用了RoPE旋转位置编码
  • 由于模型是预测下一个token,因此Llama在训练时,会mask掉卫位于当前token之后的token

2. 核心概念

2.1 RMSNorm:

RMSNorm是LayerNorm的变体,通过激活值的均方根来实现归一化

优点:

  1. 不计算均值,相比于LayerNorm,减少了计算开销
  2. 避免了过度归一化,使得训练更加稳定:没有对均值进行归一化,只归一化方差,因此可以保留均值的信息,减少对信息的破坏

2.1.1 计算步骤:

  1. 计算均方根值RMS(x):对输入的特征x计算其均方根值,x是输入特征向量,n是特征维度

  1. 进行归一化,利用计算的均方根进行归一化:

  1. 可以通过可学习的参数进行缩放和平移,g是缩放参数,b是平移参数

2.1.2 代码实现:

# 计算归一化结果并进行缩放(self.weight为缩放参数)
# torch.rsqrt:计算每个元素平方根的倒数

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

2.2 Rope:旋转位置编码

  • 对于输入的向量进行分解,看作多个二维向量的组合。Rope对每个二位向量进行旋转变换,对于位置p,旋转角度公式为:image,其中image是一个伴随维度变化的常数,用来控制旋转速度,每个二维向量的旋转公式为:

  • 其计算的矩阵形式为:

  • Rope会随着相对位置的增加,逐渐减小
  • 代码实现:
# 计算频率矩阵,返回余弦和正弦的频率矩阵
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device) 
    freqs = torch.outer(t, freqs).float() 
    freqs_cos = torch.cos(freqs) 
    freqs_sin = torch.sin(freqs) 
    return freqs_cos, freqs_sin

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(shape)

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cos: torch.Tensor,
    freqs_sin: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:

    # 重塑 xq 和 xk,使其与复数表示相匹配
    xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
    xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)

    # 重塑形为了广播
    freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)

    # 应用旋转嵌入
    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos

    # 将最后两维度拉平。
    xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
    xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

2.3 Attn实现

class Attention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        
        self.group = args.n_group
        self.heads = args.n_heads
        self.kv_heads = args.n_heads // args.n_group
        assert args.n_heads % self.kv_heads == 0
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout
        mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
        mask = torch.triu(mask, diagonal=1)
        self.register_buffer("mask", mask)

    def forward(
        self,
        x: torch.Tensor,
        freqs_cos: torch.Tensor,
        freqs_sin: torch.Tensor,
    ):
        bsz, seqlen, _ = x.shape

        # QKV
        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
        xq = xq.view(bsz, seqlen, self.heads, self.head_dim)
        xk = xk.view(bsz, seqlen, self.kv_heads, self.head_dim)
        xv = xv.view(bsz, seqlen, self.kv_heads, self.head_dim)

        # RoPE relative positional embeddings
        xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)

        # grouped multiquery attention: expand out keys and values
        xk = repeat_kv(xk, self.group)  # (bs, seqlen, n_local_heads, head_dim)
        xv = repeat_kv(xv, self.group)  # (bs, seqlen, n_local_heads, head_dim)

        # make heads into a batch dimension
        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
        xk = xk.transpose(1, 2)
        xv = xv.transpose(1, 2)


        # 先不使用flash attn,从零走一遍流程!
        scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
        assert hasattr(self, 'mask')
        scores = scores + self.mask[:, :, :seqlen, :seqlen]   # (bs, n_local_heads, seqlen, cache_len + seqlen)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        scores = self.attn_dropout(scores)
        output = torch.matmul(scores, xv)  # (bs, n_local_heads, seqlen, head_dim)

        # restore time as batch dimension and concat heads
        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)

        # 最终送入output层并正则,得到最终结果。
        output = self.wo(output)
        output = self.resid_dropout(output)
        return output

3. Llama3本地部署

按照HF库中文档的简单部署与使用

import transformers
import torch

model_id = "meta-llama/Meta-Llama-3-8B"

pipeline = transformers.pipeline(
    "text-generation", model=model_id, model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto"
)
pipeline("Hey how are you doing today?")

标签:dim,记录,Llama3,self,torch,freqs,学习,xq,xk
From: https://www.cnblogs.com/AiHorizon/p/18421452

相关文章

  • 俄罗斯的Alexey V. Gubin开发的数据恢复文件-零假设恢复只读模式下对扫描/恢复数据起
    //主要特征//WindowsFAT,NTFS,Linuxext2/3/4和XFS卷格式化的驱动器或“RAW文件系统”外部驱动器,USB拇指驱动器和存储卡带有ZARDataRecovery免费版本的数码照片恢复RAID数据恢复NAS数据恢复MBR损坏数据恢复具有多个逻辑驱动器的分区表恢复支持长文件名和国家文......
  • [学习笔记]树链剖分(简易版) 及其LCA
    树链剖分先讲解一下一些基础定义(都是在树上)重儿子:一个节点中所有儿子中子树大小最大的一个儿子(每个节点最多有一个重儿子)轻儿子:一个节点除重儿子外所有的节点重链:若干个重儿子组成的链链顶:一条链中深度最小的节点以下图为例子(红色连续线段为重链)对于节......
  • 《学习公社》继续教育快速学习操作指南
    一、总结本教程面向2024年度学习公社继续教育学时快速学习,自动完成所选全部必修课及选修课。操作过程简单明了。目标网址:https://www.ttcdw.cn(其他套壳学习公社的网站均可使用,例如家庭教育指导培训平台:http://www.jkyjtjy.com)联系微信:clm8618888(任何网站项目均可定制达到同......
  • Springboot基于springbootvue小学生学习阅读平台785j5(程序+源码+数据库+调试部署+开发
    本系统(程序+源码+数据库+调试部署+开发环境)带论文文档1万字以上,文末可获取,系统界面在最后面。系统程序文件列表开题报告内容一、项目背景在当今信息化时代,阅读对于小学生而言不仅是获取知识的重要途径,也是培养思维能力和文化素养的关键环节。然而,传统的学习方式往往受限于......
  • 学习-2024/9/18
    双列集合双列集合的特点双列集合Map的常见APIMap是双列集合的顶层接口,它的功能是全部双列集合都可以继承使用的Vput(Kkey,Vvalue)添加元素再添加元素的时候,如果键不存在,那么直接把键值对象添加到map中,方法返回null在添加数据的时候,如果键是存在的,那么会把原有......
  • java学习9.19
    结合前端,在本地运行实现登陆操作。将在输入框的数据传给服务器,服务器再通过调用数据库的数据进行对比,实现简单的判断逻辑到这里的我就感觉内容多了起来,在之前连接数据库,数据库操作的时候,跟着教程走,只是知道简单的用法也能在之后自行配置这里的话数据库等操作变成了一个环节,还有......
  • 口胡记录
    先开个坑,不一定填。主要记录一些口胡了但是没写的题。String题面:给定两个字符串\(a\),\(b\),我们称这两个字符串的所有子序列为坏字符串。求最短的非坏字符串。做法:首先要解决一个问题,假设你有一个字符串你需要判断这个字符串是否是坏的,怎么快速判断?我们预处理出nxta[i......
  • Android NotificationListenerService的实操记录
    文章目录背景介绍主要方法技术细节背景介绍Android在4.3的版本中(即API18)加入了NotificationListenerService,根据SDK的描述(AndroidDeveloper)可以知道,当系统收到新的通知或者通知被删除时,会触发NotificationListenerService的回调方法。同时在Android4.4中新增......
  • 全面系统的AI学习路径,帮助普通人也能玩转AI
    前言现如今AI技术和应用的发展可谓是如火如荼,它们在各个领域都展现出了巨大的潜力和影响力。AI的出现对于我们这些普通人而言也是影响匪浅,比如说使用AI工具GPT来写文档查问题、使用AI辅助编程工具帮助我们写代码、并且可是使用AI来实现人工客服等。那么普通人如何学习AI呢?别再当别......
  • 【第十二章:Sentosa_DSML社区版-机器学习回归】
    【第十二章:Sentosa_DSML社区版-机器学习回归】12.1 线性回归1.算子介绍        线性回归模型(BuildLRNode)是一个非常经典有效的回归模型,它假设所有特征变量和目标变量之间存在线性关系。通过训练来求得各个特征的权重以及截距。同时可以通过L1,L2正则化来减少模型......