首页 > 编程语言 >Python pytorch 坐标系变换与维度转换

Python pytorch 坐标系变换与维度转换

时间:2024-04-22 11:37:06浏览次数:23  
标签:tensor Python contiguous per pytorch 维度 array size view

前言

深度学习中经常要用到张量坐标系变换与维度转换,因此记录一下,避免混淆

坐标系变换

坐标系变换(矩阵转置),主要是调换tensor/array的维度

pytorch

import torch

def info(tensor):
    print(f"tensor: {tensor}")
    print(f"tensor size: {tensor.size()}")
    print(f"tensor is contiguous: {tensor.is_contiguous()}")
    print(f"tensor stride: {tensor.stride()}")

tensor = torch.rand([1,2,3])
info(tensor)

# output:
# tensor: tensor([[[0.9516, 0.2289, 0.0042],
#          [0.2808, 0.4321, 0.8238]]])
# tensor size: torch.Size([1, 2, 3])
# tensor is contiguous: True
# tensor stride: (6, 3, 1)

per_tensor = tensor.permute(1,2,0)
info(per_tensor)

# output:
# tensor: tensor([[[0.9516, 0.2808],
#          [0.2289, 0.4321],
#          [0.0042, 0.8238]]])
# tensor size: torch.Size([1, 3, 2])
# tensor is contiguous: False
# tensor stride: (6, 1, 3)

numpy

import numpy as np

def np_info(array):
    print(f"array: {array}")
    print(f"array size: {array.shape}")
    print(f"array is contiguous: {array.flags['C_CONTIGUOUS']}")
    print(f"array stride: {array.strides}")

array = np.random.rand(1,2,3)
np_info(array)

# output:
# array: [[[0.58227139 0.32251543 0.12221412]
#   [0.72647191 0.42323578 0.65290986]]]
# array size: (1, 2, 3)
# array is contiguous: True
# array stride: (48, 24, 8)

trans_array = np.transpose(array, (0,2,1))
np_info(trans_array)

# output:
# array: [[[0.58227139 0.72647191]
#   [0.32251543 0.42323578]
#   [0.12221412 0.65290986]]]
# array size: (1, 3, 2)
# array is contiguous: False
# array stride: (48, 8, 24)

所以对于高维的tensor来说,其实并没有改变数据的相对位置,只是旋转了这个data的(超)立方体,即改变(超)立方体的观察角度

维度变换

tensor.view()

view()主要是将tensor转化为想要的张量尺寸,但并不影响contiguous属性
view()相当于tensor的一个引用,通过它会直接对原tensor进行操作,不会产生拷贝,输出和输入是共享内部存储的

view_tensor = tensor.view(3,2,1)
info(view_tensor)

# output:
# tensor: tensor([[[0.9516],
#          [0.2289]],
# 
#         [[0.0042],
#          [0.2808]],
# 
#         [[0.4321],
#          [0.8238]]])
# tensor size: torch.Size([3, 2, 1])
# tensor is contiguous: True
# tensor stride: (2, 1, 1)

但当对contiguous为false的tensor进行view操作时,则会报错

view_per_tensor  = per_tensor.view(2,3) 

#output:
# ---------------------------------------------------------------------------
# RuntimeError                              Traceback (most recent call last)
# /tmp/ipykernel_388070/1679121630.py in <module>
# ----> 1 view_per_tensor  = per_tensor.view(2,3)
#       2 # info(per_tensor)
#       3 info(view_per_tensor)
#       4 print(view_per_tensor.data_ptr() == per_tensor.data_ptr())

# RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

tensor.reshape()

torch.Tensor.reshape()可以对任意tensor进行操作,相当于torch.Tensor.view() + torch.Tensor.contiguous().view(),也就是说,reshape操作也不一定会开辟新的内存空间,如果tensor是连续的话,实际上调用的view的实现,而当tensor不连续且步长不兼容的时候,就会对tensor进行深拷贝。

reshape_per_tensor = per_tensor.reshape(2,3) 
info(reshape_per_tensor)

# output:
# tensor: tensor([[0.9384, 0.9049, 0.8476],
#         [0.5196, 0.7949, 0.0637]])
# tensor size: torch.Size([2, 3])
# tensor is contiguous: True
tensor stride: (3, 1)

