首页 > 其他分享 >机器学习 - PyTorch里的aggregation

机器学习 - PyTorch里的aggregation

时间:2024-03-16 18:58:41浏览次数:22  
标签:10 机器 tensor dtype torch 数据类型 aggregation PyTorch print

在PyTorch里,可以在tensor里找到min, max, mean, sum 等aggregation值。

直接上代码

import torch 

x = torch.arange(0, 100, 10)
print(x)
print(f"Minimum: {x.min()}")
print(f"Minimum: {torch.min(x)}")
print(f"Maximum: {x.max()}")
print(f"Maximum: {torch.max(x)}")
print(f"Mean: {x.type(torch.float32).mean()}")
print(f"Mean: {torch.mean(x.type(torch.float32))}")
print(f"Sum: {x.sum()}")
print(f"Sum: {torch.sum(x)}")

# 结果如下
tensor([ 0, 10, 20, 30, 40, 50, 60, 70, 80, 90])
Minimum: 0
Minimum: 0
Maximum: 90
Maximum: 90
Mean: 45.0
Mean: 45.0
Sum: 450
Sum: 450


可以在tensor里找到最大值和最小值的位置,用到 torch.argmax()torch.argmin()

print(f"Index where max value occurs: {x.argmax()}")
print(f"Index where min value occurs: {x.argmin()}")

# 结果如下
Index where max value occurs: 9
Index where min value occurs: 0


在深度学习中,会经常出现的问题是tensor的数据类型不对。如果一个tensor的数据类型是 torch.float64 ,而另一个tensor的数据类型是 torch.float32,运行起来就出错了。
要改变tensor的数据类型,可以使用 torch.Tensor.type(dtype=None) 其中的 dtype 参数是你想用的数据类型。

代码如下:

tensor = torch.arange(10., 100., 10.)
print(tensor.dtype)
tensor_float16 = tensor.type(torch.float16)
print(tensor_float16)
tensor_int8 = tensor.type(torch.int8)
print(tensor_int8)

# 输出
torch.float32
tensor([10., 20., 30., 40., 50., 60., 70., 80., 90.], dtype=torch.float16)
tensor([10, 20, 30, 40, 50, 60, 70, 80, 90], dtype=torch.int8)


看到这里了,给个赞呗~

标签:10,机器,tensor,dtype,torch,数据类型,aggregation,PyTorch,print
From: https://blog.csdn.net/BSCHN123/article/details/136765210

相关文章

  • pytorch CV入门 - 汇总
    初次编辑:2024/2/14;最后编辑:2024/3/9参考网站-微软教程:https://learn.microsoft.com/en-us/training/modules/intro-computer-vision-pytorch更多的内容可以参考本作者其他专栏:Pytorch基础:https://blog.csdn.net/qq_33345365/category_12591348.htmlPytorchNLP基础:https......
  • Pytorch基础-汇总
    本教程翻译自微软教程:https://learn.microsoft.com/en-us/training/paths/pytorch-fundamentals/初次编辑:2024/3/1;最后编辑:2024/3/4本教程包含以下内容:介绍pytorch基础和张量操作介绍数据集介绍归一化介绍构建模型层的基本操作介绍自动微分相关知识介绍优化循环(optimiz......
  • 机器学习模型—CatBoost
    机器学习模型—CatBoost作为俄罗斯科技公司Yandex推出的开源机器学习库,CatBoost可以说是当前GradientBoosting算法发展的新里程碑。相较于广为人知的XGBoost,CatBoost在处理类别特征、纵向样本采样和有序训练数据方面做出了创新性的改进,展现了卓越的性能。我们经常遇到......
  • 【机器学习智能硬件开发全解】(五)—— 政安晨:嵌入式系统基本素养【总线、地址、指令集
    在智能硬件领域中,一个核心概念是嵌入式系统,整体结构可以分为以下几个主要组成部分:控制器:控制器是嵌入式系统的核心,负责处理和执行系统中的各种任务和功能。它通常由中央处理器(CPU)和相关的外围设备(如存储器、时钟、中断控制器等)组成。存储器:存储器用于存储系统的程序代码和......
  • 【机器学习】机器学习创建算法第2篇:K-近邻算法【附代码文档】
    机器学习(算法篇)完整教程(附代码资料)主要内容讲述:机器学习算法课程定位、目标,K-近邻算法,1.1K-近邻算法简介,1.2k近邻算法api初步使用定位,目标,学习目标,1什么是K-近邻算法,1Scikit-learn工具介绍,2K-近邻算法API,3案例,4小结。K-近邻算法,1.3距离度量学习目标,1欧式距离,2......
  • pytorch使用pytorch_wavelets包错误:ValueError: step must be greater than zero 错误
    错误描述在使用pytorch_wavelets包的DWT1DInverse时,发现报错信息如下:Traceback(mostrecentcalllast):File"/work/GDN/test/test_DWT.py",line24,inx_=idwt((YL,YH))File"/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py",line550......
  • 机器学习 - PyTorch一些常用的用法
    如果我们要创建2维随机数importtorchrandom_tensor=torch.rand(size=(3,4))print(random_tensor)#输出tensor([[0.0137,0.7773,0.0150,0.2406],[0.6414,0.7830,0.7603,0.1866],[0.8157,0.8269,0.0438,0.0314]])有时候需要通过加......
  • SpaceX 星舰发射「成功一半」;首位具身 AI 机器人面世丨 RTE 开发者日报 Vol.166
       开发者朋友们大家好: 这里是**「RTE开发者日报」**,每天和大家一起看新闻、聊八卦。我们的社区编辑团队会整理分享RTE(RealTimeEngagement)领域内「有话题的新闻」、「有态度的观点」、「有意思的数据」、「有思考的文章」、「有看点的会议」,但内容仅代......
  • 机器学习练手项目-猫狗分类器
    机器学习练手项目-猫狗分类器作者简介:一名后端开发人员,每天分享后端开发以及人工智能相关技术,行业前沿信息,面试宝典。座右铭:未来是不可确定的,慢慢来是最快的。个人主页:极客李华-CSDN博客合作方式:私聊+这个专栏内容:用最低价格鼓励和博主一起在寒假打卡高频大厂算法题,连续一......
  • 工匠的发展与兴衰趋势-机器人篇
    这是一篇纯纯调侃的博客,如有雷同纯属意外。之前,写过:从2050回顾2020,职业规划与技术路径(节选)从2050回顾2020,职业规划与技术路径(节选)补充 未来以“工”为主的就业机会趋势是越来越少,也就是从业人员的感受是越来越卷。学生通常最多困惑或者反馈的现象。如果从零搭建一台......