首页 > 编程语言 >Nvidia Tensor Core-WMMA API编程入门

Nvidia Tensor Core-WMMA API编程入门

时间:2023-04-12 19:02:48浏览次数:50  
标签:__ Core Tensor IMAD 0x0 RZ API U32 WMMA

1 WMMA (Warp-level Matrix Multiply Accumulate) API

对于计算能力在7.0及以上的CUDA设备,可以使用CUDA C++ API调用Tensor Core,支持形如D = AB + C的混合精度的矩阵乘运算。
template<typename Use, int m, int n, int k, typename T, typename Layout=void> class fragment;

void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm);
void load_matrix_sync(fragment<...> &a, const T* mptr, unsigned ldm, layout_t layout);
void store_matrix_sync(T* mptr, const fragment<...> &a, unsigned ldm, layout_t layout);
void fill_fragment(fragment<...> &a, const T& v);
void mma_sync(fragment<...> &d, const fragment<...> &a, const fragment<...> &b, const fragment<...> &c, bool satf=false);
  • fragment:Tensor Core数据存储类,支持matrix_a、matrix_b和accumulator
  • load_matrix_sync:Tensor Core数据加载API,支持将矩阵数据从global memory或shared memory加载到fragment
  • store_matrix_sync:Tensor Core结果存储API,支持将计算结果从fragment存储到global memory或shared memory
  • fill_fragment:fragment填充API,支持常数值填充
  • mma_sync:Tensor Core矩阵乘计算API,支持D = AB + C或者C = AB + C

2 示例

以m16n16k16为例,实现HGEMM:C = AB,其中矩阵A(M * K,row major)、B(K * N,col major)和C(M * N,row major)的精度均为FP16。首先我们看如何使用CUDA Core写HGEMM naive算法。

2.1 CUDA Core

按照每个线程计算矩阵C中的一个元素来构建naive kernel,首先确定当前线程处理矩阵C的元素坐标,再遍历K并直接从global memory中加载所需A、B矩阵元素到寄存器参与计算,最后将计算结果从寄存器直接写回矩阵C。所有block计算完成之后即可得到矩阵C。这个例子不能说简单,只能说技术含量不高,不过我们只是为了对比。
#define DIV_CEIL(x, y) (((x) + (y) - 1) / (y))

__global__ void naiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
                                size_t N, size_t K) {
    size_t row = threadIdx.x + blockDim.x * blockIdx.x;
    size_t col = threadIdx.y + blockDim.y * blockIdx.y;
    if (row < M && col < N) {
        half tmp = 0.0;
        for (size_t i = 0; i < K; ++i) {
            tmp += A[row * K + i] * B[i + col * K];
        }
        C[row * N + col] = tmp;
    }
}

void hgemmNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
    dim3 block(16, 16);
    dim3 grid(DIV_CEIL(M, block.x), DIV_CEIL(N, block.y));

    naiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}

2.2 Tensor Core

我们再来看如何用WMMA API来构建naive kernel,参考cuda sample。与CUDA Core naive不同的是,WMMA需要按照每个warp处理一个矩阵C的WMMA_M * WMMA_N大小的tile的思路来构建,因为Tensor Core的计算层级是warp级别,计算的矩阵元素也是二维的。接下来,与CUDA Core naive的处理思路一致,首先确定当前warp处理矩阵C的tile坐标,声明计算tilie所需的fragment,再以WMMA_K为步长遍历K并直接从global memory中加载所需A、B矩阵tile到fragment参与计算,最后将计算结果从fragment直接写回矩阵C。所有block计算完成之后即可得到矩阵C。 值得注意的是,load_matrix_sync和store_matrix_sync都是按stride访问矩阵元素。
#include <mma.h>

#define WARP_SIZE 32

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

using namespace nvcuda;

