首页 > 其他分享 >offline RL | 读读 Decision Transformer

offline RL | 读读 Decision Transformer

时间:2024-02-27 20:23:27浏览次数:23  
标签:Transformer return Decision attention go embedding RL offline DT



目录


1 Transformer 是一种 seq2seq 建模方法

(著名的 GPT 的全称是 Generative Pre-trained Transformer)

学习 Transformer:

seq2seq 的输入输出:

  • 在 nlp 领域貌似是 word embedding,然后再使用 word2vec 之类得到单词(?)

Attention 与 Transformer:

  • attention:
    • key query value:key 用来提取关键信息、query 用来提取查询、value 用来提取值。它们具有矩阵形式,用 k q v 矩阵去乘输入的 vector,得到 k q v 的 vector。
    • k q v 举例:希望投票选举,query - 评委的重要程度、key - 评委的职称(?)、value - 评委的投票结果,最后按照 (query × key^T) × value 的形式,对投票结果进行加权计算。 multi-head-attention 就是使用多组 k q v,可能表示我们希望关注多个方面。
  • encoder:
    • 一个 encoder 块包含一个 attention + 一个 feed forward 层(大概就是全连接层)。
    • 我们使用 k q v 的 attention 模块,一下对所有 token 的矩阵(维度 num of tokens × embedding size)得到一个 latent z(维度 num of tokens × latent size)。
    • 然后对于每个原句子中的 token,各过各的全连接层(feed forward)(?)最后得到一个 维度 num of tokens × embedding size 的矩阵。
    • 残差连接(Res):将输入和 多头注意力层 或全连接神经网络的输出 相加,再传递给下一层,避免梯度递减的问题。
  • decoder:
    • 一个 decoder 块包含两个 attention + 一个 feed forward 层。
    • attention 1 用来处理自己输出的信息,因此它在说第 n 个单词之前,只能以自己说出的前 n-1 个单词作为输入,使用一个掩码(?)来实现:掩码多头自注意力(Masked-Multi-head self attention)。
    • attention 2 用来处理 encoder 给出的 num of tokens × embedding size 的 embedding,attention 1 的输出也是其输入的一部分。
  • 这样,看图应该就能看懂了。
img

2 建模 RL 的 sequence

我们的 sequence:{return-to-go, state, action, return-to-go, ...}

  • 形式类似于 \(\{s_t,a_t,r_t,s_{t+1},\cdots\}\) 。
  • return-to-go: \(\hat R_t=\sum_{t'=t}^Tr_{t'}\) ,是从此刻 t 到 episode 结束的,in-discounted reward 的加和。
  • 感觉 return-to-go 类似于 HER 的预期目标,比较 hindsight。

3 如何训练 DT

对 sequence {s, a, R, s, a, R, ...} 进行处理:

  • 对每个 modality(s a R),都学习了将它们转换为 embedding 的线性层。
  • 对于具有视觉输入的环境,状态被输入到卷积编码器而不是线性层中。
  • 此外,每个时间步的 embedding 都会被学习并添加到每个 token 中 —— 这与 transformer 使用的 positional embedding(三角函数?)不同,因为一个时间步对应于三个 token。

(搬图,搬运文字说明)

img
  • 一条轨迹按照 s a R 顺序排列好后,每个元素都是图 1 下部的一个小圆圈,类似于 NLP 中的一个个单词。
  • 然后,每个元素经过一个 mlp 做 embedding 后,再加上 position encoding,就得到了 tokens,也就是图 1 下部的一个个五颜六色的小长方块。

训练:

  • 使用 offline trajectory 的 dataset(D4RL 之类)。
  • 从离线轨迹数据集中,抽取 sequence 长度为 K 的 minibatch。
  • 训练:对 input token \(s_t\) 的那个 prediction head,再加一个 mlp 来预测 \(a_t\) (上图上部输出 \(a_t\) 的橙色方块)。
    • 训 action 时,对离散动作使用 cross-entropy loss,连续则使用 MSE。
    • DT 每隔三个 token 才 decode 一个,因为作为 policy 只需要输出 action。但其实,output tokens 由对应 return-to-go、state、action 的 token 组成,所以自然只留下对应 action 的 tokens(?)
    • 发现,去预测 state 或 return-to-go 并不能提高性能,尽管在 DT 框架里,很容易这么做。
  • 上述训练部分是想让 DT 学会,在某个特定状态 s 下,达到 return-to-go R,所需要做的动作 a。
  • 详见 Algorithm 1 伪代码,感觉写的很清楚。

4 如何部署 DT policy / 如何 inference

  • inference 过程就是,首先提出一个 target return(我们希望 agent 在一个 episode 里能达到的 return),作为初始的 return-to-go,然后 DT 按照训练过程中学到的 如何达成 return-to-go 的方法,选择 action。
  • 每走一步,就将上一步的 return-to-go 减去这一步的 reward,得到下一个 return-to-go,从而不断地更新我们期望 DT 达到的 return 目标,同时 DT 根据我们的目标,不断选择 action。
  • 详见 Algorithm 1 伪代码。
  • evaluate 时,只保留 length = K 的 context,对应于前面训练时 sequence length = K。
    • 通常认为,当使用 frame stacking(Atari 的帧堆叠)时,K = 1 已经 MDP,足以用于 RL 算法。然而,当 K = 1 时,Decision Transformer 的性能明显更差,这表明过去的信息对 Atari 游戏有用(非 MDP?)。(具体实验中,Atari 的 K = 30 50,MuJoco 的 K = 5 20)
    • 一个假设是,当我们表示 一些策略的分布时(例如序列建模),上下文允许 transformer 识别,哪个策略生成了该动作,从而实现更好的学习和 / 或改进训练动态。

