首页 > 编程语言 >使用统计方法在AMD GPU上使用JAX Profiler可靠地比较大型生成AI模型中的算法性能

使用统计方法在AMD GPU上使用JAX Profiler可靠地比较大型生成AI模型中的算法性能

时间:2024-08-28 15:24:52浏览次数:15  
标签:qr trace JAX AI einsum AMD shape -- matmul

Using statistical methods to reliably compare algorithm performance in large generative AI models with JAX Profiler on AMD GPUs — ROCm Blogs

摘要

本文提供了一份详细的指南,介绍如何在JAX实现的生成AI模型中测量和比较各种算法的性能。利用JAX Profiler和统计分析,本文展示了如何可靠地评估关键步骤并比较AMD GPU上算法的性能。

引言

在GPU加速计算的动态领域,追求最佳性能和效率需要有效的性能分析技术。性能分析通过仔细检查执行时间、内存利用率和内核占用率等指标,提供了对基于GPU的应用程序行为和性能特征的全面了解。这对于大规模生成AI模型尤为重要,因为优化性能可以显著提升最终用户体验和收入来源。通过利用性能分析技术,开发人员可以找出低效之处,深入了解运行时行为,并最终优先考虑战略性优化工作,从而带来显著的性能提升。
JAX是谷歌的一款开源数值计算库(尽管不是官方的谷歌产品),由于其能够利用硬件加速器和自动微分的能力,正在生成AI领域引起广泛关注。最初用于高性能机器学习研究,JAX的函数式编程方法和对GPU及TPU的支持使其成为构建和部署大型语言模型(LLMs)和其他前沿生成AI应用的首选。值得注意的是,像 X.AI这样的公司利用JAX开发开源模型如Grok-1,进一步推动了该库在生成AI领域的流行。凭借其性能、灵活性及其适合先进AI模型开发和部署的特点,JAX继续在受欢迎程度上不断攀升。

ROCm博客系列此前已探索过各种性能分析工具,如 *rocprof*,可以用于在AMD GPU上分析模型性能,还有针对TensorFlow和PyTorch的框架特定性能分析工具。尽管JAX的官方页面涵盖了其性能分析工具的基本用法,本教程深入探讨了更高级的技术。例如,它解释了在评估算法时,如何在考虑到大量随机噪声的情况下确定一种算法是否显著优于另一种算法。本文通过统计分析和假设检验,展示了如何可靠地测量和比较在大型语言模型中执行相同步骤的不同算法的性能。具体而言,它比较了在JAX-based生成预训练变换器(GPT)模型的`CausalSelfAttention`组件中,使用`einsum`与`matmul`实现两个矩阵乘法步骤的性能。(参见博客中关于在JAX中实现GPT模型的文章)。要了解更多关于`einsum`的信息,请访问这篇博客。 

实现

要实现此代码示例,请首先设置ROCm环境,并安装必要的软件包和Python脚本。值得注意的是,该代码示例是平台无关的,这意味着只要加速计算平台和Python包配置正确,它就兼容AMD GPU以及其他GPU或TPU。

环境设置

按照以下步骤为本教程设置运行环境:

1. 在Linux shell中使用下面的代码拉取并运行docker容器:

docker run -it --ipc=host --network=host --device=/dev/kfd --device=/dev/dri \
           --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined \
           --name=nanogpt rocm/pytorch:rocm6.1_ubuntu22.04_py3.10_pytorch_2.1.2 /bin/bash

2. 在docker容器内运行以下代码,以安装必要的Python包并配置XLA环境变量:

python3 -m pip install --upgrade pip
pip install optax==0.2.2 flax==0.8.2 transformers==4.38.2 tiktoken==0.6.0 datasets==2.17.1 perfetto==0.7.0 matplotlib==3.8.4 scipy==1.13.0
python3 -m pip install https://github.com/ROCmSoftwarePlatform/jax/releases/download/jaxlib-v0.4.26/jaxlib-0.4.26+rocm610-cp310-cp310-manylinux2014_x86_64.whl
python3 -m pip install https://github.com/ROCmSoftwarePlatform/jax/archive/refs/tags/jaxlib-v0.4.26.tar.gz
pip install numpy==1.22.0
export XLA_FLAGS="--xla_gpu_autotune_level=0"

