首页 > 编程语言 >强化学习代码实战-04时序差分算法(SARSA)

强化学习代码实战-04时序差分算法(SARSA)

时间:2022-11-10 18:35:48浏览次数:46  
标签:return 04 差分 next SARSA action reward col row

import numpy as np
import random

# 获取一个格子的状态
def get_state(row, col):
    if row!=3:
        return 'ground'
    if row == 3 and col == 11:
        return 'terminal'
    if row == 3 and col == 0:
        return 'ground'
    return 'trap'

# 在某一状态下执行动作,获得对应奖励
def move(row, col, action):
    # 状态检查-进入陷阱或结束,则不能执行任何动作,获得0奖励
    if get_state(row, col) in ["trap", "terminal"]:
        return row, col, 0
    # 执行上下左右动作后,对应的位置变化
    if action == 0:
        row -= 1
    if action == 1:
        row += 1
    if action == 2:
        col -= 1
    if action == 3:
        col += 1
    # 最小不能小于零,最大不能大于3
    row = max(0, row)
    row = min(3, row)
    col = max(0, col)
    col = min(11, col)
    
    # 掉进trap奖励-100,其余每走一步奖励-1,让agent尽快完成任务
    reward = -1
    if get_state(row, col) == 'trap':
        reward = -100
    return row, col, reward

# 初始化Q表格,每个格子采取每个动作的分数,刚开始都是未知的故为零
Q = np.zeros([4, 12, 4])

# 根据当前所处的格子,选取一个动作
def get_action(row, col):
    # 以一定的概率探索
    if random.random() < 0.1:
        return np.random.choice(range(4))
    else:
        # 返回当前Q表格中分数最高的动作
        return Q[row, col].argmax()
    
# 计算当前格子的更新量(当前格子采取动作后获得的奖励,来到下一个格子及要进行的动作)
def update(row, col, action, reward, next_row, next_col, next_action):
    target = reward + Q[next_row, next_col, next_action] * 0.95
    value = Q[row, col, action]
    # 时序查分计算td_error
    td_error = 0.1 * (target - value)
    # 返回误差值
    return td_error

def train():
    for epoch in range(10000):
        # 每次迭代开始,随机一个起点,尽可能多地与环境交互,同时绑定一个动作
        row = np.random.choice(range(4))
        col = 0
        action = get_action(row, col)
        # 计算本轮奖励的总和,越来越大
        rewards_sum = 0
        
        # 一直取探索,直到游戏结束或者进入trap(要判断)
        while get_state(row, col) not in ["terminal", "trap"]:
            # 当前状态下移动一次,获得新的状态
            next_row, next_col, reward = move(row, col, action)
            next_action = get_action(next_row, next_col)
            rewards_sum += reward
            # 获取此次移动的更新量
            td_error = update(row, col, action, reward, next_row, next_col, next_action)
            # 更新Q表格
            Q[row, col, action] += td_error
            # 状态更新
            row, col, action = next_row, next_col, next_action
        if epoch % 500 == 0:
            print(f"epoch:{epoch}, rewards_sum:{rewards_sum}")
        

 

标签:return,04,差分,next,SARSA,action,reward,col,row
From: https://www.cnblogs.com/demo-deng/p/16877999.html

相关文章

  • NC207040 丢手绢
    题目描述链接:https://ac.nowcoder.com/acm/problem/207040来源:牛客网牛客幼儿园的小朋友们围成了一个圆圈准备玩丢手绢的游戏,但是小朋友们太小了,不能围成一个均匀的圆圈......
  • Ubuntu22.04配置静态IP
    1打开配置文件sudovim/etc/netplan/01-network-manager-all.yaml2输入以下配置network:version:2renderer:NetworkManagerethernets:ens33:......
  • NFC 读卡器ACR122U-A9接入Ubuntu 18.04系统
    虚拟机环境:VirtualBox图形用户界面 版本6.0.24r139119系统环境:18.04.1-Ubuntu需求:在Ubuntu环境下,接入NFC读卡器ACR122U接入方式:pcsc-lite封装了访问使用SCardAP......
  • leetcode704
    二分查找Category Difficulty Likes Dislikesalgorithms Easy(54.59%) 1037 -TagsCompanies给定一个n个元素有序的(升序)整型数组nums和一个目标值target,写一个......
  • [JavaScript-04]Switch
    1.Switch//Switch语句constcolor='green';switch(color){case'red':console.log('colorisred');break;case'blue':......
  • [Bug0049]SwitchHosts报错:没有写入 Hosts 文件的权限
    问题SwitchHosts报错:没有写入Hosts文件的权限解决方案1、打开如下目录C:\Windows\System32\drivers\etc2、右键hosts文件->点击安全->点击编辑->找到User......
  • Atcoder Grand Contest 004(A~F)
    这场半VP做的,就不分赛时赛后写了,直接放每道题的解法。A-DivideaCuboid当某一维的长度为偶数的时候,显然可以在这一维的中间切,两部分方块的最小差为\(0\)。当每一......
  • HCIP-ICT实战进阶04-ISIS原理与配置
    HCIP-ICT实战进阶04-ISIS原理与配置0前言IS-IS(IntermediateSystemtoIntermediateSystem,中间系统到中间系统)协议,和OSPF一样属于内部网关协议,也是一种采用SP......
  • Ubuntu 20.04 LTS/RTX30XX显卡 快速配置深度学习环境(一行命令)
    近日,新入一台RTX3080的服务器,目前好像还没办法很方便地在RTX30系列GPU上通过pip/conda安装TensorFlow或PyTorch。因为这些GPU需要CUDA11.1,而当前主流的Tensor......
  • 关于异常DBG_TERMINATE_PROCESS(0x40010004)
    简介DBG_TERMINATE_PROCESS表示进程被调试器终止。值为0x40010004。其定义如下:////MessageId:DBG_TERMINATE_PROCESS////MessageText:////Debuggerterminatedproce......