首页 > 其他分享 >强化学习(三):PPO连续

强化学习(三):PPO连续

时间:2024-02-29 15:13:04浏览次数:28  
标签:self torch PPO 学习 state memory action 强化 reward

一、PPO连续

import torch
import torch.nn as nn
from torch.distributions import MultivariateNormal
import gym
import numpy as np

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class Memory:
    def __init__(self):
        self.actions = []
        self.states = []
        self.logprobs = []
        self.rewards = []
        self.is_terminals = []
    
    def clear_memory(self):
        del self.actions[:]
        del self.states[:]
        del self.logprobs[:]
        del self.rewards[:]
        del self.is_terminals[:]

class ActorCritic(nn.Module):
    def __init__(self, state_dim, action_dim, action_std):
        super(ActorCritic, self).__init__()
        # action mean range -1 to 1
        self.actor =  nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 32),
                nn.Tanh(),
                nn.Linear(32, action_dim),
                nn.Tanh()
                )
        # critic
        self.critic = nn.Sequential(
                nn.Linear(state_dim, 64),
                nn.Tanh(),
                nn.Linear(64, 32),
                nn.Tanh(),
                nn.Linear(32, 1)
                )
        self.action_var = torch.full((action_dim,), action_std*action_std).to(device)
        
    def forward(self):
        raise NotImplementedError
    
    def act(self, state, memory):
        action_mean = self.actor(state)
        cov_mat = torch.diag(self.action_var).to(device)
        
        dist = MultivariateNormal(action_mean, cov_mat)
        action = dist.sample()
        action_logprob = dist.log_prob(action)
        
        memory.states.append(state)
        memory.actions.append(action)
        memory.logprobs.append(action_logprob)
        
        return action.detach()
    
    def evaluate(self, state, action):   
        action_mean = self.actor(state)
        
        action_var = self.action_var.expand_as(action_mean)
        cov_mat = torch.diag_embed(action_var).to(device)
        
        dist = MultivariateNormal(action_mean, cov_mat)
        
        action_logprobs = dist.log_prob(action)
        dist_entropy = dist.entropy()
        state_value = self.critic(state)
        
        return action_logprobs, torch.squeeze(state_value), dist_entropy

class PPO:
    def __init__(self, state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip):
        self.lr = lr
        self.betas = betas
        self.gamma = gamma
        self.eps_clip = eps_clip
        self.K_epochs = K_epochs
        
        self.policy = ActorCritic(state_dim, action_dim, action_std).to(device)
        self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, betas=betas)
        
        self.policy_old = ActorCritic(state_dim, action_dim, action_std).to(device)
        self.policy_old.load_state_dict(self.policy.state_dict())
        
        self.MseLoss = nn.MSELoss()
    
    def select_action(self, state, memory):
        state = torch.FloatTensor(state.reshape(1, -1)).to(device)
        return self.policy_old.act(state, memory).cpu().data.numpy().flatten()
    
    def update(self, memory):
        # Monte Carlo estimate of rewards:
        rewards = []
        discounted_reward = 0
        for reward, is_terminal in zip(reversed(memory.rewards), reversed(memory.is_terminals)):
            if is_terminal:
                discounted_reward = 0
            discounted_reward = reward + (self.gamma * discounted_reward)
            rewards.insert(0, discounted_reward)
        
        # Normalizing the rewards:
        rewards = torch.tensor(rewards).to(device)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-5)
        
        # convert list to tensor
        old_states = torch.squeeze(torch.stack(memory.states).to(device), 1).detach()
        old_actions = torch.squeeze(torch.stack(memory.actions).to(device), 1).detach()
        old_logprobs = torch.squeeze(torch.stack(memory.logprobs), 1).to(device).detach()
        
        # Optimize policy for K epochs:
        for _ in range(self.K_epochs):
            # Evaluating old actions and values :
            logprobs, state_values, dist_entropy = self.policy.evaluate(old_states, old_actions)
            
            # Finding the ratio (pi_theta / pi_theta__old):
            ratios = torch.exp(logprobs - old_logprobs.detach())

            # Finding Surrogate Loss:
            advantages = rewards - state_values.detach()   
            surr1 = ratios * advantages
            surr2 = torch.clamp(ratios, 1-self.eps_clip, 1+self.eps_clip) * advantages
            loss = -torch.min(surr1, surr2) + 0.5*self.MseLoss(state_values, rewards) - 0.01*dist_entropy
            
            # take gradient step
            self.optimizer.zero_grad()
            loss.mean().backward()
            self.optimizer.step()
            
        # Copy new weights into old policy:
        self.policy_old.load_state_dict(self.policy.state_dict())
        