Ref

  1. https://blog.csdn.net/wulele2/article/details/127337439
  2. https://blog.csdn.net/wxfighting/article/details/122758553

标签:tensor,Python,contiguous,per,pytorch,维度,array,size,view
From: https://www.cnblogs.com/liuliu55/p/18150227

相关文章

  • python中列表、字典和字符串的互相转换
    我们在python使用中经常会用到需要把字符串转为list或者字典,及把list或字典转为字符串(写文件,f.write()只能写字符串,插入数据库时,也只能用字符串)具体使用方法总结了一下:1、字符串转lists='a,b,c'l=s.split(',')  #把字符串s以逗号分割,分割出的list给到l ......
  • python os库将字符串转化为路径
    前言在python编程中,经常需要对文件进行读取操作,而os库提供了一些方法处理文件和目录的路径官方文档如下:https://docs.python.org/zh-cn/3/library/os.html本文主要记录如何将字符串转化为路径1.os.path.join()主要将多个字符串进行拼接,从而形成路径importosos.path.join......
  • 很强!4.7k star,推荐一款Python工具,可实现自动化操作!!
    1、介绍在日常工作中,肯定会遇到一些重复性的工作,不管是点击某个按钮、写东西,打印东西,还是复制粘贴拷贝资料之类的,需要进行大量的重复操作。按键精灵大家都听说过,传统的方式,大家可以使用按键精灵将操作录制一遍,形成脚本,剩余的工作让计算机自动循环执行,应对这些重复性的任务。但今......
  • python 二进制序列类型 bytes 和 bytearray
    bytesbytes定义bytes是一个不可变序列,用于存储字节数据。bytes对象包含范围在0到255之间的整数序列,通常用于处理二进制数据、文本数据的字节表示、以及网络通信中的原始数据传输。创建bytes对象使用b'...'表示字节字符串,各个字符以ASCII对应的单字节值表示。使用byte......
  • 用 Python(PyVISA) 实现仪器自动化
    本文介绍一个远程仪器控制的例子,包含一些Python脚本实现自动在示波器上进行简单的测量。Python介绍Python是免费和开源的,它为核心开发人员提供了责任、庞大的支持基础以及Python用户检查和改进其代码库的能力。Python有很多包用来扩展了Python的基本功能。Python的包可......
  • 人工智能:更多有用的 Python 库
    #为什么你选择成为一名程序员?#​目录推荐JupyterLab入门复杂的矩阵运算其它人工智能和机器学习的Python库推荐前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站在进入主题之前,我们先讨论几个人工智能和机器学习中常用的重......
  • Python 使用Snap7读写西门子S7系列PLC
    1.简介Snap7Snap7是一个基于s7通信协议的开源软件包,作者是DavideNardella,该软件包封装了S7通信的底层协议,可使用普通电脑通过编程与西门子S7系列PLC进行通信Snap7三大对象组件:客户端,服务器,合作者。下面是三者关系,更详细介绍可看官网。本篇主要讲述的是Client模式,我们的pc机作......
  • Python环境和PyCharm搭建教程
    1、python下载和安装1、访问Python官网:https://www.python.org/ 2、以Windows为例,我们选择一个稳定的版本进行安装,这里需要注意选择和自己操作系统类型一致的安装包,64位操作系统选择 64-bit/32位操作系统选择 32-bit,x86表示是32位机子/x86-64表示64位机子的。Stabl......
  • 关于Python能再Pycharm上运行而在VSCode下无法运行
    前提项目是由Pycharm创建并且编写,然后复制下来VSCode上运行问题Pycharm写了一个项目,项目的某个文件A要调用到项目其他文件B的某个方法b,在上运行Pycharm没问题,VSCode复制下来该干的都干了(依赖安装,venv环境),但是运行的时候就是报错说,找不到模块B的路径,但是点引用却又能转到对应的......
  • 基于python语言命令行模式的nmap扫描- python-nmap
    使用python命令行模式进行nmap扫描,简化流程首先安装python环境https://www.python.org/downloads/安装nmap,python-nmap需要借助nmap运行https://nmap.org/download写python-nmap脚本https://github.com/home-assistant-libs/python-nmap代码:点击查看代码......