首页 > 其他分享 >Transformer模型代码(详细注释,适合新手)

Transformer模型代码(详细注释,适合新手)

时间:2024-06-18 11:33:09浏览次数:10  
标签:Transformer idx self 注释 num position 新手 model logits

# Hyperparameters
batch_size = 4  # How many batches per training step
context_length = 16  # Length of the token chunk each batch
d_model = 64  # The size of our model token embeddings
num_blocks = 8  # Number of transformer blocks
num_heads = 4  # Number of heads in Multi-head attention
learning_rate = 1e-3  # 0.001
dropout = 0.1  # Dropout rate
max_iters = 5000  # Total of training iterations <- Change this to smaller number for testing
eval_interval = 50  # How often to evaluate
eval_iters = 20  # Number of iterations to average for evaluation
device = 'cuda' if torch.cuda.is_available() else 'cpu'  # Use GPU if it's available.
TORCH_SEED = 1337
torch.manual_seed(TORCH_SEED)


class FeedForward(nn.Module):
    def __init__(self):
        super().__init__()
        # 模型维度
        self.d_model = d_model
        # 丢弃率
        self.dropout = dropout
        # 第一个线性层
        self.ln1 = nn.Linear(in_features=self.d_model, out_features=self.d_model * 4)
        # ReLU激活函数
        self.relu = nn.ReLU()
        # 第二个线性层
        self.ln2 = nn.Linear(in_features=self.d_model * 4, out_features=self.d_model)
        # 丢弃层
        self.dp = nn.Dropout(dropout)
    def forward(self, x):
        # 输入形状为 batch_size, seq_len, d_model
        x = self.ln1(x)
        x = self.relu(x)
        x = self.ln2(x)
        out = self.dp(x)
        return out




class Attention(nn.Module):
    def __init__(self, head_size):
        """
        参数:
            head_size (int): 每个注意力头的大小。
            d_model (int): 输入张量的特征维度大小。
            context_length (int): 上下文长度,即时间步数。
            dropout (float): Dropout 层的丢弃率。
        """
        super().__init__()
        # 设置模型的参数
        self.d_model = d_model
        self.head_size = head_size
        self.context_length = context_length
        self.dropout = dropout
        # 定义用于计算注意力权重的线性层
        self.key_layer = nn.Linear(in_features=self.d_model, out_features=self.head_size, bias=False)
        self.query_layer = nn.Linear(in_features=self.d_model, out_features=self.head_size, bias=False)
        self.value_layer = nn.Linear(in_features=self.d_model, out_features=self.head_size, bias=False)
        # 生成下三角掩码
        self.register_buffer('tril', torch.tril(torch.ones((self.context_length, self.context_length))))
        # Dropout 层,用于防止过拟合
        self.dropout_layer = nn.Dropout(self.dropout)

    def forward(self, x):
        #todo 输入输出维度一样

        # 获取输入张量的形状信息
        B, T, C = x.shape  # Batch size, Time steps(current context_length), Channels(dimensions)
        # 确保时间步数不超过上下文长度,且特征维度与模型参数匹配
        assert T <= self.context_length
        assert C == self.d_model
        # 通过线性层得到查询、键和值张量
        q = self.query_layer(x)
        k = self.key_layer(x)
        v = self.value_layer(x)

        # 缩放点积注意力:Q @ K^T / sqrt(d_k)
        weights = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        # 应用掩码
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        weights = F.softmax(input=weights, dim=-1)
        weights = self.dropout_layer(weights)

        # 加权求和:weights @ V
        out = weights @ v
        return out


class MultiHeadAttention(nn.Module):
    def __init__(self, head_size):
        """
        定义一个多头注意力机制模型。

        参数:
            num_heads (int): 注意力头的数量。
            head_size (int): 每个注意力头的大小。
            d_model (int): 输入张量的特征维度大小。
            context_length (int): 上下文长度,即时间步数。
            dropout (float): Dropout 层的丢弃率。
        """
        super().__init__()
        # 设置模型的参数
        self.num_heads = num_heads
        self.head_size = head_size
        self.d_model = d_model
        self.context_length = context_length
        self.dropout = dropout
        # 创建多个注意力头
        self.heads = nn.ModuleList([Attention(head_size=self.head_size)
                                    for _ in range(self.num_heads)])
        # 线性投影层,用于将多头注意力的输出映射回原始的特征维度
        self.projection_layer = nn.Linear(in_features=self.num_heads * self.head_size, out_features=self.d_model)
        # Dropout 层,用于防止过拟合
        self.dropout_layer = nn.Dropout(self.dropout)

    def forward(self, x):
        """
        参数:
            x (torch.Tensor): 输入张量,形状为(batch_size, seq_len, d_model)。
        返回:
            torch.Tensor: 多头注意力机制的输出张量,形状与输入张量相同。
        """
        # 对每个注意力头执行前向传播,并将它们的输出拼接在一起
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        # 通过线性投影层将多头注意力的输出映射回原始的特征维度
        out = self.projection_layer(out)
        # 应用 Dropout 层,防止过拟合
        out = self.dropout_layer(out)
        return out


