首页 > 其他分享 >动手学强化学习(五):值迭代与策略迭代代码

动手学强化学习(五):值迭代与策略迭代代码

时间:2024-03-03 16:11:57浏览次数:27  
标签:qsa nrow 迭代 代码 动手 range ncol env self

一、策略迭代

import copy
class CliffWalkingEnv:
    """ 悬崖漫步环境"""
    def __init__(self, ncol=12, nrow=4):
        self.ncol = ncol  # 定义网格世界的列
        self.nrow = nrow  # 定义网格世界的行
        # 转移矩阵P[state][action] = [(p, next_state, reward, done)]包含下一个状态和奖励
        self.P = self.createP()

    def createP(self):
        # 初始化
        P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]
        # 4种动作, change[0]:上,change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)
        # 定义在左上角
        change = [[0, -1], [0, 1], [-1, 0], [1, 0]]
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(4):
                    # 位置在悬崖或者目标状态,因为无法继续交互,任何动作奖励都为0
                    if i == self.nrow - 1 and j > 0:
                        P[i * self.ncol + j][a] = [(1, i * self.ncol + j, 0,
                                                    True)]
                        continue
                    # 其他位置
                    next_x = min(self.ncol - 1, max(0, j + change[a][0]))
                    next_y = min(self.nrow - 1, max(0, i + change[a][1]))
                    next_state = next_y * self.ncol + next_x
                    reward = -1
                    done = False
                    # 下一个位置在悬崖或者终点
                    if next_y == self.nrow - 1 and next_x > 0:
                        done = True
                        if next_x != self.ncol - 1:  # 下一个位置在悬崖
                            reward = -100
                    P[i * self.ncol + j][a] = [(1, next_state, reward, done)]
        return P

class PolicyIteration:
    """ 策略迭代算法 """
    def __init__(self, env, theta, gamma):
        self.env = env
        self.v = [0] * self.env.ncol * self.env.nrow  # 初始化价值为0
        self.pi = [[0.25, 0.25, 0.25, 0.25]
                   for i in range(self.env.ncol * self.env.nrow)]  # 初始化为均匀随机策略
        self.theta = theta  # 策略评估收敛阈值
        self.gamma = gamma  # 折扣因子

    def policy_evaluation(self):  # 策略评估
        cnt = 1  # 计数器
        while 1:
            max_diff = 0
            new_v = [0] * self.env.ncol * self.env.nrow
            for s in range(self.env.ncol * self.env.nrow):
                qsa_list = []  # 开始计算状态s下的所有Q(s,a)价值
                for a in range(4):
                    qsa = 0
                    for res in self.env.P[s][a]:
                        p, next_state, r, done = res
                        qsa += p * (r + self.gamma * self.v[next_state] * (1 - done))
                        # 本章环境比较特殊,奖励和下一个状态有关,所以需要和状态转移概率相乘
                    qsa_list.append(self.pi[s][a] * qsa)
                new_v[s] = sum(qsa_list)  # 状态价值函数和动作价值函数之间的关系
                max_diff = max(max_diff, abs(new_v[s] - self.v[s]))
            self.v = new_v
            if max_diff < self.theta: break  # 满足收敛条件,退出评估迭代
            cnt += 1
        print("策略评估进行%d轮后完成" % cnt)

    def policy_improvement(self):  # 策略提升
        for s in range(self.env.nrow * self.env.ncol):
            qsa_list = []
            for a in range(4):
                qsa = 0
                for res in self.env.P[s][a]:
                    p, next_state, r, done = res
                    qsa += p * (r + self.gamma * self.v[next_state] * (1 - done))
                qsa_list.append(qsa)
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)  # 计算有几个动作得到了最大的Q值
            # 让这些动作均分概率
            self.pi[s] = [1 / cntq if q == maxq else 0 for q in qsa_list]
        print("策略提升完成")
        return self.pi

    def policy_iteration(self):  # 策略迭代
        while 1:
            self.policy_evaluation()
            old_pi = copy.deepcopy(self.pi)  # 将列表进行深拷贝,方便接下来进行比较
            new_pi = self.policy_improvement()
            if old_pi == new_pi: break


