首页 > 其他分享 >2-1张量数据结构

2-1张量数据结构

时间:2023-12-22 20:56:17浏览次数:29  
标签:tensor dtype torch 张量 add print 数据结构

0.配置

Pytorch的基本数据结构是张量Tensor。张量及多维数组。Pytorch的张量和numpy中的array很类似。 本节我们主要介绍张量的数据类型、张量的维度、张量的尺寸、张量和numpy数组等基本概念。

import torch

print('torch.__version__=' + torch.__version__)

"""
torch.__version__=2.1.1+cu118
"""

1.张量的数据类型

张量的数据类型和numpy.array基本一一对应,但不支持str类型。

包括:

torch.float64(torch.double)

torch.float32(torch.float)

torch.float16

torch.int64(torch.long)

torch.int32(torch.int)

torch.int16

torch.int8

torch.uint8

torch.bool

一般神经网络建模使用的都是torch.float32类型

import numpy as np
import torch

# 自动推断数据类型
i = torch.tensor(1)
print(i, i.dtype)

x = torch.tensor(2.0)
print(x, x.dtype)

b = torch.tensor(True)
print(b, b.dtype)

"""
tensor(1) torch.int64
tensor(2.) torch.float32
tensor(True) torch.bool
"""

# 指定数据类型
i = torch.tensor(1, dtype=torch.int32)
print(i, i.dtype)

x = torch.tensor(2.0, dtype=torch.float32)
print(x, x.dtype)

"""
tensor(1, dtype=torch.int32) torch.int32
tensor(2.) torch.float32
"""

# 使用特定类型构造函数
i = torch.IntTensor(1)
print(i, i.dtype)

x = torch.Tensor(np.array(2.0))  # 等价于torch.FloatTensor
print(x, x.dtype)

b = torch.BoolTensor(np.array([1, 0, 2, 0]))
print(b, b.dtype)

"""
tensor([1073741824], dtype=torch.int32) torch.int32
tensor(2.) torch.float32
tensor([ True, False,  True, False]) torch.bool
"""

# 不同类型进行转换
i = torch.tensor(1)
print(i, i.dtype)

x = i.float()
print(x, x.dtype)

y = i.type(torch.float)
print(y, y.dtype)

z = i.type_as(x)
print(z, z.dtype)

"""
tensor(1) torch.int64
tensor(1.) torch.float32
tensor(1.) torch.float32
tensor(1.) torch.float32
"""

2.张量的维度

不同类型的数据可以用不同维度的张量来表示。

标量为0维张量,向量为1维张量,矩阵为2维张量。

彩色图像有rgb三个通道,可以表示为3维张量。

视频还有时间维,可以表示为4维张量。

可以简单总结为:有几层中括号,就是多少维张量。

scaler = torch.tensor(True)
print(scaler)
print(scaler.dim())

"""
tensor(True)
0
"""

vector = torch.tensor([1.0, 2.0, 3.0, 4.0])
print(vector)
print(vector.dim())

"""
tensor([1., 2., 3., 4.])
1
"""

matrix = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
print(matrix)
print(matrix.dim())

"""
tensor([[1., 2.],
        [3., 4.]])
2
"""

tensor3 = torch.tensor([[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]])
print(tensor3)
print(tensor3.dim())

"""
tensor([[[1., 2.],
         [3., 4.]],

        [[5., 6.],
         [7., 8.]]])
3
"""

tensor4 = torch.tensor([[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]], [[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]])
print(tensor4)
print(tensor4.dim())

"""
tensor([[[[1., 2.],
          [3., 4.]],

         [[5., 6.],
          [7., 8.]]],


        [[[1., 2.],
          [3., 4.]],

         [[5., 6.],
          [7., 8.]]]])
4
"""

3.张量的尺寸

可以使用shape属性或者size()方法查看张量在每一维的长度。

可以使用view方法改变张量的尺寸。

如果view方法改变尺寸失败,也可以使用reshape方法

scaler = torch.tensor(True)
print(scaler.size())
print(scaler.shape)

"""
torch.Size([])
torch.Size([])
"""

vector = torch.tensor([1.0, 2.0, 3.0, 4.0])
print(vector.size())
print(vector.shape)