def main():
    ############## Hyperparameters ##############
    env_name = "BipedalWalker-v3"
    render = True
    solved_reward = 300         # stop training if avg_reward > solved_reward
    log_interval = 20           # print avg reward in the interval
    max_episodes = 10000        # max training episodes
    max_timesteps = 1500        # max timesteps in one episode
    
    update_timestep = 4000      # update policy every n timesteps
    action_std = 0.5            # constant std for action distribution (Multivariate Normal)
    K_epochs = 80               # update policy for K epochs
    eps_clip = 0.2              # clip parameter for PPO
    gamma = 0.99                # discount factor
    
    lr = 0.0003                 # parameters for Adam optimizer
    betas = (0.9, 0.999)
    
    random_seed = None
    #############################################
    
    # creating environment
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    if random_seed:
        print("Random Seed: {}".format(random_seed))
        torch.manual_seed(random_seed)
        env.seed(random_seed)
        np.random.seed(random_seed)
    
    memory = Memory()
    ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip)
    print(lr,betas)
    
    # logging variables
    running_reward = 0
    avg_length = 0
    time_step = 0
    
    # training loop
    for i_episode in range(1, max_episodes+1):
        state = env.reset()
        for t in range(max_timesteps):
            time_step +=1
            # Running policy_old:
            action = ppo.select_action(state, memory)
            state, reward, done, _ = env.step(action)
            
            # Saving reward and is_terminals:
            memory.rewards.append(reward)
            memory.is_terminals.append(done)
            
            # update if its time
            if time_step % update_timestep == 0:
                ppo.update(memory)
                memory.clear_memory()
                time_step = 0
            running_reward += reward
            if render:
                env.render()
            if done:
                break
        
        avg_length += t
        
        # stop training if avg_reward > solved_reward
        if running_reward > (log_interval*solved_reward):
            print("########## Solved! ##########")
            torch.save(ppo.policy.state_dict(), './PPO_continuous_solved_{}.pth'.format(env_name))
            break
        
        # save every 500 episodes
        if i_episode % 500 == 0:
            torch.save(ppo.policy.state_dict(), './PPO_continuous_{}.pth'.format(env_name))
            
        # logging
        if i_episode % log_interval == 0:
            avg_length = int(avg_length/log_interval)
            running_reward = int((running_reward/log_interval))
            
            print('Episode {} \t Avg length: {} \t Avg reward: {}'.format(i_episode, avg_length, running_reward))
            running_reward = 0
            avg_length = 0
            
if __name__ == '__main__':
    main()
    

 

二、PPO测试

import gym
from PPO_continuous import PPO, Memory
from PIL import Image
import torch

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def test():
    ############## Hyperparameters ##############
    env_name = "BipedalWalker-v2"
    env = gym.make(env_name)
    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    
    n_episodes = 3          # num of episodes to run
    max_timesteps = 1500    # max timesteps in one episode
    render = True           # render the environment
    save_gif = False        # png images are saved in gif folder
    
    # filename and directory to load model from
    filename = "PPO_continuous_" +env_name+ ".pth"
    directory = "./preTrained/"

    action_std = 0.5        # constant std for action distribution (Multivariate Normal)
    K_epochs = 80           # update policy for K epochs
    eps_clip = 0.2          # clip parameter for PPO
    gamma = 0.99            # discount factor
    
    lr = 0.0003             # parameters for Adam optimizer
    betas = (0.9, 0.999)
    #############################################
    
    memory = Memory()
    ppo = PPO(state_dim, action_dim, action_std, lr, betas, gamma, K_epochs, eps_clip)
    ppo.policy_old.load_state_dict(torch.load(directory+filename))
    
    for ep in range(1, n_episodes+1):
        ep_reward = 0
        state = env.reset()
        for t in range(max_timesteps):
            action = ppo.select_action(state, memory)
            state, reward, done, _ = env.step(action)
            ep_reward += reward
            if render:
                env.render()
            if save_gif:
                 img = env.render(mode = 'rgb_array')
                 img = Image.fromarray(img)
                 img.save('./gif/{}.jpg'.format(t))  
            if done:
                break
            
        print('Episode: {}\tReward: {}'.format(ep, int(ep_reward)))
        ep_reward = 0
        env.close()
    
