首页 > 编程语言 >【强化学习】A grid world 值迭代算法 (value iterator algorithm)

【强化学习】A grid world 值迭代算法 (value iterator algorithm)

时间:2024-05-14 16:12:03浏览次数:27  
标签:algorithm iterator self value colors grid np ax row

强化学习——值迭代算法

代码是在 jupyter notebook 环境下编写

只需要 numpymatplotlib 包。

此代码为学习赵世钰老师强化学习课程之后,按照公式写出来的代码,对应第四章第一节 value iterator algorithm

可以做的实验:

  • 调整 gama 值观察策略的变化
  • 调整惩罚值(fa)的大小观察策略的变化
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors



class grid_world(object):

    ##
    # isRand 是否要每次都随机生成地图
    # gama 就是 gama
    # fa n x n 的地图默认生成 n 个 forbidden area , 如果觉得不够可以手动指定
    #    取 max(n,fa) 作为 number of forbidden area
    # fval 为惩罚值
    
    def __init__(self,n,isRand=False,gama = 0.9,fa=0,fval=-1):

        self.n = n
        self.fval = fval
        # 生成 n x n 的网格
        self.grid = np.zeros((n,n))
            # 设置终点
        self.grid[n-1,n-1] = 1
        # 便于 a 的遍历
        self.a = np.array([ [-1,0],[0,1],[1,0],[0,-1],[0,0] ])
        # 便于 s 的遍历
        self.s = np.array(range(n*n))
        # 便于 r 的遍历
        self.r = np.array([fval,0,1])
        # qk(s,a) 
        self.q_sa = np.zeros((n*n,5))
        # gama
        self.gama = gama
        # v 初始化 为 0
        self.v = np.zeros(n*n)
        # pi(s|a) 策略
        self.pi_sa = np.zeros((n*n,5))

        
        
        #随机种子处理
        if isRand is False:
            np.random.seed(8)
        
        # 生成 n 个 forbidden area
            # size = n 生成 n 个
            # replace = False 生成的 n 个数字不能重复
        forbidden = np.random.choice(range(n*n-1), size=max(n,fa), replace=False)
        #print(forbidden)
        for i in forbidden:
            row,col = self.one2two(i)
            self.grid[row,col] = fval
        # 查看生成的矩阵
        self.show_grid_default()
       


    def train(self,k):

        # 训练 k 轮
        for l in range(k):

            # v_k+1 的值,等到一轮结束以后,再赋值给 self.v[]
            # 在一轮中暂存在这里
            tev = np.zeros(self.n*self.n)

            for i in range(len(self.s)):
                for j in range(len(self.a)):

                    sum_r = 0
                    # 遍历 r
                    for r in self.r:
                        sum_r = sum_r + r * self.p_rsa(r,i,j)

                    self.q_sa[i,j] = sum_r + self.gama * self.p_ssa(i,j)
                    
                # a_star 为当前 s action value 最大的值的下标
                a_star = np.argmax(self.q_sa[i])
                # pi(s|a) 存储策略,先把之前存储的策略清零,再把新的策略给赋值
                self.pi_sa[i,:] = 0
                self.pi_sa[i,a_star] = 1
                # 存储 v_k+1(s)
                tev[i] = self.q_sa[i,a_star]

            # 更新 state value
            self.v[:] = tev[:]
        #    print(self.q_sa[:,:])
        self.showPi()
        self.showV()

    
    # 在每个方格中显示当前的策略
    def showPi(self):
        data = self.v.reshape(self.n,self.n)
        # 创建图像和轴对象
        fig, ax = plt.subplots()
        
        # 使用 matshow 
        colors = [(0, 'red'), (0.5, 'white'), (1, 'yellow')]
        mycmap = mcolors.LinearSegmentedColormap.from_list('mycmap', colors)
        # 在包含负值时,要做 norm 处理不然会报错,不能在 colors 中有负值,且 colors 指定的值的顺序必须是 ascend (都报错了)
        norm = mcolors.TwoSlopeNorm(vmin=-3, vcenter=0, vmax=3)
        cax = ax.matshow(self.grid, cmap=mycmap,norm=norm)

        
        # 在每个单元格中添加文本
        for (i, j), val in np.ndenumerate(data):
            teval = '↓'
            for k in range(len(self.a)):
                a_star = np.argmax(self.q_sa[i*self.n + j])
                if a_star == 0 :
                    teval = '↑'
                elif a_star == 1 :
                    teval = '→'
                elif a_star == 2 :
                    teval = '↓'
                elif  a_star == 3 :
                    teval = '←'
                else:
                    teval = 'o'
                
            
            ax.text(j, i, f'{teval}', ha='center', va='center', color='black')
            


        # 设置网格线
        ax.set_xticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.grid(which='minor', color='black', linestyle='-', linewidth=1)

        
        # 显示图像
        plt.show()
        

    # 显示每一个方格中的  state value    
    def showV(self):
        data = self.v.reshape(self.n,self.n)
         # 创建图像和轴对象
        fig, ax = plt.subplots()
        
        # 使用 matshow 
        colors = [(0, 'red'), (0.5, 'white'), (1, 'yellow')]
        mycmap = mcolors.LinearSegmentedColormap.from_list('mycmap', colors)
        # 在包含负值时,要做 norm 处理不然会报错,不能在 colors 中有负值,且 colors 指定的值的顺序必须是 ascend (都报错了)
        norm = mcolors.TwoSlopeNorm(vmin=-3, vcenter=0, vmax=3)
        cax = ax.matshow(self.grid, cmap=mycmap,norm=norm)

        
        # 在每个单元格中添加文本
        for (i, j), val in np.ndenumerate(data):
            ax.text(j, i, f'{val:.1f}', ha='center', va='center', color='black')
            


        # 设置网格线
        ax.set_xticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.grid(which='minor', color='black', linestyle='-', linewidth=1)

        
        # 显示图像
        plt.show()
        
    

    # p(r|s,a)
    def p_rsa(self,r,s,a):
        #print(r,s,a)
        row,col = self.one2two(s)
        tx,ty = self.a[a] 
        row = row + tx
        col = col + ty
        if self.checkInWorld(row,col) and self.grid[row,col] == r:
            return True
        elif self.checkInWorld(row,col) == False and r == self.fval:
            return True
        else:
            return False

    # p(s'|s,a)
    # 这里我没有遍历,因为 s_i 与 a 已经确定 那么只有唯一的一个 s' 与之对应
    # 注意,越过了边界的话 v 是自己
    def p_ssa(self,s,a):
        row,col = self.one2two(s)
        tx,ty = self.a[a] 
        tr = row + tx
        tc = col + ty
        if self.checkInWorld(tr,tc) is False:
            return self.v[s]
        else:
            return self.v[tr * self.n + tc]

    # 查看是否超出了边界
    def checkInWorld(self,x,y):
        if x < 0 or x >= self.n or y < 0 or y >= self.n:
            return False
        else:
            return True

    # 由一维下标,转换为二维的坐标
    def one2two(self,x):
        row = x // self.n
        col = x % self.n
        return row,col
        

    # isShowWord 是否要在矩阵中写出数值
    # isShowBar 是否显示颜色条
    def show_grid_default(self,isShowWord=False,isShowBar=False):
        # 创建图像和轴对象
        fig, ax = plt.subplots()
        
        # 使用 matshow 
        colors = [(0, 'red'), (0.5, 'white'), (1, 'yellow')]
        mycmap = mcolors.LinearSegmentedColormap.from_list('mycmap', colors)
        # 在包含负值时,要做 norm 处理不然会报错,不能在 colors 中有负值,且 colors 指定的值的顺序必须是 ascend (都报错了)
        norm = mcolors.TwoSlopeNorm(vmin=-3, vcenter=0, vmax=3)
        cax = ax.matshow(self.grid, cmap=mycmap,norm=norm)

        if isShowWord:
            # 在每个单元格中添加文本
            for (i, j), val in np.ndenumerate(self.grid):
                ax.text(j, i, f'{val}', ha='center', va='center', color='black')
            
        # 添加颜色条
        if isShowBar:
            fig.colorbar(cax)

        # 设置网格线
        ax.set_xticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, self.n, 1), minor=True)
        ax.grid(which='minor', color='black', linestyle='-', linewidth=1)

        
        # 显示图像
        plt.show()
        
        
        
        
        
