GAE是强化学习中常用的一种advantage计算方法,被经常应用于A3C、A2C、TRPO、PPO中,但是在常见的GAE实现中都是不考虑截断情况下的,也就是truncation情况下,本文给出Google的一种truncation情况下的GAE计算方法的实现。
解释一下什么叫做truncation截断:
在强化学习中agent需要和环境进行一系列的交互从而形成一个连贯时序的episode,而这个episode的程度往往是被设定为有最大值的,即time_limit,当agent在一个时序连贯的交互中episode长度等于time_limit时就需要truncation截断,毕竟常见的就是cartpole问题中的200步长的限制,也就是episode长度最大不能超过200。
代码来源:
https://openi.pcl.ac.cn/devilmaycry812839668/google_brax_ppo_pytorch
@torch.jit.export
def compute_gae(self, truncation, termination, reward, values,
bootstrap_value):
truncation_mask = 1 - truncation
# Append bootstrapped value to get [v1, ..., v_t+1]
values_t_plus_1 = torch.cat(
[values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0)
deltas = reward + self.discounting * (
1 - termination) * values_t_plus_1 - values
deltas *= truncation_mask
acc = torch.zeros_like(bootstrap_value)
vs_minus_v_xs = torch.zeros_like(truncation_mask)
for ti in range(truncation_mask.shape[0]):
ti = truncation_mask.shape[0] - ti - 1
acc = deltas[ti] + self.discounting * (
1 - termination[ti]) * truncation_mask[ti] * self.lambda_ * acc
vs_minus_v_xs[ti] = acc
# Add V(x_s) to get v_s.
vs = vs_minus_v_xs + values
vs_t_plus_1 = torch.cat([vs[1:], torch.unsqueeze(bootstrap_value, 0)], 0)
advantages = (reward + self.discounting *
(1 - termination) * vs_t_plus_1 - values) * truncation_mask
return vs, advantages
这个实现中的一个关键就是GAE计算中的delta,如果一个timestep的下一步是truncation截断,那么这一步的delta直接赋值为0,而通过delta计算出advantage之后也需要将下一步为truncation截断的哪个timestep对应的advantage赋值为0,具体如:(与truncation_mask相乘)
advantages = (reward + self.discounting *
(1 - termination) * vs_t_plus_1 - values) * truncation_mask
不过需要注意的是,在大多数的情况下是可以不对截断情况作区分的,常见的方法就是直接将截断情况视作termination,也就是episode主动结束,即game over!!!因为大多数情况下,我们都是希望episode越长越好的,因此我们会设置一个比较大的time limit,比如这里的示例代码所在的例子中就是1000,而cartpole则为200,而除了time limit是一个较小数值以外的情况时这个截断都可以不用考虑的,因为并不会对整体算法性能有明显影响的。
个人github博客地址:
https://devilmaycry812839668.github.io/