首页 > 编程语言 >强化学习-表格型算法Q学习稳定倒立摆小车

强化学习-表格型算法Q学习稳定倒立摆小车

时间:2024-07-07 21:53:29浏览次数:13  
标签:observation 表格 self 学习 state num action table 倒立

[[Q 学习]] 是表格型算法的一种,主要维护了一个 Q-table,里面是 状态-动作 对的价值,分别由一个状态和一个动作来索引。

这里以一个经典的道理摆小车问题来说明如何使用 [[Q 学习]] 算法。
这里会用到两个类,agentbrainbrain 类中来维护 [[强化学习的基本概念|强化学习]] 算法的具体执行,agent 是一层封装,以后也可以用其他算法来实现 brain 类。整个的逻辑也可以参考[[强化学习基本程序框架]]。
首先是 agent

class Agent():
    def __init__(self, num_states, num_actions):
        self.brain = Brain(num_states, num_actions)

    def update_Q_fun(self, observation, reward, action, next_observation):
        self.brain.update_Q_table( observation, reward, action, next_observation)
        
    def get_action(self, observation,step):
        action = self.brain.decide_action(observation, step)
        return action
        

其中 get_action 就是根据状态选择一个动作,可以不放到 brain 类里面,一般都是 \(\epsilon\) -贪心算法在动作空间里面选动作。update_Q_fun 用来更新 Q-table,如果是其他算法,比如说 [[DQN]],换个名字就行。

然后是 brain

class Brain():
    def __init__(self, num_states, num_actions):
        self.num_actions = num_actions
        self.Q_table = np.random.uniform(low=0, high=1, size=(NUM_DIZITIZED**num_states, num_actions))
    
    def bins(self,clip_min, clip_max, num ):
        return bins(clip_min, clip_max, num)
    
    def digitize_state(self,observation) :
        cart_pos, cart_v, pole_angle, pole_v = observation
        digitized = [
        np.digitize(cart_pos, bins=self.bins(-2.4, 2.4, NUM_DIZITIZED)),
        np.digitize(cart_v, bins=self.bins(-3.0, 3.0, NUM_DIZITIZED)) ,
        np.digitize(pole_angle, bins=self.bins(-0.5, 0.5, NUM_DIZITIZED)) ,
        np.digitize(pole_v, bins=self.bins(-2.0, 2.0, NUM_DIZITIZED) )
    ]
        return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])
    
    def update_Q_table(self, observation, reward, action, next_observation):
        state = self.digitize_state(observation) 
        state_next = self.digitize_state(observation_next)
        Max_Q_next = np.max(self.Q_table[state_next][:])
        self.Q_table[state,action] = self.Q_table[state,action] + ETA * (reward + GAMMA * Max_Q_next - self.Q_table[state,action])
        
    def decide_action(self, observation,episode):  
        state = self.digitize_state(observation)
        epsilon = 0.5 * (1 / (episode + 1))
        
        if epsilon <= np.random.uniform(0, 1):
            action = np.argmax(self.Q_table[state][:])
        else:
            action = np.random.choice(self.num_actions)
        return action

update_Q_table 就是根据时序差分的公式更新 Q-table。

\[Q(s_t,a_t)\leftarrow Q(s_t,a_t)+\alpha[R_t+\gamma\max_aQ(s_{t+1},a)-Q(s_t,a_t)] \]

其中,\(\alpha\) 是学习率,\(\gamma\) 是奖励累积的折扣系数。如果这里的 \(\max_aQ(s_{t+1},a)\) 换成 \(Q(s_{t+1},a_{t+1})\) 的话,就是 [[sarsa 算法]]。
decide_action 就是前面提到的 \(\epsilon\) -贪心算法选取动作,这里的 \(\epsilon\) 是随 episode 的数量衰减的。
digitize_state 是为了处理连续状态的。因为倒立摆小车的位置、速度、杆的角度这些信息是连续变量(尽管是在计算机中仿真,我们也认为是连续的),所以为了能在表格中维护,需要将状态进行离散化处理,比如位置在什么范围内就认为其状态是 1。为了减少内存的占用,示例里 NUM_DIZITIZED 等于 6,意思是只用 6 个数来划分表示单一维度里面的连续区间的状态。实际上,如果状态空间任一维度都很大或者状态空间本身就是连续的,后面会有 [[DQN]] 等算法可以处理。

仿真代码:

frames=[]
#环境初始化
env=gym.make('CartPole-v0')
observation = env.reset()#需要先重置环境

NUM_DIZITIZED = 6
GAMMA=0.99   # 时间折扣率
ETA=0.5       # 学习系数
MAX_STEPS=200
NUM_EPISODES = 200

agent = Agent(6,2)
complete_episodes = 0 
is_episode_final = False 

