首页 > 其他分享 >torch.einsum 的计算过程

torch.einsum 的计算过程

时间:2024-08-09 09:28:08浏览次数:10  
标签:... einsum 4.0 torch 张量 3.0 计算 维度

概论

a = torch.randn(3, 2, 2)
b = torch.randn(3)
c = torch.einsum('...chw,c->...hw', a, b)

上面的 einsum 如何计算的?
简单说,把 b 广播为 a 的形状,然后做矩阵乘法,即逐位相乘运算,注意,不是点积,是逐位的相乘运算。
注:这里符合背景需求,背景是,a 是深度学习的某个张量,b是a的权重,要求 a 的每一个元素都要乘以权重 b ,来得到实际有效的值。
然后,再把矩阵乘积的结果逐位相加后,得到最后结果,同时也去掉了维度c。

运算过程

具体运算细节如下:

为了详细解释 c = torch.einsum('...chw,c->...hw', a, b) 的计算过程,我们可以逐步分析每个部分的运算,并通过一个具体的例子说明结果的产生过程。

1. 张量 ab 的形状与内容

  • a 是一个形状为 (3, 2, 2) 的张量,假设其值为:
    a = torch.tensor([[[0.1, 0.2],
                       [0.3, 0.4]],
    
                      [[0.5, 0.6],
                       [0.7, 0.8]],
    
                      [[0.9, 1.0],
                       [1.1, 1.2]]])
    
  • b 是一个形状为 (3,) 的张量,假设其值为:
    b = torch.tensor([2.0, 3.0, 4.0])
    

2. einsum 表达式 '...chw,c->...hw' 解析

  • ...chw:

    • ... 匹配任意数量的前导维度,在本例中没有前导维度。
    • c 对应的是第一个维度(形状为3)。
    • h 对应第二个维度(形状为2)。
    • w 对应第三个维度(形状为2)。
  • c->...hw:

    • c 对应 b 中的元素,作为缩并维度,它会在计算中被消除(通过乘法与求和操作)。
    • ...hw 表示最终保留高度(h)和宽度(w)维度。

3. 具体计算过程

步骤 1:广播 b 以匹配 a 的形状

张量 b(形状 (3,))会被广播到形状 (3, 2, 2),与张量 a 对应的维度对齐:

b_broadcasted = torch.tensor([[[2.0, 2.0],
                               [2.0, 2.0]],

                              [[3.0, 3.0],
                               [3.0, 3.0]],

                              [[4.0, 4.0],
                               [4.0, 4.0]]])

步骤 2:逐元素相乘

a 和广播后的 b 逐元素相乘,结果为:

result = a * b_broadcasted
result = torch.tensor([[[0.2, 0.4],
                        [0.6, 0.8]],

                       [[1.5, 1.8],
                        [2.1, 2.4]],

                       [[3.6, 4.0],
                        [4.4, 4.8]]])

步骤 3:沿着 c 维度求和

对上面的结果,沿着 c 维度(也就是第一个维度)求和:

  • (0.2, 1.5, 3.6) 这三个数在 c 维度上相加,得到 5.3
  • (0.4, 1.8, 4.0) 这三个数在 c 维度上相加,得到 6.2
  • 以此类推,求得最终结果 c
c = torch.tensor([[5.3, 6.2],
                  [7.1, 8.0]])

4. 总结结果

最终的结果 c 是一个形状为 (2, 2) 的张量。每个元素是通过 a 中与 b 对应的元素逐元素相乘后,再沿 c 维度(通道维度)求和得到的。

因此,使用 torch.einsum('...chw,c->...hw', a, b) 后,得到的张量 c 对应的是去掉 c 维度后的加权和,仅保留高度 h 和宽度 w 维度的结果。

广播运算

问:b 广播为 b_broadcasted ,代码如何写

在 PyTorch 中,可以使用 unsqueezeexpand 方法来将张量 b 广播为与 a 形状匹配的 b_broadcasted。具体代码如下:

import torch

# 原始张量 a 和 b
a = torch.randn(3, 2, 2)
b = torch.randn(3)

# 扩展 b 的维度,使其能够广播到与 a 形状匹配
b_broadcasted = b.unsqueeze(-1).unsqueeze(-1).expand_as(a)

print(b_broadcasted)

步骤解释:

  1. unsqueeze(-1):将 b 的形状从 (3,) 变为 (3, 1),即在最后一维添加一个大小为 1 的新维度。
  2. unsqueeze(-1):再一次将形状从 (3, 1) 变为 (3, 1, 1),在最后再添加一个大小为 1 的维度。
  3. expand_as(a):将 b 广播到与 a 相同的形状 (3, 2, 2)

结果示例:

