首页 > 其他分享 >[cuda]RMSNorm核函数解析

[cuda]RMSNorm核函数解析

时间:2023-08-20 11:22:05浏览次数:36  
标签:__ ... idx float RMSNorm cuda hidden 解析 size

计算原理

\(RMSNorm = x * (sqrt(1/n * (x_i)^2 + eps)) * g\)

torch实现

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

先算出norm的值,然后再计算g*norm, 其中norm为平方和的根。注意这里是先转化为float进行进行norm运算,norm的结果再转为对应type。

cuda实现

__global__ void rms_norm_kernel(
  scalar_t* __restrict__ out,             // [num_tokens, hidden_size]
  const scalar_t* __restrict__ input,     // [num_tokens, hidden_size]
  const scalar_t* __restrict__ weight,    // [hidden_size]
  const float epsilon,
  const int num_tokens,
  const int hidden_size) {
  __shared__ float s_variance;
  float variance = 0.0f;

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
    const float x = (float) input[blockIdx.x * hidden_size + idx];
    variance += x * x;
  }
  variance = blockReduceSum<float>(variance);
  if (threadIdx.x == 0) {
    s_variance = rsqrtf(variance / hidden_size + epsilon);
  }
  __syncthreads();

  for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
    float x = (float) input[blockIdx.x * hidden_size + idx];
    out[blockIdx.x * hidden_size + idx] = ((scalar_t) (x * s_variance)) * weight[idx];
  }
}

这里variance计算了不同block间同余位置上的x的平方和,经过blockReduceSum则将部分和进行相加,得到全部的平方和,并在线程0下计算平方根。同时更新同一block下的所有output。

结果对比

torch run time: 0.5862712860107422 ms
torch.Size([200, 2048])
cuda run time: 0.06341934204101562 ms
torch.Size([200, 2048])
tensor([[ 0.2473, -0.4733, -1.5234,  ..., -1.0379,  0.2188, -1.7629],
        [-0.0408, -0.9154,  0.6396,  ...,  0.1713, -1.1047,  0.7188],
        [-1.0582, -0.0282,  0.7803,  ...,  1.4090,  1.4131,  1.7266],
        ...,
        [ 0.4701,  0.2073,  1.7602,  ..., -0.4985, -1.0406, -0.4027],
        [ 0.0527, -1.2559,  0.2172,  ..., -0.2953, -1.3365,  0.2298],
        [ 1.0274,  2.4901, -0.2216,  ...,  0.5723,  1.3783,  0.6167]],
       device='cuda:0', grad_fn=<MulBackward0>)
tensor([[ 0.2473, -0.4733, -1.5234,  ..., -1.0379,  0.2188, -1.7629],
        [-0.0408, -0.9154,  0.6396,  ...,  0.1713, -1.1047,  0.7188],
        [-1.0582, -0.0282,  0.7803,  ...,  1.4090,  1.4131,  1.7266],
        ...,
        [ 0.4701,  0.2073,  1.7602,  ..., -0.4985, -1.0406, -0.4027],
        [ 0.0527, -1.2559,  0.2172,  ..., -0.2953, -1.3365,  0.2298],
        [ 1.0274,  2.4901, -0.2216,  ...,  0.5723,  1.3783,  0.6167]],
       device='cuda:0')
max diff:  tensor(4.7684e-07, device='cuda:0', grad_fn=<MaxBackward1>)

本次测试的大小为[200, 2048], 即token长为200,feature dim长度为2048,可以看到torch的运行时间为0.58ms,cuda的运行时间为0.06ms,效率提升了一个数量级,而误差max diff为1e-7级别,是可接受的范围。

标签:__,...,idx,float,RMSNorm,cuda,hidden,解析,size
From: https://www.cnblogs.com/wildkid1024/p/17643752.html

