import torch
# tensor索引和切片
a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
b = torch.tensor([[10, 10, 10], [10, 10, 10], [10, 10, 10]])
print("a的值:\n", a)
# a的值:
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
# ----------------索引-------------------
print("a[1,2]第2行第3列:", a[1, 2])
# a[1,2]第2行第3列: tensor(6)
print("a[-1,-1]:", a[-1, -1]) # 索引位置可以用负数来表示,-1 表示最后索引位置,-2 表示倒数第2个索引位置
# a[-1,-1]: tensor(9)
print("a[1,[0,2]]第2行的第1列和第3列:", a[1, [0, 2]])
# a[1,[0,2]]第2行的第1列和第3列: tensor([4, 6])
print("a[[0,1],[0,2]]获取(0,0)和(1,2)位置上的值:", a[[0, 1], [0, 2]])
# a[[0,1],[0,2]]获取(0,0)和(1,2)位置上的值: tensor([1, 6])
# --布尔索引
index = a > 4 # 判断各位置上的数据是否大于4
print("a>4的索引:\n", index) # 大于4的位置上为True,小于4的位置上为False
print("a>4的值:", a[index]) # 位置上为True的值
# a>4的索引:
# tensor([[False, False, False],
# [False, True, True],
# [ True, True, True]])
# a>4的值: tensor([5, 6, 7, 8, 9])
print("大于5输出a的值否则输出b的值:\n", torch.where(a > 5, a, b))
# 大于5输出a的值否则输出b的值:
# tensor([[10, 10, 10],
# [10, 10, 6],
# [ 7, 8, 9]])
# ----------------切片-------------------
print(a)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
print("第1列:", a[:, 0])
# 第1列: tensor([1, 4, 7])
print("第3列:", a[:, 2])
# 第3列: tensor([3, 6, 9])
print("第3列:", a[:, -1])
# 第3列: tensor([3, 6, 9])
print("第1,2列:\n", a[:, 0:2]) # 取所有行,索引为0和1的列,即第1和2列
# 第1,2列:
# tensor([[1, 2],
# [4, 5],
# [7, 8]])
print("第1行:", a[0, :])
# 第1行: tensor([1, 2, 3])
print("第3行:", a[2, :])
# 第3行: tensor([7, 8, 9])
print("第3行:", a[-1, :])
# 第3行: tensor([7, 8, 9])
# ---有步长的切片
print(a)
# tensor([[1, 2, 3],
# [4, 5, 6],
# [7, 8, 9]])
print("a[::2, ::2]:\n", a[::2, ::2]) # 步长为2,即取索引为0和2的行,索引为0和2的列
# a[::2, ::2]:
# tensor([[1, 3],
# [7, 9]])
print("a[:, ::2]:\n", a[:, ::2]) # 带步长为2,即取所有的行,索引为0和2的列
# a[:, ::2]:
# tensor([[1, 3],
# [4, 6],
# [7, 9]])
t = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]])
print(t)
# tensor([[ 1, 2, 3, 4, 5],
# [ 6, 7, 8, 9, 10],
# [11, 12, 13, 14, 15],
# [16, 17, 18, 19, 20]])
print("t[1::2, 1::2]:\n", t[1::2, 1::2])
# t[1::2, 1::2]:
# tensor([[ 7, 9],
# [17, 19]])
print("t[1::2, ::3]\n", t[1::2, ::3])
# t[1::2, ::3]
# tensor([[ 6, 9],
# [16, 19]])
标签:10,Tensor,--,torch,索引,PyTorch,print,True,tensor
From: https://blog.csdn.net/m0_74895132/article/details/141993542