class TransformerBlock(nn.Module):

    def __init__(self):
        super().__init__()
        # 设置模型的参数
        self.d_model = d_model
        self.context_length = context_length
        self.head_size = d_model // num_heads  # 注意力头的大小
        self.num_heads = num_heads
        self.dropout = dropout
        # 多头注意力层
        self.multi_head_attention_layer = MultiHeadAttention(self.head_size)
        # 前馈神经网络层
        self.feed_forward_layer = FeedForward()
        # Layer normalization 层
        self.layer_norm_1 = nn.LayerNorm(normalized_shape=self.d_model)
        self.layer_norm_2 = nn.LayerNorm(normalized_shape=self.d_model)

    def forward(self, x):
        """
        定义模型的前向传播过程。
        参数:
            x (torch.Tensor): 输入张量,形状为(batch_size, seq_len, d_model)。
        返回:
            torch.Tensor: Transformer 块的输出张量,形状与输入张量相同。
        """
        # 注意:操作的顺序与原始的 Transformer 论文不同
        # 这里的顺序是:LayerNorm -> Multi-head attention -> LayerNorm -> Feed forward
        # 使用 Layer normalization 对输入张量进行归一化
        x_normalized_1 = self.layer_norm_1(x)
        # 执行多头注意力机制,并将输出与输入张量相加(残差连接)
        x = x + self.multi_head_attention_layer(x_normalized_1)
        # 使用 Layer normalization 对得到的结果进行归一化
        x_normalized_2 = self.layer_norm_2(x)
        # 执行前馈神经网络,并将输出与之前的结果相加(残差连接)
        x = x + self.feed_forward_layer(x_normalized_2)
        return x

class TransformerLanguageModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.d_model = d_model
        self.context_length = context_length
        self.num_heads = num_heads
        self.num_blocks = num_blocks
        self.dropout = dropout
        self.max_token_value = max_token_value
        # 设置 token 嵌入查找表
        self.token_embedding_lookup_table = nn.Embedding(num_embeddings=self.max_token_value + 1, embedding_dim=self.d_model)

        # 运行所有的 Transformer 块
        # 与原始论文不同,这里在所有块之后添加了一个最终的层规范化
        self.transformer_blocks = nn.Sequential(*(
                [TransformerBlock(num_heads=self.num_heads) for _ in range(self.num_blocks)] +
                [nn.LayerNorm(self.d_model)]
        ))
        self.language_model_out_linear_layer = nn.Linear(in_features=self.d_model, out_features=self.max_token_value)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        """
        # 设置位置嵌入查找表
        # 遵循原始 Transformer 论文相同的方法(正弦和余弦函数)
        """
        position_encoding_lookup_table = torch.zeros(self.context_length, self.d_model)
        position = torch.arange(0, self.context_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, self.d_model, 2).float() * (-math.log(10000.0) / self.d_model))
        position_encoding_lookup_table[:, 0::2] = torch.sin(position * div_term)
        position_encoding_lookup_table[:, 1::2] = torch.cos(position * div_term)
        # 将 position_encoding_lookup_table 从 (context_length, d_model) 更改为 (T, d_model)
        position_embedding = position_encoding_lookup_table[:T, :].to(device)
        x = self.token_embedding_lookup_table(idx) + position_embedding
        x = self.transformer_blocks(x)
        # “logits” 是我们的模型在应用 softmax 之前的输出值
        logits = self.language_model_out_linear_layer(x)

        if targets is not None:
            B, T, C = logits.shape
            logits_reshaped = logits.view(B * T, C)
            targets_reshaped = targets.view(B * T)
            loss = F.cross_entropy(input=logits_reshaped, target=targets_reshaped)
        else:
            loss = None
        return logits, loss

    def generate(self, idx, max_new_tokens):
        # idx 是当前上下文中索引的 (B,T) 数组
        for _ in range(max_new_tokens):
            # 将 idx 裁剪到我们位置嵌入表的最大尺寸
            idx_crop = idx[:, -self.context_length:]
            # 获取预测值
            logits, loss = self(idx_crop)
            # 从 logits 中获取最后一个时间步,其中 logits 的维度为 (B,T,C)
            logits_last_timestep = logits[:, -1, :]
            # 应用 softmax 获取概率
            probs = F.softmax(input=logits_last_timestep, dim=-1)
            # 从概率分布中采样
            idx_next = torch.multinomial(input=probs, num_samples=1)
            # 将采样的索引 idx_next 追加到 idx 中
            idx = torch.cat((idx, idx_next), dim=1)
        return idx

