首页 > 其他分享 >一个连续动作空间的SAC的例子

一个连续动作空间的SAC的例子

时间:2024-10-10 21:25:09浏览次数:8  
标签:nn 动作 SAC self width state 例子 action hidden

"""My SAC continuous demo"""

import argparse
import copy
import gym
import numpy as np
import torch
import torch.nn.functional as F

from torch import nn
from torch.distributions import Normal


def parse_args() -> argparse.Namespace:
    """Parse arguments."""
    parser = argparse.ArgumentParser(description="Training")
    parser.add_argument(
        "--log_path", type=str, help="Model path", default="./training_log/"
    )
    parser.add_argument(
        "--max_buffer_size", type=int, help="Max buffer size", default=100000
    )
    parser.add_argument(
        "--min_buffer_size", type=int, help="Min buffer size", default=50000
    )
    parser.add_argument("--hidden_width", type=int, help="Hidden width", default=256)
    parser.add_argument(
        "--gamma",
        type=float,
        help="gamma",
        default=0.99,
    )
    parser.add_argument("--tau", type=float, help="tau", default=0.005)
    parser.add_argument(
        "--learning_rate", type=float, help="Learning rate", default=1e-3
    )
    parser.add_argument(
        "--max_train_steps", type=int, help="Max training steps", default=100000
    )
    parser.add_argument("--batch_size", type=int, help="Batch size", default=256)
    parser.add_argument(
        "--evaluate_freqency", type=int, help="Evaluate freqency", default=10000
    )
    return parser.parse_args()


class ReplayBuffer:
    """Replay buffer for storing transitions."""

    def __init__(self, state_dim: int, action_dim: int) -> None:
        self.max_size = int(args.max_buffer_size)
        self.count = 0
        self.size = 0
        self.state = np.zeros((self.max_size, state_dim))
        self.action = np.zeros((self.max_size, action_dim))
        self.reward = np.zeros((self.max_size, 1))
        self.next_state = np.zeros((self.max_size, state_dim))
        self.done = np.zeros((self.max_size, 1))

    def store(
        self,
        state: np.ndarray,
        action: np.ndarray,
        reward: np.ndarray,
        next_state: np.ndarray,
        done: np.ndarray,
    ) -> None:
        """Store a transition in the replay buffer."""
        self.state[self.count] = state
        self.action[self.count] = action
        self.reward[self.count] = reward
        self.next_state[self.count] = next_state
        self.done[self.count] = done
        self.count = (self.count + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)

    def sample(self, batch_size: int) -> tuple:
        """Sample a batch of transitions."""
        index = np.random.choice(self.size, size=batch_size)
        batch_state = torch.tensor(self.state[index], dtype=torch.float)
        batch_action = torch.tensor(self.action[index], dtype=torch.float)
        batch_reward = torch.tensor(self.reward[index], dtype=torch.float)
        batch_next_state = torch.tensor(self.next_state[index], dtype=torch.float)
        batch_done = torch.tensor(self.done[index], dtype=torch.float)
        return batch_state, batch_action, batch_reward, batch_next_state, batch_done