"""
torch.Size([4])
torch.Size([4])
"""

matrix = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
print(matrix.size())
print(matrix.shape)

"""
torch.Size([2, 2])
torch.Size([2, 2])
"""

# 使用view可以改变张量尺寸
vector = torch.arange(0, 12)
print(vector)
print(vector.shape)

matrix = vector.view(3, 4)
print(matrix)
print(matrix.shape)

matrix43 = vector.view(4, -1)
print(matrix43)
print(matrix43.shape)

"""
tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])
torch.Size([12])
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]])
torch.Size([3, 4])
tensor([[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]])
torch.Size([4, 3])
"""

# 有些操作会让张量存储结构扭曲,直接使用view会失败,可以用reshape方法
matrix26 = torch.arange(0, 12).view(2, 6)
print(matrix26)
print(matrix26.shape)

# 转置操作让张量存储结构扭曲
matrix62 = matrix26.t()
print(matrix62.is_contiguous())  # 该张量在内存中是否是连续的

"""
tensor([[ 0,  1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10, 11]])
torch.Size([2, 6])
False
"""

# 直接使用view方法会失败。可以使用reshape方法
matrix34 = matrix62.reshape(3, 4)  # 等价于matrix62.contiguous().view(3, 4)
print(matrix34)

"""
tensor([[ 0,  6,  1,  7],
        [ 2,  8,  3,  9],
        [ 4, 10,  5, 11]])
"""

4.张量和numpy数组

可以用numpy方法从Tensor得到numpy数组,也可以用torch.from_numpy从numpy数组得到Tensor

这两种方法关联的Tensor和numpy数组是共享内存的。

如果改变其中一个,另外一个的值也会发生改变。

如果有需要,可以用张量的clone方法拷贝张量,中断这种关联。

此外,还可以使用item方法从标量张量得到对应的Python数值。

使用tolist方法从张量得到对应的Python数值列表。

arr = np.zeros(3)
tensor = torch.from_numpy(arr)
print('before add 1:')
print(arr)
print(tensor)

print('\nafter add 1:')
np.add(arr, 1, out=arr)
print(arr)
print(tensor)

"""
before add 1:
[0. 0. 0.]
tensor([0., 0., 0.], dtype=torch.float64)

after add 1:
[1. 1. 1.]
tensor([1., 1., 1.], dtype=torch.float64)
"""

# numpy方法从Tensor得到numpy数组
tensor = torch.zeros(3)
arr = tensor.numpy()
print('before add 1:')
print(tensor)
print(arr)

print('\nafter add 1:')
# 使用带下划线的方法表示计算结果会返回给调用张量
tensor.add_(1)  # torch.add(tensor, 1, out=tensor)
print(tensor)
print(arr)

"""
before add 1:
tensor([0., 0., 0.])
[0. 0. 0.]

after add 1:
tensor([1., 1., 1.])
[1. 1. 1.]
"""

# 可以使用clone方法中断这种关联
tensor = torch.zeros(3)
arr = tensor.clone().numpy()
print('before add 1:')
print(tensor)
print(arr)

print('\nafter add 1:')
tensor.add_(1)
print(tensor)
print(arr)

"""
before add 1:
tensor([0., 0., 0.])
[0. 0. 0.]

after add 1:
tensor([1., 1., 1.])
[0. 0. 0.]
"""

# item方法和tolist方法可以将张量转换成Python数值和数值列表
scalar = torch.tensor(1.0)
s = scalar.item()
print(s)
print(type(s))

tensor = torch.rand(2, 2)
t = tensor.tolist()
print(t)
print(type(t))

"""
1.0
<class 'float'>
[[0.9156631827354431, 0.2283121943473816], [0.969607412815094, 0.8414113521575928]]
<class 'list'>
"""

标签:tensor,dtype,torch,张量,add,print,数据结构
From: https://www.cnblogs.com/lotuslaw/p/17922349.html