def print_agent(agent, action_meaning, disaster=[], end=[]):
    print("状态价值:")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 为了输出美观,保持输出6个字符
            print('%6.6s' % ('%.3f' % agent.v[i * agent.env.ncol + j]), end=' ')
        print()

    print("策略:")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 一些特殊的状态,例如悬崖漫步中的悬崖
            if (i * agent.env.ncol + j) in disaster:
                print('****', end=' ')
            elif (i * agent.env.ncol + j) in end:  # 目标状态
                print('EEEE', end=' ')
            else:
                a = agent.pi[i * agent.env.ncol + j]
                pi_str = ''
                for k in range(len(action_meaning)):
                    pi_str += action_meaning[k] if a[k] > 0 else 'o'
                print(pi_str, end=' ')
        print()

if __name__ == '__main__':
    env = CliffWalkingEnv()
    action_meaning = ['^', 'v', '<', '>']
    theta = 0.001
    gamma = 0.9
    agent = PolicyIteration(env, theta, gamma)
    agent.policy_iteration()
    print_agent(agent, action_meaning, list(range(37, 47)), [47])

 

二、值迭代

import copy
class CliffWalkingEnv:
    """ 悬崖漫步环境"""
    def __init__(self, ncol=12, nrow=4):
        self.ncol = ncol  # 定义网格世界的列
        self.nrow = nrow  # 定义网格世界的行
        # 转移矩阵P[state][action] = [(p, next_state, reward, done)]包含下一个状态和奖励
        self.P = self.createP()

    def createP(self):
        # 初始化
        P = [[[] for j in range(4)] for i in range(self.nrow * self.ncol)]
        # 4种动作, change[0]:上,change[1]:下, change[2]:左, change[3]:右。坐标系原点(0,0)
        # 定义在左上角
        change = [[0, -1], [0, 1], [-1, 0], [1, 0]]
        for i in range(self.nrow):
            for j in range(self.ncol):
                for a in range(4):
                    # 位置在悬崖或者目标状态,因为无法继续交互,任何动作奖励都为0
                    if i == self.nrow - 1 and j > 0:
                        P[i * self.ncol + j][a] = [(1, i * self.ncol + j, 0,
                                                    True)]
                        continue
                    # 其他位置
                    next_x = min(self.ncol - 1, max(0, j + change[a][0]))
                    next_y = min(self.nrow - 1, max(0, i + change[a][1]))
                    next_state = next_y * self.ncol + next_x
                    reward = -1
                    done = False
                    # 下一个位置在悬崖或者终点
                    if next_y == self.nrow - 1 and next_x > 0:
                        done = True
                        if next_x != self.ncol - 1:  # 下一个位置在悬崖
                            reward = -100
                    P[i * self.ncol + j][a] = [(1, next_state, reward, done)]
        return P

class ValueIteration:
    def __init__(self, env, theta, gamma):
        self.env = env
        self.v = [0] * self.env.ncol * self.env.nrow  # 初始化价值为0
        self.theta = theta  # 价值收敛阈值
        self.gamma = gamma
        # 价值迭代结束后得到的策略
        self.pi = [None for i in range(self.env.ncol * self.env.nrow)]
    def value_iteration(self):
        cnt = 0
        while 1:
            max_diff = 0
            new_v = [0] * self.env.ncol * self.env.nrow
            for s in range(self.env.ncol * self.env.nrow):
                qsa_list = []  # 开始计算状态s下的所有Q(s,a)价值
                for a in range(4):
                    qsa = 0
                    for res in self.env.P[s][a]:
                        p, next_state, r, done = res
                        qsa += p * (r + self.gamma * self.v[next_state] * (1 - done))
                    qsa_list.append(qsa)  # 这一行和下一行代码是价值迭代和策略迭代的主要区别
                new_v[s] = max(qsa_list)
                max_diff = max(max_diff, abs(new_v[s] - self.v[s]))
            self.v = new_v
            if max_diff < self.theta: break  # 满足收敛条件,退出评估迭代
            cnt += 1
        print("价值迭代一共进行%d轮" % cnt)
        self.get_policy()

    def get_policy(self):  # 根据价值函数导出一个贪婪策略
        for s in range(self.env.nrow * self.env.ncol):
            qsa_list = []
            for a in range(4):
                qsa = 0
                for res in self.env.P[s][a]:
                    p, next_state, r, done = res
                    qsa += p * (r + self.gamma * self.v[next_state] * (1 - done))
                qsa_list.append(qsa)
            maxq = max(qsa_list)
            cntq = qsa_list.count(maxq)  # 计算有几个动作得到了最大的Q值
            # 让这些动作均分概率
            self.pi[s] = [1 / cntq if q == maxq else 0 for q in qsa_list]