grid = grid_world(10,isRand=False,fa=30,gama=0.9,fval = -100)
grid.train(500)

运行截图

image-20240514155508036image-20240514155624859image-20240514155712520

标签:algorithm,iterator,self,value,colors,grid,np,ax,row
From: https://www.cnblogs.com/hoppz/p/18191511

相关文章

  • [Algorithm] Prim's Algorithm
    Prim'salgorithmisapopularmethodusedincomputerscienceforfindingaminimumspanningtreeforaconnected,undirectedgraph.Thismeansitfindsasubsetoftheedgesthatformsatreethatincludeseveryvertex,wherethetotalweightofall......
  • IfcValue
    IfcValue类型定义IfcValue是一种选择类型,用于在更专业的选择类型IFcSimpleValue、IFcMeasureValue和IFcDerivedMeasureValue之间进行选择。IfcSimpleValue简单数据类型的基本定义类型的选择类型。IfcMeasureValueISO10303-41基本度量类型的一种选择类型。BucalDerivedMeasur......
  • java.lang.IllegalArgumentException: Invalid value type for attribute 'factoryBea
    简介前排提示:这个错误一般是由于Spring新版本导致的与其他框架不兼容现象,解决办法一般是升级其他框架版本。使用springboot-3.2.5和myabtis-plus-3.5.0搭建开发环境时,启动Springboot程序时报错,报错信息:点击查看代码java.lang.IllegalArgumentException:Invalidvalu......
  • ValueError: 'a' cannot be empty unless no samples are taken
    Here,Imettheerrormessageasfollows:defmaldroid_noniid(dataset,train_labels,num_users):num_shards,num_imgs=110,120idx_shard=[iforiinrange(num_shards)]dict_users={i:np.array([])foriinrange(num_users)}idxs=np......
  • Object.values()对象遍历
    Object.keys() 对象的遍历 返回给定对象所有可枚举属性的数组;是属性名组成的数组letobj={a:1,b:2,c:3};Object.keys(obj).map((key)=>{console.log(key,obj[key]);}); Object.values() 对象的遍历返回一个给定对象自身的所有属性值的......
  • 返回Rich return value结果思考
    本文是在写过的代码中进行回顾,有理解不对的地方,望请指正!在库(Library)或框架(Framework)设计中,"Richreturnvalue"是指返回值的丰富性,意味着函数返回的不仅仅是一个简单的值,而是一个包含了额外信息的复合类型。这样的设计可以提供更多的上下文信息,方便调用者理解和处理函数的执行......
  • TheAlgorithms/C - 各种基础算法、数据结构的 C 语言实现+armink/SFUD - 一款基于 JED
    1、OpenMV-RT-基于恩智浦i.MXRT系列的开源机器视觉AI模块OpenMV-RT是一款基于恩智浦最近主打的i.MXRT超高性能系列MCU的视觉模块,模块设计者是恩智浦大牛工程师宋岩(对,就是ARMCortex-M3权威指南中文版作者)。模块源代码: https://github.com/RockySong/micropython......
  • 2022 Benelux Algorithm Programming Contest (BAPC 22) A 、I、J、L
    A.AdjustedAverage(暴力枚举+二分查找)分析读完题目可以发现k很小,那么考虑暴力做法的时间复杂度为\(O(C_n^k)\),对于\(k\leq3\)的其实可以直接暴力创过去,但对于\(k=4\)的情况显然不适用。那么对应\(k=4\)的情况考虑优化,可以选择将数分为两个集合,先用一个set存下其中一个集合的所......
  • dotnet 9 WPF 支持 Style 的 Setter 填充内容时可忽略 Value 标签
    本文记录WPF在dotnet9的一项XAML编写语法改进点,此改进点用于解决编写Style的Setter进行给Value赋值时,不能将Value当成默认内容,需要多写Value标签的问题。通过此改进点可减少两行XAML代码在原先的WPF版本里面,对Style的Setter填充复杂的对象内容时,大概的......
  • 解决Vue3项目警告:xxxis-declared-but-its-value-is-never-read
    刚刚在Vue3项目引入的一个组件Person下有红线,系统给出了警告,这是因为TypeScript会检查代码中未使用的变量,我定义了'Person'的变量,但是后续代码没有使用到它,从而导致Vetur(Vue的语法检查工具)给出了这个警告。解决方法:方法一:你可以删除或者在代码中使用'Person'变量或类型,以......