我有一个巨大的稀疏矩阵 S。每列有 1 个值为 1 的条目。其余值为零。我经常需要将 S 与另一个矩阵 A (
S@A
) 相乘。 它实际上只是一个“选择器”,所以在 numpy 中实现很简单。
torch 有一些 API 可以帮助我解决这个问题吗?
S@A
的结果将在gpu中广泛使用(例如进一步乘法、求解线性系统等)。所以我也想在 GPU 中实现。
S@A
in GPU too.
很遗憾,在 PyTorch 中没有直接的方法可以将单位向量矩阵的乘法存储和执行为稀疏矩阵以实现高效的 GPU 计算。PyTorch 中的稀疏矩阵支持仍然有限,而且像描述的那样的单位矩阵的存储格式不提供利用此类结构进行乘法的优化途径。
但是,我们可以通过利用的矩阵
S
的特殊属性(每列仅一个 1 值)来实现高效的 GPU 加速乘法。
我们可以将
S
中每个 1 的位置存储为索引,并使用这些索引从
A
中收集相应的行。这种方法避免了显式矩阵乘法,并且可以非常高效,尤其是在 GPU 上。
以下是操作方法:
import torch
def sparse_select(A, indices):
"""
A: 要从中选择行的密集张量 (batch_size, m, k)
indices: 包含要为每列选择的行索引的张量 (batch_size, n)
返回:
选定行的张量 (batch_size, n, k)
"""
batch_size, m, k = A.size()
n = indices.size(1)
# 将索引转换为线性索引
linear_indices = indices + torch.arange(batch_size)[:, None] * m
# 使用 gather 从 A 中收集行
selected_rows = A.view(batch_size * m, k)[linear_indices.view(-1)].view(batch_size, n, k)
return selected_rows
# 示例用法
batch_size = 2
n = 4
m = 5
k = 3
# 示例矩阵 A
A = torch.randn(batch_size, m, k).cuda()
# 示例索引(每列一个 1 的矩阵 S)
indices = torch.tensor([[1, 0, 3, 2],
[4, 2, 1, 0]]).cuda()
# 使用索引选择行
selected_rows = sparse_select(A, indices)
# 打印结果
print(selected_rows)
在此代码中:
-
indices
张量存储S
中每个 1 的位置(行索引)。 -
sparse_select()
函数获取密集矩阵A
和索引作为输入。 -
它计算必要的线性索引以从
A
中收集正确的行,从而有效地执行S@A
的操作,而无需实际构造或存储S
作为稀疏矩阵。
这种方法利用了 PyTorch 中的高效索引操作,并且在 GPU 上运行良好,因为它避免了与稀疏矩阵乘法相关的开销。
请记住,此方法假设
S
的结构是每列有一个 1。如果的
S
矩阵发生变化,可能需要调整此代码以适应不同的稀疏模式。