相关文章

  • [2024深圳市考][计算机素质测试考纲](二)算法和数据结构
    前言因篇幅有限,本文仅对考纲中的考点做基本介绍。更详细的内容请自行学习:【双语字幕】CS61B数据结构|整合版|UCBDataStructureSpring2021【中英双字】普林斯顿大学-算法分析AlgorithmAnalysis2015COS423一、基本概念二、数组三、链表四、栈和队列五、递......
  • Week1——STL 与基础数据结构专题训练
    https://blog.csdn.net/qq_46025844/article/details/127948957 实训概要实训专题STL与基础数据结构专题训练实训目的掌握STL常用的算法、容器、容器适配器的使用方法。能够利用STL的算法、容器、容器适配器求解问题。题目列表A:摘苹果B:立方和C:计算个数D:后缀表达式的值E:做蛋糕......
  • 【数据结构】第二章——线性表(2)
    线性表的顺序表示导言大家好,很高兴又和各位见面啦!!!在上一个篇章中,我们简单了解了一下线性表的基础知识以及一下重要的术语。在今天的篇章中我们将来开始正式介绍线性表的顺序存储——又称顺序表。我们将会在本章介绍什么是顺序表,对于顺序表的操作我们又应该如何实现。接下来,我们就来......
  • 金牌导航-数据结构优化DP
    数据结构优化DP例题A题解设\(f_{i,j}\)表示以第\(i\)位为结尾,长度为\(j\)的严格单调上升子序列的数量。那么显然有\(f_{i,j}=\sum_{k=1}^{i-1}f_{k,j-1}\times(a_k<a_i)\)然后发现这玩应\(O(n^2m)\)直接寄掉了。考虑优化。发现只有当\(a_k<a_i\)时才会有贡献。......
  • python 数据结构与算法知识图
    1.算法思想:递归、分治(归并排序、二分查找、快速排序)、贪心(贪心策略排序+当前最优)、动态规划(最优子结构+递推式)、回溯(解空间:排列树+子集树、深度搜索+剪枝)、分支限界(解空间:排列树+子集树、广度搜索+剪枝))2.排序算法:(low:冒泡、插入、选择;mid:快排、归并、堆排(完全二叉树),其他:桶排序、基......
  • Databend 源码阅读: Meta-service 数据结构
    作者:张炎泼(XP)DatabendLabs成员,Databend分布式研发负责人https://github.com/drmingdrmer引言Databend是一款开源的云原生数据库,采用Rust语言开发,专为云原生数据仓库的需求而设计。面向云架构:Databend是完全面向云架构的数据库,可以在云环境中灵活部署和扩展简介|......
  • Databend 源码阅读: Meta-service 数据结构
    作者:张炎泼(XP)DatabendLabs成员,Databend分布式研发负责人https://github.com/drmingdrmer引言Databend是一款开源的云原生数据库,采用Rust语言开发,专为云原生数据仓库的需求而设计。面向云架构:Databend是完全面向云架构的数据库,可以在云环境中灵活部署和扩展简介|......
  • 数据结构
    数据结构有:1.数组;2.栈;3.队列;4.链表(单链表、双向链表、循环链表);5.数;6.散列表;7.堆;8.图。一、数组内存连续,可通过元素下标访问。二、栈先进后出三、队列先进先出四、链表物理存储不连续,因为存储了相邻元素的物理地址,所以逻辑上连续。五、树每个节点有零个或多个子节点;没......
  • 【面试官版】【持续更新中】融合滤波算法+数据结构+激光视觉SLAM+C++面试题汇总
    C++部分什么时候需要写虚函数、什么时候需要写纯虚函数?只继承接口为纯虚函数强调覆盖父类重写,或者父类也需要实现一定的功能,为虚函数指针传参和引用传参区别?引用传参本质上是传递原参数地址,指针传参本质还是值传递,生成拷贝指针,拷贝指针和原指针指向的为同一块内存。因此改变......
  • 数据结构之<图>的介绍
    图(Graph)的概念:在数据结构中,图是由节点(顶点)和边组成的非线性数据结构。图用于表示不同对象之间的关系,其中节点表示对象,边表示对象之间的连接或关系。1.图的基本组成元素:节点(Vertex或Node):表示图中的实体或对象。节点可以有不同的属性和值。在某些情况下,节点也被称为顶点。边(Edge):......