for episode in range(NUM_EPISODES):
    observation = env.reset()

    for step in range(0,MAX_STEPS):
        if is_episode_final:
            frames.append(env.render(mode='rgb_array')) #将各个时刻的图像添加到帧中
        
        action = agent.get_action (observation, episode)
        observation_next, _, done, _ = env.step(action)

		# 自定义的奖励部分
		# 如果结束的时候,已经稳定了190步,就给1的奖励,否则-1.没结束的时候奖励是0
        if done: 
            if step < 190:
                reward = -1 
                complete_episode = 0 
            else:
                reward = 1
                complete_episodes += 1 
        else:
            reward = 0 
        
        agent.update_Q_fun(observation,reward,action,observation_next)
        
        observation= observation_next
        
        if done:
            print(f'{episode} Episode: Finished after {step + 1} time steps')
            break
            
    if complete_episodes >= 10:
        print('10回合连续成功')
        is_episode_final = True
        
display_frames_as_gif(frames)

More Reading

[[边做边学深度强化学习:PyTorch程序设计实践]]

Reference

标签:observation,表格,self,学习,state,num,action,table,倒立
From: https://www.cnblogs.com/pomolnc/p/18288975

相关文章

  • 学习Linux LVM,这篇文章就够了
      (1)引言     LVM(LogicalVolumeManager)逻辑卷管理,是在硬盘分区和文件系统之间添加的一个逻辑层,为文件系统屏蔽下层硬盘分区布局,并提供一个抽象的盘卷,在盘卷上建立文件系统。管理员利用LVM可以在硬盘不用重新分区的情况下动态调整文件系统的大小,并且利用LVM管理的......
  • 昇思25天学习打卡营第11天 | LLM原理和实践:基于MindSpore实现BERT对话情绪识别
    1.基于MindSpore实现BERT对话情绪识别1.1环境配置#实验环境已经预装了mindspore==2.2.14,如需更换mindspore版本,可更改下面mindspore的版本号!pipuninstallmindspore-y!pipinstall-ihttps://pypi.mirrors.ustc.edu.cn/simplemindspore==2.2.14#该案例在min......
  • 昇思25天学习打卡营第10天 | 自然语言处理:RNN实现情感分类
    1.RNN实现情感分类1.2概述情感分类是自然语言处理中的经典任务,是典型的分类问题。本节使用MindSpore实现一个基于RNN网络的情感分类模型,实现如下的效果:输入:Thisfilmisterrible正确标签:Negative(负面)预测标签:Negative输入:Thisfilmisgreat正确标签:......
  • Kaggle网站免费算力使用,深度学习模型训练
    声明:本文主要内容为:kaggle网站数据集上传,训练模型下载、模型部署、提交后台运行等教程。1、账号注册此步骤本文略过,如有需要可以参考其他文章。2、上传资源不论是上传训练好的模型进行预测,还是训练用的数据集都可以按此步骤上传。如果是数据集的话,先要将数据集进行压缩,才......
  • 【机器学习】机器学习与时间序列分析的融合应用与性能优化新探索
    文章目录引言第一章:机器学习在时间序列分析中的应用1.1数据预处理1.1.1数据清洗1.1.2数据归一化1.1.3数据增强1.2模型选择1.2.1自回归模型1.2.2移动平均模型1.2.3长短期记忆网络1.2.4卷积神经网络1.3模型训练1.3.1梯度下降1.3.2随机梯度下降1.3.3Adam优......
  • 关于数据结构的学习心得
    介绍在备赛xcpc时,其实除了数据结构以外,绝大部分常用的大纲知识都学习了,但数据结构确实是练得最多的,本文主要介绍一下个人是如何学习数据结构的。数据结构概述数据结构大概是很多人比较抵触系统学习的东西,因为许多数据结构来说,光是板子就比其他领域的算法长很多。比如线段树,可......
  • 第一周学习报告
    在第一周,对Java进行初步了解,学习了Java的一些基础知识。学习主要参考于B站上的黑马程序员,以下为这周的学习报告day1打开CMD1.win+r2.输入CMD常见的CMD命令1.盘符名称+冒号盘符切换2.dir查看当前路径下的内容3.cd目录进入单级目录4.cd..回退到上一级目录5.cd目录1......
  • 跟着吴恩达学深度学习(二)
    前言第一门课的笔记见:跟着吴恩达学深度学习(一)本文对应了吴恩达深度学习系列课程中的第二门课程《改善深层神经网络:超参数调试、正则化以及优化》第二门课程授课大纲:深度学习的实用层面优化算法超参数调试、Batch正则化和程序框架目录1深度学习的实用层面 1.1 训练/......
  • 强化学习(Value Function Approximation)-Today9
    ValueFunctionApproximation主要是使用神经网络来求最优解问题,主要包括Algorithmforstatevaluefunction、Sarsa和valuefunctionapproximation的结合、Q-learning和valuefunctionapproximation的结合、DeepQ-learning。由于tables的数据不能处理很大的statespace或......
  • 昇思25天学习打卡营第14天|SSD目标检测
    今天学习的是SSD目标检测内容,首先介绍什么是SSD?SSD,全称SingleShotMultiBoxDetector,是WeiLiu在ECCV2016上提出的一种目标检测算法。使用NvidiaTitanX在VOC2007测试集上,SSD对于输入尺寸300x300的网络,达到74.3%mAP(meanAveragePrecision)以及59FPS;对于512x512的网......