首页 > 其他分享 >一个蒙特卡洛树搜索的例子

一个蒙特卡洛树搜索的例子

时间:2024-10-24 21:20:14浏览次数:6  
标签:node return self state 例子 搜索 action 蒙特卡洛 def

""" My Monte Carlo Tree Search Demo """

import argparse
import math
import random

from copy import deepcopy
from typing_extensions import Self


def parse_args() -> argparse.Namespace:
    """Parse arguments."""
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, help="Fix random seed", default=0)
    parser.add_argument("--tape_length", type=int, help="Tape length", default=24)
    parser.add_argument(
        "--sample_times_limit", type=int, help="Sample times limit", default=100
    )
    parser.add_argument(
        "--exploration_constant", type=float, help="Exploration constant", default=1.0
    )
    return parser.parse_args()


def set_seed(seed: int) -> None:
    """Set seed for reproducibility."""
    random.seed(seed)


class Action:
    """Action class."""

    def __init__(self, write_position):
        self.write_position = write_position


class State:
    """State class."""

    def __init__(self, tape_length: int) -> None:
        self.tape = [0] * tape_length
        self.tape_length = tape_length
        self.possible_actions = []
        for i in range(self.tape_length):
            self.possible_actions.append(Action(write_position=i))
        self.written_times = 0

    def get_possible_actions(self) -> list:
        """Get possible actions."""
        if self.is_terminal():
            return []
        return self.possible_actions

    def take_action(self, action: Action) -> Self:
        """Take action."""
        if action is None:
            return self
        new_state = deepcopy(self)
        new_state.tape[action.write_position] = 1
        new_state.written_times = self.written_times + 1
        return new_state

    def is_terminal(self) -> bool:
        """Check if the state is terminal."""
        if self.written_times == self.tape_length:
            return True
        return False

    def get_reward(self) -> int:
        """Get reward."""
        return sum(self.tape)

    def show_tape(self) -> None:
        """Show tape."""
        print(self.tape)


class TreeNode:
    """Tree node class."""

    def __init__(self, state: State, parent: Self) -> None:
        self.state = state
        self.is_terminal = state.is_terminal()
        self.is_fully_expanded = self.is_terminal
        self.parent = parent
        self.num_visits = 0
        self.total_reward = 0
        self.children = {}


class MCTS:
    """Monte Carlo Tree Search class."""

    def __init__(self, iteration_limit: int, exploration_constant: float) -> None:
        self.search_limit = iteration_limit
        self.exploration_constant = exploration_constant

    def search(self, initial_state: State) -> Action:
        """Search for the best action."""
        if initial_state.is_terminal():
            return None
        root = TreeNode(initial_state, None)
        for _ in range(self.search_limit):
            node = self.select_node(root)
            reward = self.rollout(node.state)
            self.back_propogate(node, reward)
        best_child = self.get_best_child(root, 0.0)
        return self.get_action(root, best_child)

    def select_node(self, node: TreeNode) -> TreeNode:
        """Select node."""
        while not node.is_terminal:
            if node.is_fully_expanded:
                node = self.get_best_child(node, self.exploration_constant)
            else:
                return self.expand(node)
        return node

    def get_best_child(self, node: TreeNode, exploration_value: float) -> TreeNode:
        """Get best child."""
        best_value = float("-inf")
        best_nodes = []
        for child in node.children.values():
            if child.num_visits == 0:
                return child
            node_value = (
                child.total_reward / child.num_visits
                + exploration_value
                * math.sqrt(2 * math.log(node.num_visits) / child.num_visits)
            )
            if node_value > best_value:
                best_value = node_value
                best_nodes = [child]
            elif node_value == best_value:
                best_nodes.append(child)
        return random.choice(best_nodes)

    def rollout(self, state: State) -> int:
        """Rollout."""
        while not state.is_terminal():
            action = random.choice(state.get_possible_actions())
            state = state.take_action(action)
        return state.get_reward()

    def back_propogate(self, node: TreeNode, reward: int) -> None:
        """Back propogate."""
        while node is not None:
            node.num_visits += 1
            node.total_reward += reward
            node = node.parent

    def expand(self, node: TreeNode) -> TreeNode:
        """Expand."""
        actions = node.state.get_possible_actions()
        for action in actions:
            if action not in node.children:
                new_node = TreeNode(node.state.take_action(action), node)
                node.children[action] = new_node
                if len(actions) == len(node.children):
                    node.is_fully_expanded = True
                    return new_node
        return None

    def get_action(self, father: TreeNode, child: TreeNode) -> Action:
        """Get action."""
        actions = father.state.get_possible_actions()
        for action in actions:
            if father.children[action] == child:
                return action
        return None