5 技术细节

  • 训练的一些超参数,encoder / decoder:可参见 Table 8 9。
  • 在 inference 过程中,如何选择 return-to-go:使用 dataset 中最大 return 的一倍或 5 倍。
  • Warmup tokens 为 512 ∗ 20,是 <begin> <end> 这种 token 嘛?

6 一些讨论

  • (Section 5.7)DT 为什么不需要 pessimistic value 或行为正则化?作者猜想:pessimistic value 和行为正则化是为了避免 value function approximation 带来的问题,但 DT 并不需要显式优化一个函数(?)
  • (Section 5.8)声称 DT + Go-Explore(RL exploration 方法,感觉像打表)可以帮助 online policy。
  • Credit assignment 貌似是一类工作,通过分解 reward function,使得某些“重要”状态包含了大部分 credit。


标签:Transformer,return,Decision,attention,go,embedding,RL,offline,DT
From: https://www.cnblogs.com/moonout/p/18037872

相关文章

  • Hyperledger Fabric出块配置详解
    HyperledgerFabric的出块主要是Orderer节点负责,出块配置位于创世区块中,支持定时出块、达到一定交易数出块两种条件。出块配置位于configtx.yaml中,修改出块配置后需要重新生成创世区块。相关参数若需要修改fabric的出块机制,则需要调整以下配置参数:BatchTimeout:出块超时时间,最......
  • 【报错解决】【Mathtype】lease restart Word to load MathType addin properly
    打开Mathtype安装目录例如我的C:\software\MathModel\MATHTYPE继续进入目录C:\software\MathModel\MATHTYPE\MathPage\64找到MathPage.wll复制该文件在word中查看加载项的路径,将刚才复制的文件粘贴到这个路径当中返回上一级目录,再次粘贴这个文件重启word,问题解决......
  • uniGui用UniURLFrame1填写表单
    参考自带例子:C:\ProgramFiles(x86)\FMSoft\Framework\uniGUI\Demos\Desktop\HTTPPostCallback-URLFrame-AutoTarget 添加步骤1] 2] 3] ......
  • Docker-Overlay2磁盘空间爆满清理方法
    Docker-Overlay2磁盘空间爆满清理方法在日常线上环境中,我们通常会来做利用Docker来做容器化管理,通过运行容器来执行任务等。但是,随着业务量的不断增大,容器的不断启动,往往会出现磁盘空间不足,1、第一种情况:是因为docker中部署的系统中日志内容的不断扩大。这种情况下,我们可手动,或定......
  • AI与人类联手,智能排序人类决策:RLHF标注工具打造协同标注新纪元,重塑AI训练体验
    AI与人类联手,智能排序人类决策:RLHF标注工具打造协同标注新纪元,重塑AI训练体验在大模型训练的RLHF阶段,需要人工对模型生成的多份数据进行标注排序,然而目前缺乏开源可用的RLHF标注平台。RLHF标注工具是一个简单易用的,可以在大模型进行RLHF(基于人类反馈的强化学习)标注排序的......
  • 从零开始学Spring Boot系列-Hello World
    欢迎来到从零开始学SpringBoot的旅程!在这个系列的第二篇文章中,我们将从一个非常基础但重要的示例开始:创建一个简单的SpringBoot应用程序,并输出“HelloWorld”。1.环境准备首先,确保你的开发环境已经安装了以下工具:JavaDevelopmentKit(JDK):SpringBoot需要Java来运行,所......
  • 读人工不智能:计算机如何误解世界笔记02_Hello,world
    1. Hello,world1.1. “Hello,world”是布赖恩·克尼汉和丹尼斯·里奇于1978年出版的经典著作《C程序设计语言》中的第一个编程项目1.2. 贝尔实验室可以说是现代计算机科学界中的智库,地位好比巧克力界的好时巧克力1.3. 计算机科学界的大量创新都起源于贝尔实验室1.3.1. 激......
  • 05 Hello World
    05HelloWorld随便新建一个文件夹,存放代码新建一个java文件txt(文件后缀名为).javaHello.java【注意点】系统可能没显示文件后缀名,我们需要手动打开使用Notpad++编写代码publicclassHello{ publicstaticvoidmain(String[]args){ System.out.print("......
  • ../inst/include/Eigen/src/Core/MathFunctions.h:487:16: error: no member named 'R
    Asmentionedin conda-forge/r-base-feedstock#163(comment),IsuccessfullyinstalledsctransforminMacsiliconM1Maxbyfirstrun exportPKG_CPPFLAGS="-DHAVE_WORKING_LOG1P intheterminalandtheninstallthepackageinR.......
  • P9370 APIO2023 cyberland
    题面:https://www.luogu.com.cn/problem/P9370显然只有从\(0\)出发不经过\(H\)能到达的点是有用的。首先,考虑跑多源最短路,将\(arr=0\)的点都作为源点(当然\(0\)也是源点)。不难发现这样转化后,这些点即可视作\(arr=1\)。对于\(arr=2\)的点,由于能使用除以二技能的次数很......