if __name__ == '__main__':
    test()
    
    

 

标签:self,torch,PPO,学习,state,memory,action,强化,reward
From: https://www.cnblogs.com/zhangxianrong/p/18044295

相关文章

  • m基于深度学习的16QAM调制解调系统相位检测和补偿算法matlab仿真
    1.算法仿真效果matlab2022a仿真结果如下: 2.算法涉及理论知识概要        随着通信技术的飞速发展,高阶调制格式如16QAM(16-QuadratureAmplitudeModulation,16进制正交幅度调制)在高速数据传输中得到了广泛应用。然而,由于信道失真、噪声干扰等因素,接收端往往面临相......
  • Vue学习笔记23--监视数据总结
    Vue监视数据总结vue会监视data中所有层次的数据如何监测对象中的数据通过setter实现监视,且要在newVue时就传入要监测的数据对象中后添加的属性,Vue默认不做响应式处理如需给后添加的属性做响应式,请使用如下API:Vue.set(target,propertyName/index,vaue)或vm.$set(target,pr......
  • 欧拉函数学习笔记
    首先,\(\varphi(n)\)的值是\(n\)内与\(n\)互质的数的个数。//求n的欧拉函数值:phi[n]intgetPhi(intn){intans=n;for(inti=2;i*i<=n;i++){if(n%i==0){ans=ans*(i-1)/i;while(n%i==0)n/=i;......
  • 学习之Web服务器
    2.1WEB服务器Web服务器通常由硬件和软件共同构成。硬件:电脑,提供服务供其它客户电脑访问软件:电脑上安装的服务器软件,安装后能提供服务给网络中的其他计算机,将本地文件映射成一个虚拟的url地址供网络中的其他人访问。常见的JavaWeb服务器:Tomcat(Apache):当前应用最广的Ja......
  • 学习之WEB项目的标准结构
    一个标准的可以用于发布的WEB项目标准结构如下app本应用根目录static非必要目录,约定俗成的名字,一般在此处放静态资源(cssjsimg)WEB-INF必要目录,必须叫WEB-INF,受保护的资源目录,浏览器通过url不可以直接访问的目录classes必要目录,src下源代码,配置......
  • 学习之HTTP(2)
    3.1.3HTTP1.0和HTTP1.1的区别在HTTP1.0版本中,浏览器请求一个带有图片的网页,会由于下载图片而与服务器之间开启一个新的连接;但在HTTP1.1版本中,允许浏览器在拿到当前请求对应的全部资源后再断开连接,提高了效率。3.1.4在浏览器中通过F12工具抓取请求响应报文包几乎所有的PC......
  • 学习之Http协议
    3.1HTTP简介HTTP超文本传输协议(HTTP-HyperTexttransferprotocol),是一个属于应用层的面向对象的协议,由于其简捷、快速的方式,适用于分布式超媒体信息系统。它于1990年提出,经过十几年的使用与发展,得到不断地完善和扩展。它是一种详细规定了浏览器和万维网服务器之间互相通......
  • 学习之请求和响应
    3.2请求和响应报文3.2.1报文的格式主体上分为报文首部和报文主体,中间空行隔开报文部首可以继续细分为"行"和"头"3.2.2请求报文客户端发给服务端的报文请求报文格式请求首行(请求行);GET/POST资源路径?参数HTTP/1.1(默认是通过GET请求获取服务器信......
  • 学习之请求报文
    3.2.2请求报文客户端发给服务端的报文请求报文格式请求首行(请求行);GET/POST资源路径?参数HTTP/1.1(默认是通过GET请求获取服务器信息)(通常表单提交信息到服务器用POST请求)请求头信息(请求头);空行;请求体;POST请求才有请求体浏览器f12网络下查看请求数据包......
  • 学习之@WebServlet
    3.1@WebServlet注解源码官方JAVAEEAPI文档下载地址JavaEE-Technologies(oracle.com)@WebServlet注解的源码阅读packagejakarta.servlet.annotation;importjava.lang.annotation.Documented;importjava.lang.annotation.ElementType;importjava.lang.ann......