__global__ void wmmaNaiveKernel(const half *__restrict__ A, const half *__restrict__ B, half *__restrict__ C, size_t M,
                                size_t N, size_t K) {
    size_t warpM = (blockIdx.x * blockDim.x + threadIdx.x) / WARP_SIZE;
    size_t warpN = (blockIdx.y * blockDim.y + threadIdx.y);

    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, half, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K, half, wmma::col_major> b_frag;
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K, half> c_frag;

    wmma::fill_fragment(c_frag, 0.0f);

    for (size_t i = 0; i < K; i += WMMA_K) {
        size_t aCol = i;
        size_t aRow = warpM * WMMA_M;
        size_t bCol = warpN * WMMA_N;
        size_t bRow = i;

        if (aRow < M && aCol < K && bRow < K && bCol < N) {
            wmma::load_matrix_sync(a_frag, A + aCol + aRow * K, K);
            wmma::load_matrix_sync(b_frag, B + bRow + bCol * K, K);

            wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
        }
    }

    size_t cCol = warpN * WMMA_N;
    size_t cRow = warpM * WMMA_M;

    if (cRow < M && cCol < N) {
        wmma::store_matrix_sync(C + cCol + cRow * N, c_frag, N, wmma::mem_row_major);
    }
}

void hgemmWmmaNaive(half *A, half *B, half *C, size_t M, size_t N, size_t K) {
    dim3 block(128, 4);
    dim3 grid((M - 1) / (WMMA_M * block.x / WARP_SIZE) + 1, (N - 1) / (WMMA_N * block.y) + 1);

    wmmaNaiveKernel<<<grid, block>>>(A, B, C, M, N, K);
}

2.3 区别

从上述两个naive kernel的代码来看调用CUDA Core和Tensor Core的区别如下:
  • 计算层级:CUDA Core是线程级别,Tensor Core是warp级别
  • 计算维度:CUDA Core是一维逐点计算,Tensor Core是二维逐tile计算
  • 计算依赖:WMMA调用Tensor Core需要借助数据存储类fragment,CUDA Core不需要借助其他

3 底层代码

我们再对上述WMMA naive kernel做进一步探索,看一下它在RTX A6000(sm_86,CUDA 11.3)上对应的PTX和SASS。

3.1 PTX

dump出对应的PTX代码如下,好像不那么简单了。
.visible .entry _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm(
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4,
.param .u64 _Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5
)
{
.reg .pred %p<8>;
.reg .b16 %rs<2>;
.reg .f32 %f<2>;
.reg .b32 %r<58>;
.reg .b64 %rd<28>;

ld.param.u64 %rd9, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_0];
ld.param.u64 %rd10, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_1];
ld.param.u64 %rd11, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_2];
ld.param.u64 %rd14, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_3];
ld.param.u64 %rd12, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_4];
ld.param.u64 %rd13, [_Z15wmmaNaiveKernelPK6__halfS1_PS_mmm_param_5];
mov.u32 %r19, %ntid.x;
mov.u32 %r20, %ctaid.x;
mov.u32 %r21, %tid.x;
mad.lo.s32 %r22, %r20, %r19, %r21;
mov.u32 %r23, %ntid.y;
mov.u32 %r24, %ctaid.y;
mov.u32 %r25, %tid.y;
mad.lo.s32 %r26, %r24, %r23, %r25;
mov.f32 %f1, 0f00000000;

	{ cvt.rn.f16.f32 %rs1, %f1;}

	mov.b32 %r50, {%rs1, %rs1};
mul.wide.u32 %rd1, %r26, 16;
shr.u32 %r27, %r22, 1;
and.b32 %r28, %r27, 2147483632;
cvt.u64.u32 %rd2, %r28;
setp.lt.u64 %p2, %rd2, %rd14;
setp.lt.u64 %p3, %rd1, %rd12;
and.pred %p1, %p2, %p3;
setp.eq.s64 %p4, %rd13, 0;
mov.u32 %r51, %r50;
mov.u32 %r52, %r50;
mov.u32 %r53, %r50;
@%p4 bra $L__BB0_5;

mul.lo.s64 %rd3, %rd2, %rd13;
cvt.u32.u64 %r2, %rd13;
mul.lo.s64 %rd4, %rd1, %rd13;
cvta.to.global.u64 %rd5, %rd10;
cvta.to.global.u64 %rd6, %rd9;
mov.u64 %rd27, 0;
not.pred %p5, %p1;
mov.u32 %r51, %r50;
mov.u32 %r52, %r50;
mov.u32 %r53, %r50;

$L__BB0_2:
@%p5 bra $L__BB0_4;

