点击查看代码
# -*- coding: utf-8 -*-
# @Author : 钱力
# @Time : 2024/7/26 14:24
import torch
# 合并操作
A = torch.arange(0, 16).view(2, 8)
B = 10 * A
C = torch.cat([A, B], dim=1) # 将矩阵根据特定维度进行缝合
print(C)
D = torch.stack([A, B], dim=1) # 通过增加维度来融合矩阵,这种融合方式一般是时间序列采用
print(D)
# 切分操作
print('=====================================================')
a = torch.arange(10).reshape(5, 2)
print(torch.chunk(a, 2)) # 根据索引进行切分
a = torch.arange(10).reshape(5, 2)
print(torch.split(a, 2)) # 根据长度进行切分
a = torch.arange(10).reshape(5, 2)
print(torch.split(a, [3, 1, 1])) # 根据长度进行切分
# 现有张量沿着值为1的维度扩展到新的维度n,输出重复n次
a = torch.tensor([[[1, 2, 3], [4, 5, 6]]])
print(a.size())
print(a)
a = a.expand(2, 2, 3) # 仅限于size=1的维度
print(a.size())
print(a)
# 改变张量的维度
a = torch.arange(9).reshape(3, 3)
print('a:', a)
b = a.permute(1, 0) # 维度转换,但不改变索引方式
print('b:', b)
print(b.stride()) # 张量的索引方式
print(b.is_contiguous()) # 是否连续,视图索引和内存索引是否一致
c = b.contiguous() # 强制转换为一致
print(c.stride())
print(c.is_contiguous())
# a 和 b共享内存,但c不是
print('ptr of storage of a', a.untyped_storage().data_ptr())
print('ptr of storage of b', b.untyped_storage().data_ptr())
print('ptr of storage of c', c.untyped_storage().data_ptr())
# reshape和view区别
a = torch.arange(9).reshape(3, 3)
b = a.permute(1, 0)
print(b.reshape(9))
# print(b.view(9)) # 如果视图索引和内存索引不一致,就会报错
print(b.contiguous().view(9))