首页 > 编程语言 > 强化学习代码实战-07 Actor-Critic 算法

强化学习代码实战-07 Actor-Critic 算法

时间:2022-11-15 23:55:51浏览次数:44  
标签:rewards 07 torch state Actor next states Critic action

Actor(策略网络)和 Critic(价值网络)

  • Actor 要做的是与环境交互,并在 Critic 价值函数的指导下用策略梯度学习一个更好的策略。
  • Critic 要做的是通过 Actor 与环境交互收集的数据学习一个价值函数,这个价值函数会用于判断在当前状态什么动作是好的,什么动作不是好的,进而帮助 Actor 进行策略更新。
import random
import gym
import torch
import numpy as np
from matplotlib import pyplot as plt
from IPython import display

env = gym.make("CartPole-v0")
# 智能体状态
state = env.reset()
# 动作空间
actions = env.action_space.n
print(state, actions)

# Actor使用策略梯度更新(接收状态,输出策略),Critic使用价值函数更新(接收状态,输出价值)
actor_model = torch.nn.Sequential(torch.nn.Linear(4, 128),
                                 torch.nn.ReLU(),
                                 torch.nn.Linear(128, 2),
                                 torch.nn.Softmax(dim=1))
critic_model = torch.nn.Sequential(torch.nn.Linear(4, 128),
                                  torch.nn.ReLU(),
                                  torch.nn.Linear(128, 1))


def get_action(state):
    state = torch.FloatTensor(state).reshape(1,4)
    prob = actor_model(state)
    action = random.choices(range(2), weights=prob[0].tolist(), k=1)[0]
    return action

def get_data():
    state = env.reset()
    states = []
    actions = []
    rewards = []
    next_states = []
    dones = []
    
    done = False
    while not done:
        action = get_action(state)
        next_state, reward, done, _ = env.step(action)
        states.append(state)
        actions.append(action)
        rewards.append(reward)
        next_states.append(next_state)
        dones.append(done)
        state = next_state
        
    states = torch.FloatTensor(states).reshape(-1, 4)
    rewards = torch.FloatTensor(rewards).reshape(-1, 1)
    actions = torch.LongTensor(actions).reshape(-1, 1)
    next_states = torch.FloatTensor(next_states).reshape(-1, 4)
    dones = torch.LongTensor(dones).reshape(-1, 1)
    
    return states, rewards, actions, next_states, dones

def test():
    state = env.reset()
    rewards_sum = 0
    done = False
    
    while not done:
        action = get_action(state)
        state, reward, done, _ = env.step(action) 
        rewards_sum += reward
    return rewards_sum

def train():
    optimizer = torch.optim.Adam(actor_model.parameters(), lr=1e-3)
    optimizer_td = torch.optim.Adam(critic_model.parameters(), lr=1e-2)
    
    # 玩N局游戏,每局游戏训练一次
    for epoch in range(1000):
        states, rewards, actions, next_states, dones = get_data()
        # 分batch优化
        current_values = critic_model(states)
        next_state_values = critic_model(next_states) * 0.98
        next_state_values *= (1 - dones)
        next_values = rewards + next_state_values
        # 时序差分误差.单纯使用值,不反向传播梯度. detach:阻断反向梯度传播
        delta = (next_values - current_values).detach()
        
        # actor重新评估动作计算得分
        probs = actor_model(states)
        probs = probs.gather(dim=1, index=actions)
        actor_loss = (-probs.log() * delta).mean()
        # 时序差分loss。均方误差
        critic_loss = torch.nn.MSELoss()(current_values, next_values.detach())
        
        optimizer.zero_grad()
        actor_loss.backward()
        optimizer.step()
        
        optimizer_td.zero_grad()
        critic_loss.backward()
        optimizer_td.step()
        
        if epoch % 100 == 0:
            result = sum([test() for _ in range(50)]) / 50
            print(epoch, result)

 

标签:rewards,07,torch,state,Actor,next,states,Critic,action
From: https://www.cnblogs.com/demo-deng/p/16894494.html

相关文章