add.s64 %rd16, %rd27, %rd3;
shl.b64 %rd17, %rd16, 1;
add.s64 %rd18, %rd6, %rd17;
wmma.load.a.sync.aligned.row.m16n16k16.global.f16 {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, [%rd18], %r2;
add.s64 %rd19, %rd27, %rd4;
shl.b64 %rd20, %rd19, 1;
add.s64 %rd21, %rd5, %rd20;
wmma.load.b.sync.aligned.col.m16n16k16.global.f16 {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, [%rd21], %r2;
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16 {%r53, %r52, %r51, %r50}, {%r29, %r30, %r31, %r32, %r33, %r34, %r35, %r36}, {%r37, %r38, %r39, %r40, %r41, %r42, %r43, %r44}, {%r53, %r52, %r51, %r50};

$L__BB0_4:
add.s64 %rd27, %rd27, 16;
setp.lt.u64 %p6, %rd27, %rd13;
@%p6 bra $L__BB0_2;

$L__BB0_5:
not.pred %p7, %p1;
@%p7 bra $L__BB0_7;

mul.lo.s64 %rd22, %rd2, %rd12;
add.s64 %rd23, %rd22, %rd1;
cvta.to.global.u64 %rd24, %rd11;
shl.b64 %rd25, %rd23, 1;
add.s64 %rd26, %rd24, %rd25;
cvt.u32.u64 %r45, %rd12;
wmma.store.d.sync.aligned.row.m16n16k16.global.f16 [%rd26], {%r53, %r52, %r51, %r50}, %r45;

$L__BB0_7:
ret;

}

不过我们主要关注WMMA相关的PTX指令,如下所示。可以看到这里正是Nvidia提供的WMMA PTX指令来调用Tensor Core,所以无论是使用WMMA API编程,还是使用WMMA PTX指令编程,底层差别不会太大。

wmma.load.a.sync.aligned.row.m16n16k16.global.f16
wmma.load.b.sync.aligned.col.m16n16k16.global.f16
wmma.mma.sync.aligned.row.col.m16n16k16.f16.f16
wmma.store.d.sync.aligned.row.m16n16k16.global.f16

3.2 SASS

进一步dump出对应的SASS代码,似乎也不简单。
      IMAD.MOV.U32 R1, RZ, RZ, c[0x0][0x28] 
      S2R R0, SR_CTAID.X 
      ISETP.NE.U32.AND P2, PT, RZ, c[0x0][0x188], PT 
      ULDC.64 UR4, c[0x0][0x118] 
      CS2R R8, SRZ 
      S2R R10, SR_CTAID.Y 
      ISETP.NE.AND.EX P2, PT, RZ, c[0x0][0x18c], PT, P2 
      S2R R5, SR_TID.Y 
      S2R R3, SR_TID.X 
      IMAD R10, R10, c[0x0][0x4], R5 
      IMAD R0, R0, c[0x0][0x0], R3 
      IMAD.WIDE.U32 R10, R10, 0x10, RZ 
      CS2R R2, SRZ 
      SHF.R.U32.HI R0, RZ, 0x1, R0 
      ISETP.GE.U32.AND P0, PT, R10, c[0x0][0x180], PT 
      LOP3.LUT R13, R0, 0x7ffffff0, RZ, 0xc0, !PT 
      ISETP.GE.U32.AND.EX P0, PT, R11, c[0x0][0x184], PT, P0 
      ISETP.LT.U32.AND P1, PT, R13, c[0x0][0x178], PT 
      ISETP.LT.U32.AND.EX P0, PT, RZ, c[0x0][0x17c], !P0, P1 
@!P2  BRA 0x7f1eaefc0160 
      BSSY B0, 0x7f1eaefc0160 
      IMAD.MOV.U32 R0, RZ, RZ, RZ 
      CS2R R8, SRZ 
      IMAD.MOV.U32 R15, RZ, RZ, RZ 
      IMAD.MOV.U32 R2, RZ, RZ, RZ 
      BSSY B1, 0x7f1eaefc0100 
@!P0  BRA 0x7f1eaefc00f0 
      S2R R16, SR_LANEID 
      IMAD R17, R11, c[0x0][0x188], RZ 
      IMAD.MOV.U32 R14, RZ, RZ, R0 
      IMAD.MOV.U32 R23, RZ, RZ, c[0x0][0x188] 
      IMAD.WIDE.U32 R6, R10, c[0x0][0x188], R14 
      SHF.R.U32.HI R12, RZ, 0x1, R23 
      IMAD R17, R10, c[0x0][0x18c], R17 
      LEA R21, P2, R6, c[0x0][0x168], 0x1 
      IMAD.WIDE.U32 R4, R13, c[0x0][0x188], R14 
      IMAD.IADD R7, R7, 0x1, R17 
      IMAD.MOV.U32 R17, RZ, RZ, RZ 
      IMAD R5, R13, c[0x0][0x18c], R5 
      LEA.HI.X R7, R6, c[0x0][0x16c], R7, 0x1, P2 
      SHF.R.U32.HI R19, RZ, 0x2, R16 
      LOP3.LUT R16, R16, 0x3, RZ, 0xc0, !PT 
      IMAD.WIDE.U32 R16, R19, R12, R16 
      LEA R19, P1, R4, c[0x0][0x160], 0x1 
      LEA.HI.X R5, R4, c[0x0][0x164], R5, 0x1, P1 
      LEA R18, P1, R16, R19, 0x2 
      LEA R20, P2, R16, R21, 0x2 
      LEA.HI.X R19, R16, R5, R17, 0x2, P1 
      LEA.HI.X R21, R16, R7, R17, 0x2, P2 
      IMAD.WIDE.U32 R16, R23, 0x10, R18 
      LDG.E R4, [R18.64] 
      IMAD.WIDE.U32 R22, R23, 0x10, R20 
      LDG.E R24, [R20.64] 
      LDG.E R25, [R20.64+0x10] 
      LDG.E R6, [R18.64+0x10] 
      LDG.E R5, [R16.64] 
      LDG.E R7, [R16.64+0x10] 
      LDG.E R26, [R22.64] 
      LDG.E R27, [R22.64+0x10] 
      WARPSYNC 0xffffffff 
      HMMA.16816.F16 R8, R4, R24, R8 
      HMMA.16816.F16 R2, R4, R26, R2 
      NOP 
      BSYNC B1 
      IADD3 R0, P1, R0, 0x10, RZ 
      IMAD.X R15, RZ, RZ, R15, P1 
      ISETP.GE.U32.AND P1, PT, R0, c[0x0][0x188], PT 
      ISETP.GE.U32.AND.EX P1, PT, R15, c[0x0][0x18c], PT, P1 
@!P1  BRA 0x7f1eaefbfe90 
      BSYNC B0 
@!P0  EXIT 
      S2R R4, SR_LANEID 
      IMAD.MOV.U32 R15, RZ, RZ, c[0x0][0x180] 
      WARPSYNC 0xffffffff 
      IMAD.WIDE.U32 R10, R13, c[0x0][0x180], R10 
      SHF.R.U32.HI R15, RZ, 0x1, R15 
      IMAD.MOV.U32 R5, RZ, RZ, RZ 
      LEA R7, P0, R10, c[0x0][0x170], 0x1 
      IMAD R11, R13, c[0x0][0x184], R11 
      LEA.HI.X R11, R10, c[0x0][0x174], R11, 0x1, P0 
      SHF.R.U32.HI R0, RZ, 0x2, R4 
      LOP3.LUT R4, R4, 0x3, RZ, 0xc0, !PT 
      IMAD.WIDE.U32 R4, R0, R15, R4 
      LEA R6, P0, R4, R7, 0x2 
      LEA.HI.X R7, R4, R11, R5, 0x2, P0 
      IMAD.WIDE.U32 R4, R15, 0x20, R6 
      STG.E [R6.64], R8 
      STG.E [R4.64], R9 
      STG.E [R6.64+0x10], R2 
      STG.E [R4.64+0x10], R3 
      EXIT 
      BRA 0x7f1eaefc02b0
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP
      NOP

我们依然主要关注WMMA相关的SASS指令,如下所示。可以发现WMMA161616在底层是通过两个HMMA16816指令实现,同样地,SASS指令也是Nvidia提供的另一种调用Tensor Core的编程方法。

HMMA.16816.F16
Nvidia Tensor Core初探中提到Nvidia提供了四种调用Tensor Core的编程方法,这里提到了三种,还有一种是MMA PTX指令,其中MMA16816 PTX指令底层实现即是HMMA16816指令,后续会在MMA PTX相关文章中提及。

4 其他

4.1 HGEMM优化

学习WMMA API的目标在于调用Tensor Core优化HGEMM,相比于cublas,WMMA的性能究竟如何?    

标签:__,Core,Tensor,IMAD,0x0,RZ,API,U32,WMMA
From: https://www.cnblogs.com/bruceleely/p/17310854.html

相关文章

  • 国内好用的API平台,还是BAT的实力最强
    应用程序编程接口API即(ApplicationProgrammingInterface),现在众多企业的应用系统中常用的开放接口,对接相应的系统、软件功能,简化专业化的程序开发。下面这些厂商提供的API都不错,比如:百度API、阿里云API、腾讯云API、聚合API传送门国内好用的API平台,还是BAT的实力最强......
  • 在.NET Core使用 HttpClient 的正确方式
    前言HttpClient是.NETFramework、.NETCore或.NET5以上版本中的一个类,用于向WebAPI发送HTTP请求并接收响应。它提供了一些简单易用的方法,如GET、POST、PUT和DELETE,可以很容易地构造和发送HTTP请求,并处理响应数据。它是我们比较常用的官方HTTP请求组件,那么你们都......
  • 基于vue3+ts+vite的项目使用‘unplugin-auto-import/vite’插件,自动全局导入api的注意
    ​ 1.首先安装插件npminstallunplugin-auto-import@vitejs/plugin-vue-D2.安装完成后在vite.congfig.ts中配置,红色部分就是关于插件的基础自动导入部分,这样就可以将vue和router的相关api全局导入了,importAutoImportfrom"unplugin-auto-import/vite"exportdefaultde......
  • APISIX 是怎么保护用户的敏感数据不被泄露的?
    本文以APISIX作为例子,为大家介绍了如何借助GlobalDataEncryption功能来保护敏感数据,确保不会有任何敏感数据进行明文存储,这样即使etcd中所有存储的数据都被盗取,也不会造成敏感数据泄露,从而有效提升了APISIX的安全性。作者刘维,API7.ai技术工程师,ApacheAPISIXContrib......
  • 如何利用空号检测 API 降低企业的业务成本
    引言电话营销和短信营销是现代营销活动的重要手段,但是无效号码的存在导致了很多资源浪费。利用空号检测API就可以避免企业向无效号码发送营销信息,降低营销成本。本文将介绍空号检测API的工作原理以及如何利用空号检测API降低业务成本的多种方式,包括降低电话营销成本、降......
  • Asp.NetWebApi跨域解决方案
    最近新开项目,项目初期本地部署访问webapi,api采用JWT验证,发现加上Authorization请求头后就报跨域问题,网上很多方案例如修改webconfig请求头什么的,均测试未果,多方尝试后有了以下方案,一行代码便完全解决跨域问题,特此记录!1、NuGet下载Microsoft.AspNet.WebApi.Cors  2、App_Sta......
  • 手机号码归属地 API 实现个性化推荐的思路分析
    前言随着移动互联网和智能手机的普及,越来越多的人使用手机上网和购物,移动营销已成为企业获取用户和提升品牌知名度的重要手段。手机号码归属地API作为移动营销的关键工具,具有广阔的应用前景。本文将探讨如何利用手机号码归属地API进行个性化推荐和精准广告投放,希望对大家有......
  • Apipost智能Mock功能详解
    在接口开发过程中,Mock功能可以帮助开发者快速测试和验证接口的正确性和稳定性,以便快速迭代和修复问题。Apipost推出智能Mock功能,可以在智能期望中填写一些触发条件,开启后,Apipost会根据已设置的触发条件,自动匹配旗下的参数判断规则,若满足条件,则会启用预设的期望。这篇文章会模拟用户......
  • 从0到1手把手教你ASP.NET Core Web API项目配置接口文档Swagger(一)
    一、创建ASP.NETCoreWebAPI项目(若项目已创建,则可跳过本节内容)1、双击打开VS2022。2、单击“创建新项目”,如下图。3、选择“ASP.NETCoreWebAPI”类型,然后单击“下一步”,如下图。4、“项目名称”此处填写为“AllTestDemo”;“位置”此处放在E盘根目录;“解决方案名称”此......
  • 微信小程序开发——getLocation:fail the api need to be declared in the requiredPr
    getLocation:failtheapineedtobedeclaredintherequiredPrivateInfosfieldinapp.json/ext.json异常解析:app.json中没配置requiredPrivateInfos参数,按下边示例代码配置即可。示例代码:{..."permission":{"scope.userLocation":{"desc&qu......