首页 > 其他分享 >动手学强化学习(八.1):torch.gather

动手学强化学习(八.1):torch.gather

时间:2024-03-05 17:14:29浏览次数:28  
标签:dim gather tensor 索引 torch t1 动手

tensor.gather()的作用就是按照索引取对应的数据出来。之前看图解PyTorch中的torch.gather函数,那个图示看得我有点懵逼,所以自己画了两张图总结了一下规律来理解一下。

首先新建一个3*3的二维矩阵。

import torch
​
t1 = torch.tensor([[1, 2, 3],
                   [4, 5, 6],
                   [7, 8, 9]])

tensor.gather()主要有两个参数,第一个是dim,用来判断是对行还是列进行索引;第二个是索引的矩阵(这个必须是tensor,不能是list类型),这个索引是令人不太好理解的地方,下面我用两三个例子帮助理解一下。

按列取 -> "上下结构"

按列取,那么dim = 0。

t1.gather(dim = 0, index = torch.tensor([[1, 2, 0]]))
​
# tensor([[4, 8, 3]])
▲这是一个常规的例子,分别对每一列拿一个值出来组成新的矩阵,注意图中标示的颜色对应
t1.gather(dim = 0, index = torch.tensor([[1],
                                         [2],
                                         [0]]))
# tensor([[4],
#         [7],
#         [1]])
▲这个实际上是一个奇怪的例子,只取第一列的值,注意结合上图进行理解

按照上面的逻辑,你就可以尝试各种奇怪的索引:

t1.gather(dim = 0, index = torch.tensor([[1, 0],
                                         [2, 1],
                                         [0, 0]]))
# tensor([[4, 2],
#         [7, 5],
#         [1, 2]])

按行取 -> "左右结构"

这里需要将dim=1。

t1.gather(dim = 1, index = torch.tensor([[1, 2, 0]]))
​
# tensor([[2, 3, 1]])
▲因为设置的是按行索引,自然我们的矩阵要放到左边来成为“左右结构”
t1.gather(dim = 1, index = torch.tensor([[1], [2], [0]]))
​
# tensor([[2],
#         [6],
#         [7]])
▲对每一行取一个数值

不过我感觉按行还是按列是异曲同工,按照你自己的习惯来吧。

除此之外也可以类似于numpy的直接用索引值去抓:

t1[[0, 2], [2, 1]]
# tensor([3, 8])

标签:dim,gather,tensor,索引,torch,t1,动手
From: https://www.cnblogs.com/zhangxianrong/p/18054436

相关文章

  • 动手学强化学习(八.2):double-DQN
    一、代码importrandomimportgymimportnumpyasnpimporttorchimporttorch.nn.functionalasFimportmatplotlib.pyplotaspltimportrl_utilsfromtqdmimporttqdmclassQnet(torch.nn.Module):'''只有一层隐藏层的Q网络'''de......
  • Windows环境下Pytorch项目搭建在Docker中运行
    Windows环境下Pytorch项目搭建在Docker中运行1.安装windows版本的Docker​ 网上已有诸多博客教程,这里就不再赘述。2.搭建本地Pytorch环境​ 搭建本地Pytorch环境的方式我使用了两种方式,推荐使用第一种。​ 第一种:​ (1)在dockerhub中(https://hub.docker.com),找到自己版本......
  • 动手学强化学习(七.1):DQN 算法代码
    一、代码如下:importrandomimportgymimportnumpyasnpimportcollectionsfromtqdmimporttqdmimporttorchimporttorch.nn.functionalasFimportmatplotlib.pyplotaspltimportrl_utilsclassReplayBuffer:'''经验回放池'''......
  • 动手学强化学习(七):DQN 算法
    第7章DQN算法7.1简介在第5章讲解的Q-learning算法中,我们以矩阵的方式建立了一张存储每个状态下所有动作\(Q\)值的表格。表格中的每一个动作价值\(Q(s,a)\)表示在状态\(s\)下选择动作\(a\)然后继续遵循某一策略预期能够得到的期望回报。然而,这种用表格存储动作价值的做......
  • 动手学强化学习(六):Dyna-Q
    第6章Dyna-Q算法6.1简介在强化学习中,“模型”通常指与智能体交互的环境模型,即对环境的状态转移概率和奖励函数进行建模。根据是否具有环境模型,强化学习算法分为两种:基于模型的强化学习(model-basedreinforcementlearning)和无模型的强化学习(model-freereinforcementlearn......
  • 动手学强化学习(五):时序差分算法代码
    一、单步sarsaimportmatplotlib.pyplotaspltimportnumpyasnpfromtqdmimporttqdm#tqdm是显示循环进度条的库classCliffWalkingEnv:def__init__(self,ncol,nrow):self.nrow=nrow#4self.ncol=ncol#12self.x=0#记录......
  • 动手学强化学习(五):时序差分算法
    第5章时序差分算法5.1简介第4章介绍的动态规划算法要求马尔可夫决策过程是已知的,即要求与智能体交互的环境是完全已知的(例如迷宫或者给定规则的网格世界)。在此条件下,智能体其实并不需要和环境真正交互来采样数据,直接用动态规划算法就可以解出最优价值或策略。这就好比对于......
  • 动手学强化学习(五):值迭代与策略迭代代码
    一、策略迭代importcopyclassCliffWalkingEnv:"""悬崖漫步环境"""def__init__(self,ncol=12,nrow=4):self.ncol=ncol#定义网格世界的列self.nrow=nrow#定义网格世界的行#转移矩阵P[state][action]=[(p,next_state,......
  • 动手学强化学习(四):动态规划算法
    第4章动态规划算法4.1简介动态规划(dynamicprogramming)是程序设计算法中非常重要的内容,能够高效解决一些经典问题,例如背包问题和最短路径规划。动态规划的基本思想是将待求解问题分解成若干个子问题,先求解子问题,然后从这些子问题的解得到目标问题的解。动态规划会保存已解决......
  • pytorch报错:Variable._execution_engine.run_backward( # Calls into the C++ engine
    GPU模式下运行pytorch代码报错,pytorch为2.2.1,NVIDIA驱动版本535.161.07File"/home/devil/anaconda3/envs/sample-factory/lib/python3.11/site-packages/torch/_tensor.py",line522,inbackwardtorch.autograd.backward(File"/home/devil/anaconda3/envs/sample-......