假设 b 的值为 [2.0, 3.0, 4.0],那么 b_broadcasted 将变为:

tensor([[[2.0, 2.0],
         [2.0, 2.0]],

        [[3.0, 3.0],
         [3.0, 3.0]],

        [[4.0, 4.0],
         [4.0, 4.0]]])

这就是张量 b 在计算过程中如何被广播为 b_broadcasted 的代码实现。

标签:...,einsum,4.0,torch,张量,3.0,计算,维度
From: https://www.cnblogs.com/litifeng/p/18350171

相关文章

  • 深入讲解云计算
    深入讲解云计算目录概述云计算的基本概念2.1云计算定义2.2云计算模型2.3云计算服务模式云计算的核心技术3.1虚拟化技术3.2容器技术3.3分布式计算3.4数据存储与管理3.5网络技术云计算的服务类型4.1基础设施即服务(IaaS)4.2平台即服务(PaaS)4.3软件即服务(SaaS)4.......
  • nodejs语言,MySQL数据库;springboot的个性化资讯推荐系统66257(免费领源码)计算机毕业设计
    摘 要随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,个性化资讯推荐系统当然也不能排除在外。个性化资讯推荐系统是以实际运用为开发背景,运用软件工程原理和开发方法,采用springboot技术构建的一个管理系统。整......
  • c#语言,SQL server数据库;基于Web的社区人员管理系统的设计与实现36303(免费领源码)计算机
    目 录摘要1绪论1.1慨述1.2课题意义1.3B/S体系结构介绍1.4ASP.NET框架介绍2 社区人员管理系统分析2.1可行性分析2.2系统流程分析2.2.1数据增加流程2.2.2数据修改流程52.2.3数据删除流程52.3系统功能分析62.3.1功能性分析62.3.2非功能性......
  • springboot党员信息管理系统-计算机毕设定制-附项目源码(可白嫖)50966
    目 录摘要1绪论1.1系统开发背景1.2系统发展趋势1.3研究方法1.4论文结构与章节安排2 党员信息管理系统系统分析2.1可行性分析2.1.1技术可行性分析2.1.2经济可行性分析2.1.3法律可行性分析2.2系统功能分析2.2.1功能性分析2.2.2非功能性分......
  • Django+记账管理系统-计算机毕设定制-附项目源码(可白嫖)50377
    摘 要本文课题研究的记账管理系统,系统的主要功能模块包括记账信息、企业类型、公告信息、公告类型等,采取面对对象的开发模式进行软件的开发和硬体的架设,能很好的满足实际使用的需求,完善了对应的软体架设以及程序编码的工作,采用Django开发框架,MySQL数据库,Ajax异步交互,根据Aj......
  • springboot在线众筹平台的设计与实现-计算机毕设定制-附项目源码(可白嫖)50388
    springboot在线在线众筹平台摘 要随着社会的发展,社会的各行各业都在利用信息化时代的优势。计算机的优势和普及使得各种信息系统的开发成为必需。在线众筹平台,主要的模块包括管理员和用户,实现功能包括:首页、轮播图、系统公告、资源管理(新闻列表、新闻分类)系统用户(管理员......
  • PHP+历史文化学习与交流网站-计算机毕设定制-附项目源码(可白嫖)50444
    目   录摘  要Abstract第1章  前  言1.1 研究背景1.2 研究现状1.3 系统开发目标第2章  系统开发环境2.1开发技术2.2 MVVM模式介绍2.3 MYSQL数据库2.4 B/S结构52.5 Thinkphp框架介绍5第3章  需求分析3.1 需求分析3.2......
  • einsum 函数
    einsum是Einsteinsummation的缩写,即爱因斯坦求和约定。einsum函数源自NumPy,后来在PyTorch等其他科学计算库中也得到了实现。它是一种强大而灵活的函数,可以用来处理各种张量运算,如矩阵乘法、转置、批量点积、内积、外积等。爱因斯坦求和约定(EinsteinSummationConvent......
  • torch 维度
    a=torch.randn(3)atensor([0.4789,0.1794,-1.5215])a.unsqueeze(-1)tensor([[0.4789],[0.1794],[-1.5215]])上述代码中,tensor变量a是一个一维tensor变量。当用unsqueeze(扩张)为二维变量后,为什么变为竖着写了?这个小问题,以前也没有细想过。现......
  • 计算机网络考研手写笔记_物理层
    1通信基础概念1)信源、信宿、信号、信道2)码元3)速率4)噪声2编码和调制3数据转为信号4传输介质1)导向型(双绞线、同轴电缆、光纤)2)非导向型(无线电波、微波通信)3)物理层接口的特性(机械、电气、功能、过程)5物理层设备1)中继器2)集线器......