# class PolicyIteration:
#     """ 策略迭代算法 """
#     def __init__(self, env, theta, gamma):
#         self.env = env
#         self.v = [0] * self.env.ncol * self.env.nrow  # 初始化价值为0
#         self.pi = [[0.25, 0.25, 0.25, 0.25]
#                    for i in range(self.env.ncol * self.env.nrow)]  # 初始化为均匀随机策略
#         self.theta = theta  # 策略评估收敛阈值
#         self.gamma = gamma  # 折扣因子
# 
#     def policy_evaluation(self):  # 策略评估
#         cnt = 1  # 计数器
#         while 1:
#             max_diff = 0
#             new_v = [0] * self.env.ncol * self.env.nrow
#             for s in range(self.env.ncol * self.env.nrow):
#                 qsa_list = []  # 开始计算状态s下的所有Q(s,a)价值
#                 for a in range(4):
#                     qsa = 0
#                     for res in self.env.P[s][a]:
#                         p, next_state, r, done = res
#                         qsa += p * (r + self.gamma * self.v[next_state] * (1 - done))
#                         # 本章环境比较特殊,奖励和下一个状态有关,所以需要和状态转移概率相乘
#                     qsa_list.append(self.pi[s][a] * qsa)
#                 new_v[s] = sum(qsa_list)  # 状态价值函数和动作价值函数之间的关系
#                 max_diff = max(max_diff, abs(new_v[s] - self.v[s]))
#             self.v = new_v
#             if max_diff < self.theta: break  # 满足收敛条件,退出评估迭代
#             cnt += 1
#         print("策略评估进行%d轮后完成" % cnt)
# 
#     def policy_improvement(self):  # 策略提升
#         for s in range(self.env.nrow * self.env.ncol):
#             qsa_list = []
#             for a in range(4):
#                 qsa = 0
#                 for res in self.env.P[s][a]:
#                     p, next_state, r, done = res
#                     qsa += p * (r + self.gamma * self.v[next_state] * (1 - done))
#                 qsa_list.append(qsa)
#             maxq = max(qsa_list)
#             cntq = qsa_list.count(maxq)  # 计算有几个动作得到了最大的Q值
#             # 让这些动作均分概率
#             self.pi[s] = [1 / cntq if q == maxq else 0 for q in qsa_list]
#         print("策略提升完成")
#         return self.pi
# 
#     def policy_iteration(self):  # 策略迭代
#         while 1:
#             self.policy_evaluation()
#             old_pi = copy.deepcopy(self.pi)  # 将列表进行深拷贝,方便接下来进行比较
#             new_pi = self.policy_improvement()
#             if old_pi == new_pi: break

def print_agent(agent, action_meaning, disaster=[], end=[]):
    print("状态价值:")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 为了输出美观,保持输出6个字符
            print('%6.6s' % ('%.3f' % agent.v[i * agent.env.ncol + j]), end=' ')
        print()

    print("策略:")
    for i in range(agent.env.nrow):
        for j in range(agent.env.ncol):
            # 一些特殊的状态,例如悬崖漫步中的悬崖
            if (i * agent.env.ncol + j) in disaster:
                print('****', end=' ')
            elif (i * agent.env.ncol + j) in end:  # 目标状态
                print('EEEE', end=' ')
            else:
                a = agent.pi[i * agent.env.ncol + j]
                pi_str = ''
                for k in range(len(action_meaning)):
                    pi_str += action_meaning[k] if a[k] > 0 else 'o'
                print(pi_str, end=' ')
        print()

if __name__ == '__main__':
    env = CliffWalkingEnv()
    action_meaning = ['^', 'v', '<', '>']
    theta = 0.001
    gamma = 0.9
    agent = ValueIteration(env, theta, gamma)
    agent.value_iteration()
    print_agent(agent, action_meaning, list(range(37, 47)), [47])

 

标签:qsa,nrow,迭代,代码,动手,range,ncol,env,self
From: https://www.cnblogs.com/zhangxianrong/p/18050171

