在PyTorch中,torch.diag
函数可以用于创建对角线张量或提取给定矩阵的对角线元素。以下是一些详细的使用例子:
-
创建对角矩阵:如果输入是一个向量(1D张量),
torch.diag
将返回一个2D方阵,其中输入向量的元素作为对角线元素。例如:a = torch.randn(3) print(a) # 输出:tensor([ 0.5950,-0.0872, 2.3298]) print(torch.diag(a)) # 输出:tensor([[ 0.5950, 0.0000, 0.0000], # [ 0.0000,-0.0872, 0.0000], # [ 0.0000, 0.0000, 2.3298]])
-
提取对角线元素:如果输入是一个矩阵(2D张量),
torch.diag
将返回一个1D张量,包含输入矩阵的对角线元素。例如:a = torch.randn(3, 3) print(a) # 输出:tensor([[-0.4264, 0.0255,-0.1064], # [ 0.8795,-0.2429, 0.1374], # [ 0.1029,-0.6482,-1.6300]]) print(torch.diag(a, 0)) # 输出:tensor([-0.4264, -0.2429, -1.6300])
-
指定对角线:
torch.diag
函数还允许你通过diagonal
参数指定要提取或使用的对角线。diagonal=0
表示主对角线,diagonal>0
表示主对角线上方的对角线,diagonal<0
表示主对角线下方的对角线。例如,提取矩阵的第二条对角线:print(torch.diag(a, 1)) # 输出:tensor([ 0.0255, 0.1374])
这些例子展示了如何使用torch.diag
函数来创建对角矩阵或提取对角线元素,以及如何通过diagonal
参数来指定对角线。这些操作在矩阵分解和转换等数学和深度学习任务中非常有用。
喜欢本文,请点赞、收藏和关注!
标签:tensor,diag,torch,矩阵,Pytorch,对角线,0.0000 From: https://blog.csdn.net/jimn2000/article/details/141563315