相关文章

  • 深入解析 Redis 持久化机制
    引言我们都知道,Redis的数据存储在内存中,一旦服务器宕机,内存中的数据将全部丢失。因此,对Redis来说,实现数据的持久化,避免从后端数据库中进行恢复,是至关重要的。本篇我们详细讲解下Redis的三种持久化机制,分别是 AOF(AppendOnlyFile) 日志和 RDB快照 以及 混合持久化。......
  • C/C++ 中 static 关键字解析
    局部静态变量的特点:全局数据区执行到函数内对象声明处首次初始化,若没有显示初始化,自动初始化为0,且只初始化一次始终驻留在全局区,直到程序结束,作用域为局部作用域,在函数或语句块内,生命周期到程序结束全局静态变量的特点:全局区在main函数执行前分配内存并初始化注意:......
  • 一线大厂性能优化实战解析,看到就是赚到
    前言我们平时在使用软件的过程中是不是遇到过这样的情况:"这个app怎么还没下载完!"、太卡了吧!"、"图片怎么还没加载出来!"、"怎么刚进去就卡了!"、"这怎么点了一下就退出了!"等等,这些情况其实包含了我们性能优化的主要内容.,性能的优化是一个老生常谈的点,也是一个比较重要的点.特......
  • “金九银十”的秋招季,请收下这套互联网中大厂Android面试题大全(含答案解析)
    金九银十,每年9、10月份各大互联网公司都会周期性地发生人事变动,无论是刚进入社会的职场菜鸟,还是准备跳槽的老手,都想在这个时候获得新的工作,或迎来晋升涨薪的最佳机会。不同于往年的是今年的互联网寒冬好像更冷一点,形式更严峻了一些,不少公司都在裁员,可能在求职中有一大部分人经历了......
  • 快解析内网穿透便捷访问内网私有云
    快解析内网穿透软件的首要优势在于其不改变企业现有IT架构的特点。传统的内网穿透解决方案常常需要对企业网络进行重构,这不仅增加了工作量,还可能带来不稳定的因素。而快解析则巧妙地绕过了这一问题,让您能够在保持原有网络设备和配置的前提下,快速实现内网穿透,也不需要公网IP或迁移数......
  • 外网连接局域网的几种方式?快解析内网穿透安全便利吗?
    外网连接局域网是一项网络连接中的关键技术,它能够让远程用户通过互联网访问内部局域网中的资源和服务。外网连接局域网为企业提供了更大的灵活性和便捷性,但也需要严格的安全措施来防止未经授权的访问。 外网连接局域网的几种方式在将外网连接到局域网时,有三种常见的方式,那就是端口......
  • 2023年Android中高级最全面试题(含大厂原题+解析)
    前言又快要到了一年一度的金九银十黄金跳槽时节,也是互联网大厂疯狂招人的时期,现在应该有很多Android程序员已经按耐不住了。但是现在网上的面试题资料太多了,而且有些面试题已经过时甚至是漏洞百出。今天结合自己前段时间的面试经历和几位大厂大佬交流讨论总结出这份2023年Android中......
  • AOP源码解析:AspectJExpressionPointcutAdvisor类
    先看看AspectJExpressionPointcutAdvisor的类图再了解一下切点(Pointcut)表达式,它指定触发advice的方法,可以精确到返回参数,参数类型,方法名1packageconcert;23publicinterfacePerformance{4voidperform();5}AspectJExpressionPointcutAdvisor源码,官......
  • Windows设置本地DNS域名解析hosts文件配置--九五小庞
    DNSDomainNameSystem(域名系统):为了加快定位IP地址的速度,将域名映射进行层层缓存的系统.目的:互联网通过IP(10.223.146.45)定位浏览器建立连接,但是我们不易区别IP,为了方便用户辨识IP所代表的意义,操作系统会将IP和域名进行转换(roadmapsupporter.com)。IP比作IDCard:373×××××,......
  • 兰吉尔电表抄表数据采集费率时段通讯报文解析说明
    地址内容说明C748[  {    "1":"65",    "2":"255255112552552552552551280255",//年年月日周时分秒毫秒时区差时间状态,255代表未设置,所以是1月1日    "3":"0" //采用Week_active第0项  },  {    &q......