3. 使用以下命令从 ROCm/rocm-blogs GitHub 存储库下载用于该博客的文件。

git clone https://github.com/ROCm/rocm-blogs.git
cd rocm-blogs/blogs/artificial-intelligence/nanoGPT-JAX

4. 将`nanoGPT-JAX`文件夹中的`model.py`和`sample.py`脚本替换为当前博客在GitHub上*src*文件夹中的对应文件。具体参考此链接

特别需要注意的是,对`model.py`文件的修改如下面代码块所示。新添加的两行代码使用`jax.named_scope`为两个矩阵乘法步骤注释唯一名称,这是一个将用户指定名称纳入JAX名称堆栈的上下文管理器。程序随后使用指定名称提取这些步骤的相关性能数据。该技巧对于快速将同类型操作的日志映射到应用程序或模型中的每个步骤非常宝贵,因为默认的日志名称可能会在同类型操作之间非常相似或令人困惑。下面的代码块封装了两个不同的矩阵乘法步骤,并分别为它们指派了不同的范围名称`attn_q_k`和`attn_att_v`。

class CausalSelfAttention(nn.Module):
    config: GPTConfig

    @nn.compact
    def __call__(self, x, train=False, rng1=None, rng2=None):
        assert self.config.n_embd % self.config.n_head == 0
        B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd)
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        q, k, v  = jnp.split(nn.Dense(self.config.n_embd * 3, name="c_attn")(x), 3, axis=-1)
        k = k.reshape(B, T, self.config.n_head, C // self.config.n_head).swapaxes(1, 2) # (B, nh, T, hs)
        q = q.reshape(B, T, self.config.n_head, C // self.config.n_head).swapaxes(1, 2) # (B, nh, T, hs)
        v = v.reshape(B, T, self.config.n_head, C // self.config.n_head).swapaxes(1, 2) # (B, nh, T, hs)
+       with jax.named_scope("attn_q_k"):
+           att = (jnp.einsum('bhts,bhqs->bhtq', q, k, optimize=True) if self.config.use_einsum else jnp.matmul(q, k.swapaxes(-2, -1))) * (1.0 / jnp.sqrt(k.shape[-1]))
-       att = (jnp.einsum('bhts,bhqs->bhtq', q, k, optimize=True) if self.config.use_einsum else jnp.matmul(q, k.swapaxes(-2, -1))) * (1.0 / jnp.sqrt(k.shape[-1]))
        mask = jnp.tril(jnp.ones((T, T))).reshape((1, 1, T, T))
        att = jnp.where(mask == 0, float('-inf'), att)
        att = nn.softmax(att, axis=-1)
        att = nn.Dropout(self.config.dropout, name='attn_dropout', deterministic=not train)(att, rng=rng1)
+       with jax.named_scope("attn_att_v"):
+           y = jnp.einsum('bhts,bhsq->bhtq', att, v, optimize=True) if self.config.use_einsum else jnp.matmul(att, v)   # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
-       y = jnp.einsum('bhts,bhsq->bhtq', att, v, optimize=True) if self.config.use_einsum else jnp.matmul(att, v)   # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.swapaxes(1, 2).reshape(B, T, C)  # re-assemble all head outputs side by side
        # output projection
        y = nn.Dense(self.config.n_embd, name='c_proj')(y)
        y = nn.Dropout(self.config.dropout, name='resid_dropout', deterministic=not train)(y, rng=rng2)

        return y

对于`sample.py`文件的主要修改包括使用`jax.profiler.start_trace()`和`jax.profiler.stop_trace()`包裹负责运行基于JAX的GPT模型的推理的函数。这将记录每个生成样本的跟踪信息。或者,您可以使用`jax.profiler.trace()`上下文管理器来捕获跟踪,具体可参见本指南。每个样本的性能分析输出将存储在单独的文件夹中,这使得分析个别跟踪更为方便。

for i in range(num_samples): 
+   jax.profiler.start_trace(profile_dir+f'_{i}')
    output = generate([jnp.array(start_ids)], seed+i)
+   jax.profiler.stop_trace()
    print(f'\nGenerated output __{i}__: \n__________________________________\n{decode(output[0].tolist())}\n__________________________________')

使用不同的矩阵乘法算法对GPT模型进行性能分析

为了演示性能分析,本示例比较了`einsum`和`matmul`这两种在注意力计算步骤中执行矩阵乘法的内置方法。`use_einsum`标志控制了是选择`einsum`还是`matmul`进行矩阵乘法。运行以下命令以收集这两种不同算法的性能分析输出:

# Generate profiling output using matmul
python sample.py --init_from='gpt2' --max_new_tokens=50 --start="The weather today is" --num_samples=10 --profile_dir="trace_file_matmul"

# Generate profiling output using einsum
python sample.py --init_from='gpt2' --max_new_tokens=50 --start="The weather today is" --num_samples=10 --profile_dir="trace_file_einsum" --override_args="{'use_einsum':True}"

每条命令调用`sample.py`文件生成10个样本,每个样本包含最多50个新生成的tokens。这会生成20个文件夹(每种算法10个文件夹,每个生成的样本一个文件夹),这些文件夹包含性能分析的输出。在每个文件夹中,性能分析输出存储在一个压缩的`.gz`文件中。在docker终端中运行以下命令来解压输出: 

for i in {0..9}; do
    gzip -d trace_file_einsum_$i/plugins/profile/202*/*.json.gz
    gzip -d trace_file_matmul_$i/plugins/profile/202*/*.json.gz
done

统计分析与两种算法的性能测试

现在,你可以读取剖析数据并进行统计分析。对于每个迭代(对应于每种算法生成的一个样本),程序比较两种算法在矩阵乘法执行时间(以纳秒为单位)分布上的差异。可以使用箱线图来直观地检查差异。Wilcoxon秩和检验(Mann-Whitney U检验)用来确定位置参数(如均值和中位数)是否显著不同。较短的执行时间表示更好的性能。

下面的代码块导入了分析所需的包,并定义了绘制箱线图的函数。

import glob
from perfetto.trace_processor import TraceProcessor
from scipy.stats import ranksums
import matplotlib.pyplot as plt


def plot_boxplot(df1, df2, columns1, columns2=None, df1_lab='matmul', df2_lab='einsum'):
    """
    Plot boxplots for specified columns in two DataFrames. This function will 
    be used to compare the distribution of running time for the two algorithms
    we profiled.

    Args:
    df1 (pandas.DataFrame): First DataFrame.
    df2 (pandas.DataFrame): Second DataFrame.
    columns1 (list): List of column names from the first DataFrame to plot.
    columns2 (list): List of column names from the second DataFrame to plot.
    df1_lab (string): Label for df1 in the plot.
    df2_lab (string): Label for df2 in the plot.
    """
    if columns2 is None:
        columns2 = columns1
    # Combine data from both DataFrames
    data = [df1[col] for col in columns1] + [df2[col] for col in columns2]
    
    # Create labels for boxplots
    labels = [df1_lab + '_' + col for col in columns1] + [df2_lab + '_' + col for col in columns2]
    
    # Plot boxplots
    plt.figure(figsize=(10, 6))
    plt.boxplot(data, labels=labels)
    plt.xlabel('Algorithms')
    plt.ylabel('Time in nanoseconds')
    plt.title('Performance comparison on the scale of nanoseconds')
    plt.xticks(rotation=45)
    plt.grid(True)
    plt.show()

程序随后比较了每次样本生成迭代中两种算法的执行时间。它在SQL查询中使用`where display_value like "%attn_q_k%"来过滤在第一个named_scope`中的操作。你可以修改SQL查询以探索不同的列并计算感兴趣的指标。

程序省略了第一次迭代,因为第一次迭代包括编译时间,这会使比较失真。它打印了每种算法的执行时间的均值和标准偏差,以及数据框的形状,以确保剖析器和SQL查询捕获了所有事件。例如,对于包含12层的模型,并且每个样本最多生成50个新的token(导致对模型的最多50次函数调用),应捕获最多`12*50=600`次矩阵乘法事件。

最后,程序打印了Wilcoxon秩和检验的统计量和p值,该检验评估两种算法的执行时间分布的位置参数(如均值和中位数)是否显著不同。尽管t检验广泛用于检验两种总体的均值是否相等,但由于样本中存在许多异常值,因此示例使用秩基非参数检验。这些异常值可能显著降低t检验的可靠性。

for i in range(1, 10):
    # Process the profiling data for matmul
    tp = TraceProcessor(trace=glob.glob(f'trace_file_matmul_{i}/plugins/profile/202*/*.json'))
    # SQL query to get the operations enclosed by the named_scope
    query_text='''INCLUDE PERFETTO MODULE slices.slices;
    WITH arg_sets_0 AS (
        SELECT DISTINCT arg_set_id, display_value
        FROM args
        WHERE key = 'args.name'
    )
    SELECT name, display_value, dur
        FROM _slice_with_thread_and_process_info
        INNER JOIN arg_sets_0 ON arg_sets_0.arg_set_id = _slice_with_thread_and_process_info.arg_set_id
    where display_value like "%attn_q_k%"
    '''
    # Query the profiling data and convert to dataframe
    qr_matmul = tp.query(query_text).as_pandas_dataframe()
    # Process the profiling data for einsum
    tp = TraceProcessor(trace=glob.glob(f'trace_file_einsum_{i}/plugins/profile/202*/*.json'))
    # Query the profiling data and convert to dataframe
    qr_einsum = tp.query(query_text).as_pandas_dataframe()
    print(f'###########i={i}###########')
    print('#'*30)
    # Print out the mean, standard dev. and shape for each algorithm
    print(f'Matmul: Mean={qr_matmul.dur.mean()}, std. dev.={qr_matmul.dur.std()}, shape of df:{qr_matmul.shape}')
    print(f'Einsum: Mean={qr_einsum.dur.mean()}, std. dev.={qr_einsum.dur.std()}, shape of df:{qr_einsum.shape}')
    plot_boxplot(qr_matmul, qr_einsum, ['dur'])
    stat, p = ranksums(qr_matmul['dur'], qr_einsum['dur'])
    print(f'Test statistic={stat}, p_val={p}')

下面是两次迭代的截断输出,所有九次迭代都观察到了相同的模式。

###########i=1###########
##############################
Matmul: Mean=6461.875, std. dev.=504.8818364954699, shape of df:(600, 3)
Einsum: Mean=5813.346666666666, std. dev.=455.80420754410954, shape of df:(600, 3)
Test statistic=20.22982266255362, p_val=5.349499343834845e-91

###########i=2###########
##############################
Matmul: Mean=6293.076666666667, std. dev.=514.1309448993132, shape of df:(600, 3)
Einsum: Mean=5797.615, std. dev.=397.86885546863283, shape of df:(600, 3)
Test statistic=16.932946075063718, p_val=2.5717953759559878e-64

基于结果,可以明显看出,对于矩阵乘法算法`einsum`比`matmul`在计算`query`和`key`矩阵之间的矩阵乘法时显著更快。但对于在`attention`和`value`矩阵之间的矩阵乘法时,`matmul`如何表现呢?结果显示在下面的代码块中:

for i in range(1, 10):
    # Process the profiling data for matmul
    tp = TraceProcessor(trace=glob.glob(f'trace_file_matmul_{i}/plugins/profile/202*/*.json'))
    # SQL query to get the operations enclosed by the named_scope
    query_text='''INCLUDE PERFETTO MODULE slices.slices;
    WITH arg_sets_0 AS (
        SELECT DISTINCT arg_set_id, display_value
        FROM args
        WHERE key = 'args.name'
    )
    SELECT name, display_value,dur
        FROM _slice_with_thread_and_process_info
        INNER JOIN arg_sets_0 ON arg_sets_0.arg_set_id = _slice_with_thread_and_process_info.arg_set_id
    where display_value like "%attn_att_v%"
    '''
    # Query the profiling data and convert to dataframe
    qr_matmul = tp.query(query_text).as_pandas_dataframe()
    # Process the profiling data for einsum
    tp = TraceProcessor(trace=glob.glob(f'trace_file_einsum_{i}/plugins/profile/202*/*.json'))
    # Query the profiling data and convert to dataframe
    qr_einsum = tp.query(query_text).as_pandas_dataframe()
    print(f'###########i={i}###########')
    print('#'*30)
    # Print out the mean, standard dev. and shape for each algorithm
    print(f'Matmul: Mean={qr_matmul.dur.mean()}, std. dev.={qr_matmul.dur.std()}, shape of df:{qr_matmul.shape}')
    print(f'Einsum: Mean={qr_einsum.dur.mean()}, std. dev.={qr_einsum.dur.std()}, shape of df:{qr_einsum.shape}')
    plot_boxplot(qr_matmul, qr_einsum, ['dur'])
    stat, p = ranksums(qr_matmul['dur'], qr_einsum['dur'])
    print(f'Test statistic={stat}, p_val={p}')

下面是两次迭代的截断输出,所有九次迭代都观察到了相同的模式。

###########i=1###########
##############################
Matmul: Mean=5204.543333333333, std. dev.=882.6151202759834, shape of df:(600, 3)
Einsum: Mean=6360.556666666666, std. dev.=373.461514250933, shape of df:(600, 3)
Test statistic=-21.986424230986046, p_val=3.884153635651101e-107

###########i=2###########
##############################
Matmul: Mean=5145.61, std. dev.=876.5247080600369, shape of df:(600, 3)
Einsum: Mean=6396.01, std. dev.=381.7892458942073, shape of df:(600, 3)
Test statistic=-22.450480914300588, p_val=1.2659476932444539e-111

这次,令人惊讶的是,`matmul`显著比`einsum`更快。这表明一种矩阵乘法算法并不总是优于另一种。矩阵的大小、形状和其他操作(如矩阵转置)等因素可能会影响速度。这突显了在应用或模型关键步骤中选择最佳算法时使用剖析技术的重要性。另外,如果你检查同一算法在箱线图中的数据点范围,可能会注意到许多异常值。这就是为什么在得出有效结论时,统计分析和适当的方法是如此重要的原因。本例中也使用了秩基检验而非经典的t检验,因为后者通常对异常值敏感。

总结

在剖析应用或模型性能时应应用稳健的统计分析和测试,以确保随机噪声的影响不会损害我们结论的有效性。 

标签:qr,trace,JAX,AI,einsum,AMD,shape,--,matmul
From: https://blog.csdn.net/eidolon_foot/article/details/141617667

相关文章

  • 基于Ubuntu部署企业级kubernetes集群---k8s集群容器运行时Containerd准备
    1.Containerd部署文件获取1.下载 Containerd文件wgethttps://github.com/containerd/containerd/releases/download/v1.7.21/cri-containerd-1.7.21-linux-amd64.tar.gz2.查看下载的文件 3.解压到当前文件到根目录下tarxfcri-containerd-1.7.21-linux-amd64.tar.g......
  • 负责缓解超级智能AI风险的OpenAI团队已失去近半数成员,一位前研究员表示
     根据前治理研究员DanielKokotajlo的说法,OpenAI已经失去了将近一半从事AI安全工作的人员。“这不是一个有组织的行动。我认为这只是个人逐渐放弃,”Kokotajlo在周二发布的一篇《财富》报道中表示。2023年4月离开OpenAI的Kokotajlo说,这家ChatGPT制造商最初有大约30人在处理......
  • 文字游侠AI工具:一个高效内容创作的革命性助手,效率一键提高20倍!
    在当今快节奏、高效率要求的信息时代,传统的内容生产方式已经难以满足不断增长的网络信息需求。随着人工智能技术的飞速发展,一系列创新的AI工具应运而生,极大地改变了我们处理信息和创造内容的方式。其中,文字游侠AI工具凭借其出色的性能和便利性,成为了许多内容创作者的首选利器。......
  • OpenAI Images Generations API 申请及使用
    OpenAIImagesGenerationsAPI申请及使用DALL-E3是OpenAI开发的两个版本的图像生成模型,它们能够根据文本描述生成高质量的图像。本文档主要介绍OpenAIImagesGenerationsAPI操作的使用流程,利用它我们可以轻松使用官方OpenAIDALL-E的图像生成功能。申请流程......
  • OpenAI Chat Completion API 申请及使用
    OpenAIChatCompletionAPI申请及使用OpenAIChatGPT是一款非常强大的AI对话系统,只要输入提示词,就能在短短几秒内生成流畅自然的回复。ChatGPT以其出色的语言理解和生成能力在业界独树一帜,如今,ChatGPT早已在各个行业和领域广泛应用,其影响力愈发显著。无论是日常对话......
  • AI大模型prompt "自洽性"和"思维树" 这两种的区别
    一个是从多个角度对同一问题给出不同解答,选择最好的那个另外一个就像一棵树,有主干,还有分支,每个分支上还有更细分的理由比如:自洽性夏季气温升高是因为太阳光线更直接地照射到地球上。在夏天,太阳的光线以更垂直的角度到达地球表面,导致热量更集中。夏天,白天时间长,太阳照射的......
  • 三步教会你使用ai辅助背诵面试题、书籍
    一、可以使用智普清言app免费使用,点击创建智能体电脑版左下角手机版登录后往左滑即可找到二、使用相关提示词可以有效避免ai重复回答、乱答提示词大家可以在实际使用中,不断更改,创建完毕后点击编辑智能体即可三、使用时再把提示词发给ai一遍,可以增加智能率结束......
  • 【ACMMM2024】Multi-Scale and Detail-Enhanced Segment Anything Model for Salient
    论文:https://arxiv.org/pdf/2408.04326代码:https://github.com/BellyBeauty/MDSAM论文的研究动机就是使用SAM来解决显著性检测(SOD)问题,主要有两个改进:提出了LightweightMulti-ScaleAdapter,LMSA来微调SAM提出了Multi-LevelFusionModule,MLFM和DetailEnhancementM......
  • 用 Higress AI 网关降低 AI 调用成本 - 阿里云天池云原生编程挑战赛参赛攻略
    作者介绍:杨贝宁,爱丁堡大学博士在读,研究方向为向量数据库《Higress AI网关挑战赛》正在火热进行中,Higress社区邀请了目前位于排行榜top5的选手杨贝宁同学分享他的心得。下面是他整理的参赛攻略:背景我们要在Higress网关中编写WebAssembly(wasm)插件,使得在http请求的各个......
  • 所以你被要求对AI内容进行“人性化处理”
    过去12个月里,“如何让AI内容更人性化”的搜索量增长了943%。越来越多的人试图“通过人为检查”,无论是为了让AI内容不那么糟糕,还是为了蒙混过AI内容检测工具。“使AI内容人性化”这一过程已成为那些渴望增长的公司的默认内容策略。我看到到处都是被压抑的作家,他们询问如何......