首页 > 其他分享 >Pytorch torch.mean() 平均值的简单用法

Pytorch torch.mean() 平均值的简单用法

时间:2023-06-04 10:33:53浏览次数:35  
标签:dim tensor torch Pytorch print 2.5000 mean

Pytorch torch.mean()的简单用法

简单来说就是求平均数。
比如以下的三种简单情况:

import torch

x1 = torch.Tensor([1, 2, 3, 4])
x2 = torch.Tensor([[1],
                   [2],
                   [3],
                   [4]])
x3 = torch.Tensor([[1, 2],
                   [3, 4]])
y1 = torch.mean(x1)
y2 = torch.mean(x2)
y3 = torch.mean(x3)
print(y1)
print(y2)
print(y3)


输出:
tensor(2.5000)
tensor(2.5000)
tensor(2.5000)
 
也就是说,在没有指定维度的情况下,就是对所有数进行求平均

更多的时候用到的是有维度的情形,如:
二维张量求均值:

import torch

x = torch.Tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
y_0 = torch.mean(x, dim=0) ##  每列求均值
y_1 = torch.mean(x, dim=1) ###  每行求均值
print(x)
print(y_0)
print(y_1)

 


输出:

tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([2.5000, 3.5000, 4.5000])
tensor([2., 5.])
 

输入tensor的形状为(2, 3),其中2为第0维,3为第1维。对第0维求平均,得到的结果为形状为(1, 3)的tensor;对第1维求平均,得到的结果为形状为(2, 1)的tensor。
可以理解为,对哪一维做平均,就是将该维所有的数做平均,压扁成1层(实际上这一层就给合并掉了,比如上面的例子,2维的tensor在求平均数后变成了1维),而其他维的形状不影响。
如果要保持维度不变(例如在深度网络中),则可以加上参数keepdim=True:

y = torch.mean(x, dim=1, keepdim=True)
 

三维张量求均值:

 

 

import torch
import numpy as np

# ======初始化一个三维矩阵=====
A = torch.ones((4,3,2))

# ======替换三维矩阵里面的值======
A[0] = torch.ones((3,2)) *1
A[1] = torch.ones((3,2)) *2
A[2] = torch.ones((3,2)) *3
A[3] = torch.ones((3,2)) *4

print(A)

B = torch.mean(A ,dim=0)
print(B)

B = torch.mean(A ,dim=1)
print(B)

B = torch.mean(A ,dim=2)
print(B)

输出结果

tensor([[[1., 1.],
         [1., 1.],
         [1., 1.]],

        [[2., 2.],
         [2., 2.],
         [2., 2.]],

        [[3., 3.],
         [3., 3.],
         [3., 3.]],

        [[4., 4.],
         [4., 4.],
         [4., 4.]]])
tensor([[2.5000, 2.5000],
        [2.5000, 2.5000],
        [2.5000, 2.5000]])
tensor([[1., 1.],
        [2., 2.],
        [3., 3.],
        [4., 4.]])
tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.],
        [4., 4., 4.]])

 

 

 

 REF

https://blog.csdn.net/qq_40714949/article/details/115485140

标签:dim,tensor,torch,Pytorch,print,2.5000,mean
From: https://www.cnblogs.com/emanlee/p/17455263.html

相关文章

  • cmd+ssh配置远程服务器Anaconda3_2023+pytorch
    一、上传Anaconda3到远程服务器注意:如果要将这个东西安装在anaconda3文件夹里的话,当前这个目录里不能有这个文件夹。(安的时候会自动创建) 二、安装Anaconda31.win+r快捷键打开cmd输入ssh 可以看到已经与服务器建立连接 2.输入ssh<用户名>@主机IP......
  • 【Pytorch】ValueError: not enough values to unpack (expected 2, got 1)问题解决
    在运行开源项目时出现了这个问题,网上很多说删回车或者都改成英文符号,但是我都试了,没用后来自己摸索出的方法是:先更改数据集的格式,之前分隔符是\t,把数据集中的分隔符改成空格,再把语句中的\t也换成空格,然后就不会报错了。改前:改后:......
  • pytorch 训练 RuntimeError Unable to find a valid cuDNN algorithm to run convolut
    pytorch训练RuntimeError:UnabletofindavalidcuDNNalgorithmtorunconvolutionpytorch训练RuntimeError:UnabletofindavalidcuDNNalgorithmtorunconvolution#问题描述:python:3.95pytorch:1.10.2pythontrain.py--img640--batch64--epochs600--da......
  • Pytorch 分布式训练
    PytorchDDP分布式训练介绍近期一直在用torch的分布式训练,本文调研了目前Pytorch的分布式并行训练常使用DDP模式(DistributedDataParallell ),从基本概念,初始化启动,以及第三方的分布式训练框架展开介绍。最后以一个Bert情感分类给出完整的代码例子:torch-ddp-examples。基本......
  • Pytorch rendezvous 分布式
    PyTorch中的rendezvous后端是一种服务,它帮助分布式训练作业中的进程相互发现并协商角色和等级。它还提供了一个屏障和一个一致的作业成员和状态视图。 rendezvous后端是作为torch.distributed.elastic.rendezvous.RendezvousHandler的子类实现的,它定义了创建、加入和销毁rendez......
  • 安装pytorch
    pytorch官网https://pytorch.org/创建一个环境名为:pytorchpython版本为3.9激活;然后输入:condainstallpytorchtorchvisiontorchaudiopytorch-cuda=11.7-cpytorch-cnvidia安装最好离线安装测试是否安装成功importtorch......
  • python spark kmeans demo
    官方的demofromnumpyimportarrayfrommathimportsqrtfrompysparkimportSparkContextfrompyspark.mllib.clusteringimportKMeans,KMeansModelsc=SparkContext(appName="clusteringExample")#Loadandparsethedatadata=sc.textFile("/......
  • spark Bisecting k-means(二分K均值算法)
    Bisectingk-means(二分K均值算法)    二分k均值(bisectingk-means)是一种层次聚类方法,算法的主要思想是:首先将所有点作为一个簇,然后将该簇一分为二。之后选择能最大程度降低聚类代价函数(也就是误差平方和)的簇划分为两个簇。以此进行下去,直到簇的数目等于用户给定的数目K为止。......
  • MATLAB用改进K-Means(K-均值)聚类算法数据挖掘高校学生的期末考试成绩|附代码数据
    全文链接:http://tecdat.cn/?p=30832最近我们被客户要求撰写关于K-Means(K-均值)聚类算法的研究报告,包括一些图形和统计输出。本文首先阐明了聚类算法的基本概念,介绍了几种比较典型的聚类算法,然后重点阐述了K-均值算法的基本思想,对K-均值算法的优缺点做了分析,回顾了对K-均值改进......
  • 在树莓派上实现numpy的LSTM长短期记忆神经网络做图像分类,加载pytorch的模型参数,推理mn
    这几天又在玩树莓派,先是搞了个物联网,又在尝试在树莓派上搞一些简单的神经网络,这次搞得是LSTM识别mnist手写数字识别训练代码在电脑上,cpu就能训练,很快的:importtorchimporttorch.nnasnnimporttorchvisionimportnumpyasnpimportosfromPILimportImage#定义LSTM......