class Actor(nn.Module):
    """Actor network."""

    def __init__(
        self, state_dim: int, action_dim: int, hidden_width: int, max_action: float
    ) -> None:
        super().__init__()
        self.max_action = max_action
        self.in_layer = nn.Sequential(
            nn.Linear(state_dim, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
        )
        self.res_layer = nn.Sequential(
            nn.Linear(hidden_width, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
            nn.Linear(hidden_width, hidden_width),
        )
        self.out_layer = nn.Sequential(
            nn.Linear(hidden_width, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
        )
        self.mean_layer = nn.Sequential(nn.ReLU(), nn.Linear(hidden_width, action_dim))
        self.log_std_layer = nn.Sequential(
            nn.ReLU(inplace=True), nn.Linear(hidden_width, action_dim)
        )

    def forward(self, x: torch.Tensor, deterministic: bool = False) -> tuple:
        """Forward pass."""
        x = self.in_layer(x)
        x = self.out_layer(x + self.res_layer(x))
        mean = self.mean_layer(x)
        log_std = self.log_std_layer(x)
        log_std = torch.clamp(log_std, -20, 2)
        std = torch.exp(log_std)
        dist = Normal(mean, std)
        if deterministic:
            action = mean
        else:
            action = dist.rsample()
        log_pi = dist.log_prob(action).sum(dim=1, keepdim=True)
        log_pi -= (2 * (np.log(2) - action - F.softplus(-2 * action))).sum(
            dim=1, keepdim=True
        )
        action = self.max_action * torch.tanh(action)
        return action, log_pi


class Critic(nn.Module):
    """Critic network."""

    def __init__(self, state_dim: int, action_dim: int, hidden_width: int) -> None:
        super().__init__()
        self.in_layer1 = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
        )
        self.res_layer1 = nn.Sequential(
            nn.Linear(hidden_width, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
            nn.Linear(hidden_width, hidden_width),
        )
        self.out_layer1 = nn.Sequential(
            nn.ReLU(inplace=True), nn.Linear(hidden_width, 1)
        )
        self.in_layer2 = nn.Sequential(
            nn.Linear(state_dim + action_dim, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
        )
        self.res_layer2 = nn.Sequential(
            nn.Linear(hidden_width, hidden_width),
            nn.ReLU(inplace=True),
            nn.LayerNorm(hidden_width),
            nn.Linear(hidden_width, hidden_width),
        )
        self.out_layer2 = nn.Sequential(
            nn.ReLU(inplace=True), nn.Linear(hidden_width, 1)
        )

    def forward(self, state: torch.Tensor, action: torch.Tensor) -> tuple:
        """Forward pass."""
        state_action = torch.cat([state, action], 1)
        q1 = self.in_layer1(state_action)
        q1 = self.out_layer1(q1 + self.res_layer1(q1))
        q2 = self.in_layer2(state_action)
        q2 = self.out_layer2(q2 + self.res_layer2(q2))
        return q1, q2


class SACContinuous:
    """Soft Actor-Critic for continuous action space."""

    def __init__(self, state_dim: int, action_dim: int, max_action: float) -> None:
        self.gamma = args.gamma
        self.tau = args.tau
        self.batch_size = args.batch_size
        self.learning_rate = args.learning_rate
        self.hidden_width = args.hidden_width
        self.max_action = max_action
        self.target_entropy = -action_dim
        self.log_alpha = torch.zeros(1, requires_grad=True)
        self.alpha = self.log_alpha.exp()
        self.alpha_optimizer = torch.optim.Adam([self.log_alpha], lr=self.learning_rate)
        self.actor = Actor(state_dim, action_dim, self.hidden_width, max_action)
        self.actor_optimizer = torch.optim.Adam(
            self.actor.parameters(), lr=self.learning_rate
        )
        self.critic = Critic(state_dim, action_dim, self.hidden_width)
        self.critic_target = copy.deepcopy(self.critic)
        self.critic_optimizer = torch.optim.Adam(
            self.critic.parameters(), lr=self.learning_rate
        )

    def choose_action(
        self, state: np.ndarray, deterministic: bool = False
    ) -> np.ndarray:
        """Choose action."""
        state = torch.unsqueeze(torch.tensor(state, dtype=torch.float), 0)
        action, _ = self.actor(state, deterministic)
        return action.data.numpy().flatten()

    def learn(self, relay_buffer: ReplayBuffer) -> None:
        """Learn."""
        batch_state, batch_action, batch_reward, batch_next_state, batch_done = (
            relay_buffer.sample(self.batch_size)
        )
        with torch.no_grad():
            batch_next_action, log_pi_ = self.actor(batch_next_state)
            target_q1, target_q2 = self.critic_target(
                batch_next_state, batch_next_action
            )
            target_q = batch_reward + self.gamma * (1 - batch_done) * (
                torch.min(target_q1, target_q2) - self.alpha * log_pi_
            )
        current_q1, current_q2 = self.critic(batch_state, batch_action)
        critic_loss = F.mse_loss(current_q1, target_q) + F.mse_loss(
            current_q2, target_q
        )
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        for params in self.critic.parameters():
            params.requires_grad = False
        action, log_pi = self.actor(batch_state)
        q1, q2 = self.critic(batch_state, action)
        q = torch.min(q1, q2)
        actor_loss = (self.alpha * log_pi - q).mean()
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        for params in self.critic.parameters():
            params.requires_grad = True
        alpha_loss = -(
            self.log_alpha.exp() * (log_pi + self.target_entropy).detach()
        ).mean()
        self.alpha_optimizer.zero_grad()
        alpha_loss.backward()
        self.alpha_optimizer.step()
        self.alpha = self.log_alpha.exp()
        for param, target_param in zip(
            self.critic.parameters(), self.critic_target.parameters()
        ):
            target_param.data.copy_(
                self.tau * param.data + (1 - self.tau) * target_param.data
            )


def evaluate_policy(env, agent: SACContinuous) -> float:
    """Evaluate the policy."""
    state = env.reset()[0]
    done = False
    episode_reward = 0
    action_num = 0
    while not done:
        action = agent.choose_action(state, deterministic=True)
        next_statue, reward, done, _, _ = env.step(action)
        episode_reward += reward
        state = next_statue
        action_num += 1
        if action_num >= 1000:
            print("action_num too large.")
            break
        if episode_reward <= -1000:
            print("episode_reward too small.")
            break
    return episode_reward


def training() -> None:
    """My demo training function."""
    env_name = "Pendulum-v1"
    env = gym.make(env_name)
    env_evaluate = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])
    agent = SACContinuous(state_dim, action_dim, max_action)
    replay_buffer = ReplayBuffer(state_dim, action_dim)
    evaluate_num = 0
    total_steps = 0
    while total_steps < args.max_train_steps:
        state = env.reset()[0]
        episode_steps = 0
        done = False
        while not done:
            episode_steps += 1
            action = agent.choose_action(state)
            next_state, reward, done, _, _ = env.step(action)
            replay_buffer.store(state, action, reward, next_state, done)
            state = next_state
            if total_steps >= args.min_buffer_size:
                agent.learn(replay_buffer)
            if (total_steps + 1) % args.evaluate_freqency == 0:
                evaluate_num += 1
                evaluate_reward = evaluate_policy(env_evaluate, agent)
                print(
                    f"evaluate_num: {evaluate_num} \t evaluate_reward: {evaluate_reward}"
                )
            total_steps += 1
            if total_steps >= args.max_train_steps:
                break
    env.close()
    torch.save(agent.actor.state_dict(), f"{args.log_path}/trained_model.pth")


def testing() -> None:
    """My demo testing function."""
    env_name = "Pendulum-v1"
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    max_action = float(env.action_space.high[0])
    agent = SACContinuous(state_dim, action_dim, max_action)
    agent.actor.load_state_dict(torch.load(f"{args.log_path}/trained_model.pth"))
    state = env.reset()[0]
    total_rewards = 0
    for _ in range(1000):
        env.render()
        action = agent.choose_action(state)
        new_state, reward, _, _, _ = env.step(action)
        total_rewards += reward
        state = new_state
    env.close()
    print(f"SAC actor scores: {total_rewards}")


if __name__ == "__main__":
    args = parse_args()
    training()
    testing()

  

标签:nn,动作,SAC,self,width,state,例子,action,hidden
From: https://www.cnblogs.com/qiandeheng/p/18457172

相关文章

  • 工作 6 年,@Transactional 注解用的一塌糊涂
    接手新项目一言难尽,别的不说单单就一个 @Transactional 注解用的一塌糊涂,五花八门的用法,很大部分还失效无法回滚。有意识的在涉及事务相关方法上加 @Transactional 注解,是个好习惯。不过,很多同学只是下意识地添加这个注解,一旦功能正常运行,很少有人会深入验证异常情况下事务......
  • P6277 [USACO20OPEN] Circus P
    做法来自浙江队长,因为其他的题解我一篇都看不懂。考察一条极长的二度链C,即左右端点度数不为\(2\),中间的点度数都等于\(2\),它把整张图分成了左右两部分A和B(端点既属于AB也属于C)。如果\(|C|\gen-k\),那么A和B都一定被占满了,C上的点一定会阻挡A和B之间互换,所......
  • webapi测试例子
     1.修改WebApiConfig.cs中路由路径  问题:webapi的默认路由并不需要指定action的名称(WebApi的默认路由是通过http的方法get/post/put/delete去匹配对应的action),        但默认路由模板无法满足针对一种资源一种请求方式的多种操作。  解决:打开App_Sta......
  • Camstar Create Transaction Database
    sqlserverUSE[master]GO--CreatedatabaseCREATEDATABASEINSITEONPRIMARY(NAME='INSITE',FILENAME='C:\ProgramFiles\MicrosoftSQLServer\MSSQL13.MSSQLSERVER\MSSQL\DATA\INSITE.mdf',SIZE=100MB,FileGrowth=10%)LOGON(......
  • 简单的c++实现消息发布/订阅机制例子(成员函数被其他类掉调用的例子)
    以下是一个简单的使用C++实现发布/订阅机制的示例代码。这个示例包含一个简单的事件系统,其中有发布者(Publisher)和订阅者(Subscriber)。以下代码需要C++11以上支持#include<iostream>#include<vector>#include<functional>//事件参数结构体,可以根据实际需求修改struc......
  • 题解:洛谷P2339 [USACO04OPEN] Turning in Homework G
    题目链接:洛谷P2339[USACO04OPEN]TurninginHomeworkG首先我们考虑如何处理到达给定时间后才能交作业这一限制。其实在生活中,我们一般只会考虑什么时候交作业截止(除了某些卷王),这样我们只用考虑如何在最大结束时间之前交作业,而不是在所有作业都没开始交之前考虑如何转悠(前者明......
  • 题解:P9954 [USACO20OPEN] Cowntact Tracing B
    考虑暴力。枚举让每头牛都当一次“零号病人”和\(K\)的所有组合,模拟感染的过程,检查得出的病人是否和给出的一样即可。代码:#include<bits/stdc++.h>usingnamespacestd;boolinfectedd[101];intN,cowx[251],cowy[251];boolcheck(intpatient_zero,intK){ boolinfect......
  • 题解:P9939 [USACO21OPEN] Acowdemia III B
    考虑贪心。遍历每只奶牛:如果它最多与一头奶牛相邻,那么什么都不会发生。如果它与两头以上的奶牛相邻,那么它与两侧的两头奶牛相邻。将答案递增\(1\)。否则,如果正好有两头相邻的奶牛,我们就把它们配对。也就是说,将这对奶牛插入一组。代码:#include<bits/stdc++.h>usingname......
  • 题解:P1701 [USACO19OPEN] Cow Evolution B
    这题的关键就在于能否将问题转化成集合之间是否有交集。首先,考虑一个我们无法形成进化树的例子,例如这样:31fly1man2flyman如果我们想根据这个输入构建一棵树,我们需要在根上分割A或B,但剩下的两个子树都需要有一条边来添加另一个特征。显然输出为"No"。如果我们输入......
  • transaction_timeout:达到事务超时时终止会话
    功能实现背景说明我们已经有两个参数来控制长事务:statement_timeout和idle_in_transaction_session_timeout。但是,如果事务执行的命令足够短且不超过statement_timeout,并且命令之间的暂停时间适合idle_in_transaction_session_timeout,则事务可以无限期持续。在这种情况下,tra......