本文提供了transformer代码附带详细注释,要注意本文的transformer并非传统的encoder-decoder结构的,而是主流的gpt结构(decoder-only),不了解decoder-only的同学,可以参考我的另一篇文章,链接放在最后。我过几天会出一个介绍gpt模型结构的的文章,欢迎大家前来讨论。http://t.csdnimg.cn/IGCUL

标签:Transformer,idx,self,注释,num,position,新手,model,logits
From: https://blog.csdn.net/SWZ156/article/details/139647527

相关文章

  • Dcat admin laravel 快速安装及生成相应页面(新手)
    使用工具:phpEnv,phpStorm操作步骤:安装阿里云Composer镜像:打开命令行工具,如CMD或PowerShell。切换到自己安装phpEnv的www目录下我的是D:\Studysoft\phpEnv\www 。执行以下命令以设置全局Composer镜像:composerconfig-grepo.packagistcomposerhttps://mirror......
  • 史上最全最详细的适合新手的从零搭建一个Linux的HTTP服务器
    一.概念梳理    http(hyper-text-transmission-protocol)超文本传输协议,顾名思义就是传输超文本(html)的协议,具体地来说,我们不需要知道html怎么写,我们只需要梳理服务器的数据接收和响应.具体业务具体分析,你们可以自行丰富内容,这里只做最简单的功能演示.http协......
  • 新手如何入门Web3?
    一、什么是Web3?Web3是指下一代互联网,它基于区块链技术,致力于将各种在线活动变得更加安全、透明和去中心化。Web3是一个广义的概念,涵盖了包括数字货币、去中心化应用、智能合约等在内的多个方面。它的主要特点包括去中心化、区块链技术的运用、智能合约的执行、用户主权和数......
  • 新手如何入门Web3?
    一、什么是Web3?Web3是指下一代互联网,它基于区块链技术,致力于将各种在线活动变得更加安全、透明和去中心化。Web3是一个广义的概念,涵盖了包括数字货币、去中心化应用、智能合约等在内的多个方面。它的主要特点包括去中心化、区块链技术的运用、智能合约的执行、用户主权和数......
  • 新手教学系列-​​​​​​基础知识(SSH使用)
    基础知识(SSH使用)什么是sshSecureShell(安全外壳协议,简称SSH)是一种加密的网络传输协议,可在不安全的网络中为网络服务提供安全的传输环境[1]。SSH通过在网络中创建安全隧道来实现SSH客户端与服务器之间的连接[2]。虽然任何网络服务都可以通过SSH实现安全传输,SSH最常见的用途是......
  • Structure-Aware Transformer for Graph Representation Learning
    目录概SAT代码ChenD.,O'BrayL.andBorgwardtK.Structure-awaretransformerforgraphrepresentationlearning.ICML,2022.概Graph+Transformer+修改attention机制.SATTransformer最重要的就是attention机制:\[\text{Attn}(x_v)=\sum_{v\in......
  • 移动硬盘数据恢复方法哪个好?六个硬盘恢复,新手也能用!
    移动硬盘数据恢复方法哪个好?移动硬盘,作为我们存储重要数据的常用设备,一旦里面的视频、文档、音频等资料突然消失,确实会令人烦恼和担忧。然而,因为数据丢失的原因可能多种多样,因此恢复方法也会有所不同。所以,建议您在遇到的第一时间去想办法恢复。那么,移动硬盘数据恢复方法哪个好......
  • Illustrated Transformer笔记
    AttentionIsAllYouNeed编码器端Self-attention层用处:将对其他相关单词的“理解”融入我们当前正在处理的单词的方法,类似于RNN通过保持隐藏状态让RNN将其已处理的先前单词/向量的表示与当前正在处理的单词/向量结合起来将单词输入转化为Embedding之后,将Embedding和QKV......
  • 回归预测 | Matlab实现Transformer多输入单输出回归预测
    回归预测|Matlab实现Transformer多输入单输出回归预测目录回归预测|Matlab实现Transformer多输入单输出回归预测效果一览基本介绍程序设计参考资料效果一览基本介绍1.Matlab实现Transformer多变量回归预测;2.运行环境为Matlab2023b;3.输入多个特征,输......
  • 攻防世界web新手题fileinclude&fileclude
    题目1:fileinclude工具:BurpsuiteHackbarV2火狐浏览器的扩展应用解题关键:学会文件包含的命令以及学会读懂php脚本解题过程:首先对该网站进行抓包,发现潜藏的php脚本这段代码的作用是:检查是否display_errors配置项被设置为打开,如果没有则将其打开,这样可以在页面上......