相关文章

  • day52 动态规划part10 代码随想录算法训练营 122. 买卖股票的最佳时机 II
    题目:122.买卖股票的最佳时机II我的感悟:只要定义清楚,就可以做出来的。理解难点:先判断等于听课笔记:看了文字版本,感觉还是我的思路最牛逼!!我的代码:classSolution:defmaxProfit(self,prices:List[int])->int:#dp[i]为截止到当前能获得的最大利润......
  • day53 动态规划part10 代码随想录算法训练营 121. 买卖股票的最佳时机
    题目:121.买卖股票的最佳时机我的感悟:soeasy 打印dp确实能发现问题理解难点:注意条件,及时更新dp听课笔记:看了,老师的代码,感觉没有我的简洁,哈哈!!我的代码:classSolution:defmaxProfit(self,prices:List[int])->int:#设dp[i]为截止到当前能获得......
  • 动手学强化学习(四):动态规划算法
    第4章动态规划算法4.1简介动态规划(dynamicprogramming)是程序设计算法中非常重要的内容,能够高效解决一些经典问题,例如背包问题和最短路径规划。动态规划的基本思想是将待求解问题分解成若干个子问题,先求解子问题,然后从这些子问题的解得到目标问题的解。动态规划会保存已解决......
  • day52 动态规划part9 代码随想录算法训练营 337. 打家劫舍 III
    题目:337.打家劫舍III我的感悟:跳过,目前树的不学理解难点:树的理解,以及树的遍历听课笔记:我的代码:通过截图:老师代码:#Definitionforabinarytreenode.#classTreeNode:#def__init__(self,val=0,left=None,right=None):#self.val=val#......
  • day52 动态规划part9 代码随想录算法训练营 213. 打家劫舍 II
    题目:213.打家劫舍II我的感悟:看了题解不难,就是环这个思路转化很重要!理解难点:环的转化为,首,尾。代码上面可以省略长度为2的校验听课笔记:分3中情况:不考虑首尾|考虑首|考虑尾而情况2和情况3包含了情况1我的代码:classSolution:defrob(self,nums:List[i......
  • 编写更好的C#代码的技巧
    转载:https://blog.csdn.net/WuLex/article/details/123353742?spm=1001.2101.3001.6650.6&utm_medium=distribute.pc_relevant.none-task-blog-2~default~BlogCommendFromBaidu~Rate-6-123353742-blog-101057958.pc_relevant_3mothn_strategy_and_data_recovery&depth_......
  • ASP.NET(C#)返回上一页(后退)代码
    转:ASP.NET(C#)返回上一页(后退)代码//方法一protectedvoidPage_Load(objectsender,EventArgse){    if(!IsPostBack)   {ViewState["BackUrl"]=Request.UrlReferrer.ToString();}}//////返回按钮点击事件///protectedvoidButton1_Click(object......
  • 代码随想录 第11天 | 20. 有效的括号 ● 1047. 删除字符串中的所有相邻重复项 ● 150.
    Leetcode:20.有效的括号-力扣(LeetCode)思路:就是用栈存左右括号,都为0就说明true,不为零说明有没有匹配成功的括号,是false,思路没有问题,时间超时了,还得用C++...,java更好的思路如下:如果是左括号,push右括号,如果是右括号,判断是否与栈顶元素匹配,JAVA//deque.isEmpty();这个方法返回......
  • 对于需要实时处理的代码语句 就用定时器中断模式,实现多线程模式,建议不要用查询模式。
    对于需要实时处理的代码语句就用定时器中断模式,实现多线程模式,建议不要用查询模式。 示例代码1:查看代码#include"delay.h"#include"sysInt.h"#include"intrins.h"charSMGDuan[]={0x5B,0x3F,0x5B,0x66, 0x40,0x40, 0x3F,0x3F}; //2024--MMcharsegDuan[]={0x3F,0......
  • 数组(基于代码随想录)的随笔
    数组数组基础知识数组是存放在连续内存空间上的相同类型数据的集合。数组的元素是不能删的,只能覆盖。那么二维数组在内存的空间地址是连续的么?Java的二维数组在内存中不是3*4的连续地址空间,而是四条连续的地址空间组成!数组的经典题目二分法二分法时间复杂度:O(logn)......