PyTorch 中的 expand
函数用于扩展张量的形状,使其在某些维度上“看起来”像被复制了多次,但实际上它不会复制数据,从而节省内存和计算资源。扩展后的张量共享原始张量的内存空间,因此原始张量和扩展后的张量是同一个数据的视图。
以下是 torch.expand
函数的一些基本用法:
1. 扩展一维张量: 将一维张量扩展到更高维度,例如,将一维张量扩展为二维张量的第一维度。
import torch x = torch.tensor([1, 2, 3]) y = x.expand(2, -1) # 扩展为 2x3 的矩阵 print(y) # 输出: # tensor([[1, 2, 3], # [1, 2, 3]])
-1
表示该维度的大小与原始张量相同。
2. 使用 -1
自动扩展: 使用 -1
可以自动扩展张量到与输入张量相同的大小。
x = torch.tensor([1, 2, 3]) y = x.expand(-1, 2) # 扩展为 3x2 的矩阵 print(y) # 输出: # tensor([[1, 1], # [2, 2], # [3, 3]])
3. 扩展多维张量: 扩展多维张量时,可以指定多个维度进行扩展。
x = torch.tensor([[1, 2], [3, 4]]) y = x.expand(2, -1, -1) # 扩展为 2x2x2 的张量 print(y) # 输出: # tensor([[[1, 2], # [1, 2]], # [[3, 4], # [3, 4]]])
4. 扩展与广播的区别: expand
与 broadcast
相似,但 broadcast
在进行操作时会复制数据,而 expand
不会。expand
更适用于减少内存使用。
5. 扩展与复制: 如果你需要一个实际复制了数据的新张量,可以使用 expand
后跟 clone
。
x = torch.tensor([1, 2, 3]) y = x.expand(2, -1).clone() # 扩展后复制数据 print(y) # 输出: # tensor([[1, 2, 3], # [1, 2, 3]])
使用 expand
时,扩展的维度大小可以是具体的数值,也可以是 -1
,表示该维度的大小与原始张量相同。如果扩展的维度大小大于原始张量,PyTorch 会抛出错误。
(摘自kimi)
标签:tensor,torch,扩展,张量,pytorch,维度,expand,函数 From: https://www.cnblogs.com/picassooo/p/18302453