if __name__ == "__main__":
    args = parse_args()
    set_seed(args.seed)
    game_state = State(args.tape_length)
    searcher = MCTS(
        iteration_limit=args.sample_times_limit,
        exploration_constant=args.exploration_constant,
    )
    for _ in range(args.tape_length):
        bestAction = searcher.search(initial_state=game_state)
        game_state = game_state.take_action(bestAction)
        game_state.show_tape()
    print("Final reward:", game_state.get_reward())

  

标签:node,return,self,state,例子,搜索,action,蒙特卡洛,def
From: https://www.cnblogs.com/qiandeheng/p/18501358

相关文章

  • 「效率集」聚合搜索,浏览器必备的资源与信息搜索插件
    简介「效率集」聚合搜索是当前浏览器上最强的查资源找信息插件。它内置了上百款搜索引擎,支持聚合搜索。一键全网搜索全网可用资源,真正让用户实现电影,电视剧,音乐,电子书,网盘,磁力等资源的自由,也可以用于购物比价,特惠信息,新闻,学术资料等信息搜索。插件内置几十款国内外知名AI大模型......
  • 【AIGC】AI如何匹配RAG知识库:关键词搜索
    关键词搜索引言jieba库简介TF-IDF简介实践例子用jieba库提取关键词计算TF-IDF计算文档和查询相似度结果完整代码:总结引言RAG作为减少模型幻觉和让模型分析、回答私域相关知识最简单高效的方式,我们除了使用之外可以尝试了解其是如何实现的。在实现RAG的过程中,有语义......
  • The sol of coin(搜索减脂版)
    Thesolofcoin(搜索减脂版)https://www.luogu.com.cn/problem/P3878这题是模拟退火的板子,但这里先讲搜索(刚好练练搜索)搜索减脂\(1.\)按价值从大到小排序,你一不小心取的价值太大会被剪枝\(2.\)最多取n/2个金币,你取得太多是要被剪枝的codewith注解#include<bits/stdc++.......
  • 博客搭建之路:hexo搜索引擎收录
    hexo搜索引擎收录hexo版本5.0.2npm版本6.14.7next版本7.8.0写博客的目的肯定不是就只有自己能看到,想让更多的人看到就需要可以让搜索引擎来收录对应的文章。hexo支持生成站点地图sitemap在hexo下的_config.yml中配置站点地图url:https://zhhll.icusitemap:url:htt......
  • 唯品会按图搜索唯品会商品(拍立淘)API 返回值说明
    vip.item_search_img- 按图搜索唯品会商品(拍立淘)API返回值说明1.请求参数请求参数:imgid=/xupload.vip.com/38444a01-e842-49bc-99c7-5267f6d36628_1726122211035_tmp_search.jpg&page_token=参数说明:imgid:唯品会图片地址(先调用上传图片(upload_img)接口,返回图片地址)page......
  • 关键词搜索唯品会商品列表API返回值说明
    vip.item_search-按关键字搜索vip商品数据接口返回值说明1.请求参数请求参数:q=鞋子&start_price=&end_price=&page=&cat=&discount_only=&sort=&page_size=&seller_info=&nick=&ppath=参数说明:q:搜索关键字cat:分类IDstart_price:开始价格end_price:结束价格sort:排序[......
  • 热门短剧搜索网站+内置1.2万条短视频数据+无授权开心版
    热门短剧搜索网站+内置1.2万条短视频数据+无授权开心版热门短剧搜索网站+内置1.2万条短视频数据+无授权开心版运行环境PHP7.2+MYSQL5.6+伪静态......
  • 国内空白,AI将文字搜索转化为交互数据图表,融资4000万,已与Perplexity整合
    2024年10月17日。产品为利用生成式AI将文字搜索转化为数据图表的美国初创公司Tako,种子轮融资575万美元,折合人民币4000万元。国外AI搜索主导者Perplexity,其创始人也参与了这次融资。早在今年5月21日。Perplexity已与Tako进行合作。在Perplexity上提供相应的数据图表服务。图源:P......
  • Linux运行时动态库搜索路径优先级
    Windows运行时动态库搜索路径优先级:在Windows运行时,动态库(通常指DLL文件)的搜索路径遵循一定的优先级顺序,以确保程序能够正确地加载所需的动态库。以下是对Windows运行时动态库搜索路径优先级的总结:应用程序所在的目录:当一个应用程序(如exe文件)尝试加载一个DLL时,它首先会在自......
  • modsecurity:规则例子:匹配url
    一,拦截包含一个字符串的访问:1,例子:如下:11.89.39.11--[23/Oct/2024:04:47:22+0800]"GET/.git/configHTTP/1.1"404548"-""Mozilla/5.0(WindowsNT10.0;Win64;x64)AppleWebKit/537.36(KHTML,likeGecko)Chrome/70.0.3538.102Safari/537.36&qu......