[[Q 学习]] 是表格型算法的一种,主要维护了一个 Q-table,里面是 状态-动作
对的价值,分别由一个状态和一个动作来索引。
这里以一个经典的道理摆小车问题来说明如何使用 [[Q 学习]] 算法。
这里会用到两个类,agent
和 brain
。brain
类中来维护 [[强化学习的基本概念|强化学习]] 算法的具体执行,agent
是一层封装,以后也可以用其他算法来实现 brain
类。整个的逻辑也可以参考[[强化学习基本程序框架]]。
首先是 agent
类
class Agent():
def __init__(self, num_states, num_actions):
self.brain = Brain(num_states, num_actions)
def update_Q_fun(self, observation, reward, action, next_observation):
self.brain.update_Q_table( observation, reward, action, next_observation)
def get_action(self, observation,step):
action = self.brain.decide_action(observation, step)
return action
其中 get_action
就是根据状态选择一个动作,可以不放到 brain
类里面,一般都是 \(\epsilon\) -贪心算法在动作空间里面选动作。update_Q_fun
用来更新 Q-table,如果是其他算法,比如说 [[DQN]],换个名字就行。
然后是 brain
类
class Brain():
def __init__(self, num_states, num_actions):
self.num_actions = num_actions
self.Q_table = np.random.uniform(low=0, high=1, size=(NUM_DIZITIZED**num_states, num_actions))
def bins(self,clip_min, clip_max, num ):
return bins(clip_min, clip_max, num)
def digitize_state(self,observation) :
cart_pos, cart_v, pole_angle, pole_v = observation
digitized = [
np.digitize(cart_pos, bins=self.bins(-2.4, 2.4, NUM_DIZITIZED)),
np.digitize(cart_v, bins=self.bins(-3.0, 3.0, NUM_DIZITIZED)) ,
np.digitize(pole_angle, bins=self.bins(-0.5, 0.5, NUM_DIZITIZED)) ,
np.digitize(pole_v, bins=self.bins(-2.0, 2.0, NUM_DIZITIZED) )
]
return sum([x * (NUM_DIZITIZED**i) for i, x in enumerate(digitized)])
def update_Q_table(self, observation, reward, action, next_observation):
state = self.digitize_state(observation)
state_next = self.digitize_state(observation_next)
Max_Q_next = np.max(self.Q_table[state_next][:])
self.Q_table[state,action] = self.Q_table[state,action] + ETA * (reward + GAMMA * Max_Q_next - self.Q_table[state,action])
def decide_action(self, observation,episode):
state = self.digitize_state(observation)
epsilon = 0.5 * (1 / (episode + 1))
if epsilon <= np.random.uniform(0, 1):
action = np.argmax(self.Q_table[state][:])
else:
action = np.random.choice(self.num_actions)
return action
update_Q_table
就是根据时序差分的公式更新 Q-table。
其中,\(\alpha\) 是学习率,\(\gamma\) 是奖励累积的折扣系数。如果这里的 \(\max_aQ(s_{t+1},a)\) 换成 \(Q(s_{t+1},a_{t+1})\) 的话,就是 [[sarsa 算法]]。
decide_action
就是前面提到的 \(\epsilon\) -贪心算法选取动作,这里的 \(\epsilon\) 是随 episode
的数量衰减的。
digitize_state
是为了处理连续状态的。因为倒立摆小车的位置、速度、杆的角度这些信息是连续变量(尽管是在计算机中仿真,我们也认为是连续的),所以为了能在表格中维护,需要将状态进行离散化处理,比如位置在什么范围内就认为其状态是 1。为了减少内存的占用,示例里 NUM_DIZITIZED
等于 6,意思是只用 6 个数来划分表示单一维度里面的连续区间的状态。实际上,如果状态空间任一维度都很大或者状态空间本身就是连续的,后面会有 [[DQN]] 等算法可以处理。
仿真代码:
frames=[]
#环境初始化
env=gym.make('CartPole-v0')
observation = env.reset()#需要先重置环境
NUM_DIZITIZED = 6
GAMMA=0.99 # 时间折扣率
ETA=0.5 # 学习系数
MAX_STEPS=200
NUM_EPISODES = 200
agent = Agent(6,2)
complete_episodes = 0
is_episode_final = False
for episode in range(NUM_EPISODES):
observation = env.reset()
for step in range(0,MAX_STEPS):
if is_episode_final:
frames.append(env.render(mode='rgb_array')) #将各个时刻的图像添加到帧中
action = agent.get_action (observation, episode)
observation_next, _, done, _ = env.step(action)
# 自定义的奖励部分
# 如果结束的时候,已经稳定了190步,就给1的奖励,否则-1.没结束的时候奖励是0
if done:
if step < 190:
reward = -1
complete_episode = 0
else:
reward = 1
complete_episodes += 1
else:
reward = 0
agent.update_Q_fun(observation,reward,action,observation_next)
observation= observation_next
if done:
print(f'{episode} Episode: Finished after {step + 1} time steps')
break
if complete_episodes >= 10:
print('10回合连续成功')
is_episode_final = True
display_frames_as_gif(frames)
More Reading
[[边做边学深度强化学习:PyTorch程序设计实践]]