首页 > 其他分享 >actor critic 玩carpole游戏

actor critic 玩carpole游戏

时间:2024-05-13 13:53:42浏览次数:8  
标签:torch nn carpole state actor next critic

 

import gym
import torch
import torch.nn as nn
import torch.optim as optim
import pygame
import sys

# 定义Actor网络
class Actor(nn.Module):
    def __init__(self):
        super(Actor, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(4, 10),
            nn.ReLU(),
            nn.Linear(10, 2),
            nn.Softmax(dim=-1)
        )

    def forward(self, x):
        return self.fc(x)

# 定义Critic网络
class Critic(nn.Module):
    def __init__(self):
        super(Critic, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(4, 10),
            nn.ReLU(),
            nn.Linear(10, 1)
        )

    def forward(self, x):
        return self.fc(x)

# 训练模型
def train(actor, critic, actor_optimizer, critic_optimizer, state, action, reward, next_state, done):
    state = torch.tensor(state, dtype=torch.float)
    next_state = torch.tensor(next_state, dtype=torch.float)
    action = torch.tensor(action, dtype=torch.long)
    reward = torch.tensor(reward, dtype=torch.float)
    if done:
        next_value = 0
    else:
        next_value = critic(next_state).detach()
    
    # Critic loss
    value = critic(state)
    expected_value = reward + 0.99 * next_value
    critic_loss = (value - expected_value).pow(2).mean()
    
    # Actor loss
    probs = actor(state)
    dist = torch.distributions.Categorical(probs)
    log_prob = dist.log_prob(action)
    advantage = (expected_value - value).detach()  # TD error as advantage
    actor_loss = -log_prob * advantage
    
    # Update networks
    critic_optimizer.zero_grad()
    critic_loss.backward()
    critic_optimizer.step()
    
    actor_optimizer.zero_grad()
    actor_loss.backward()
    actor_optimizer.step()

# 设置环境和模型
env = gym.make('CartPole-v1')
actor = Actor()
critic = Critic()
actor_optimizer = optim.Adam(actor.parameters(), lr=0.001)
critic_optimizer = optim.Adam(critic.parameters(), lr=0.01)

pygame.init()
screen = pygame.display.set_mode((600, 400))
clock = pygame.time.Clock()

# 开始训练
for episode in range(10000):
    state = env.reset()
    done = False
    state = state[0]
    step= 0
    while not done:
        step += 1
        state_tensor = torch.tensor(state, dtype=torch.float)
        probs = actor(state_tensor)
        dist = torch.distributions.Categorical(probs)
        action = dist.sample().item()
        next_state, reward, done, _ ,_= env.step(action)
        
        train(actor, critic, actor_optimizer, critic_optimizer, state, action, reward, next_state, done)
        state = next_state
        
        # Pygame visualization
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                pygame.quit()
                sys.exit()

        # Drawing
        
        screen.fill((255, 255, 255))
        cart_x = int(state[0] * 100 + 300)
        pygame.draw.rect(screen, (0, 0, 255), (cart_x, 300, 50, 30))
        pygame.draw.line(screen, (255, 0, 0), (cart_x + 25, 300), (cart_x + 25 - int(50 * torch.sin(torch.tensor(state[2]))), 300 - int(50 * torch.cos(torch.tensor(state[2])))), 5)
        pygame.display.flip()
        clock.tick(200)

    print(f"第{episode}回合,玩{step}次挂了")

 

标签:torch,nn,carpole,state,actor,next,critic
From: https://www.cnblogs.com/LiuXinyu12378/p/18189050

相关文章

  • java.lang.IllegalArgumentException: Invalid value type for attribute 'factoryBea
    简介前排提示:这个错误一般是由于Spring新版本导致的与其他框架不兼容现象,解决办法一般是升级其他框架版本。使用springboot-3.2.5和myabtis-plus-3.5.0搭建开发环境时,启动Springboot程序时报错,报错信息:点击查看代码java.lang.IllegalArgumentException:Invalidvalu......
  • openfeign接口Springboot启动Bean报错未找到Singleton bean creation not allowed whi
    检查步骤检查springboot启动类是否标注@EnableFeignClients注解,未标注该注解会导致无法注入bean检查远程调用模块是否标注注解@FeignClient检查@FeignClient注解中是否写了正确的微服务名称(区分大小写)检查@FeignClient注解中标识的微服务是否启动​​原因:此处接......
  • LLaMA-Factory 训练 Llama3-Chinese-8B-Instruct 相关报错问题解决
    模型路径up主为llama中文社区模型地址https://www.modelscope.cn/models/FlagAlpha/Llama3-Chinese-8B-Instruct/summarysysinfov10032gnvcc--versioncuda11.8pythonimporttorchprint(torch.version)13.11pipinstallflash_attntimeout2下载whl报这个错......
  • GitHub two-factor authentication开启教程
    问题描述最近登录GitHub个人页面动不动就有一个提示框”......two-factorauthenticationwillberequiredforyouraccountstartingJan4,2024......“,点击去看了一下原来是GitHub对所有的用户登录都要开启双重身份认证,要在1月4号前完成解决办法GitHub个人页面点击右......
  • 原始翎风CLIENT8位 (13) actor的学习
    functionGetOffset(appr:integer):integer偏移大于1000退出nrace:=apprdiv10nrace0-90npos:=apprmod10npos0-9这个找的是怪物图片在文件中图片索引偏移量分为很多种,有偏移280,280是一个怪物的一组图片,例如MON1有偏移230,例如MON2有偏移360的,例如MON3appr应该......
  • A Critical Study on Data Leakage in Recommender System Offline Evaluation
    目录概主要内容数据集统计信息Top-NRecommendationListRecommendationAccuracy理想的切分方式代码JiY.,SunA.,ZhangJ.andLiC.Acriticalstudyondataleakageinrecommendersystemofflineevaluation.TOIS,2022.概本文讨论了现在的推荐系统评价方式(如L......
  • P6123 [NEERC2016] Hard Refactoring 题解
    本题说白了,就是一道big模拟!!!题意不再赘述,我们直接看思路。这里作者借鉴了某差分思想:末尾加空格,用于判断最后一个条件;若只有\(\le\),对给出的数字和数组第一个进行标记。标记的时候要+32769,因为数组中不存在负数下标,以免越界;若只有\(\ge\),就标记给出的数字和数组最后......
  • 使用Colab_LLaMA_Factory_LoRA微调_Llama3(可自定义数据)
    使用LLaMAFactory微调Llama-3中文对话模型项目主页: https://github.com/hiyouga/LLaMA-Factory这个过程超级简单,半个多小时在T4上就能跑完。完全可以替换成自己的数据,支持中文数据。安装LLaMAFactory依赖 1%cd/content/2%rm-rfLLaMA-Factory3!gitclo......
  • .net core,.net 6使用SoapCore开发webservice接口,以及使用HttpClientFactory动态访问we
    1.使用soapCorenuget包 2.新建接口及实现2.1新建接口 2.2新建实现 2.3新建接收实体 2.4返回实体 3.接口注入使用  4.启动程序,直接访问对应的asmx地址  ......
  • SpringBoot+MyBatisPlus报错 Invalid value type for attribute 'factoryBeanObjectTy
    依赖版本org.springframework.boot:spring-boot-starter-web:3.2.5com.baomidou:mybatis-plus-boot-starter:3.5.5错误Invalidvaluetypeforattribute'factoryBeanObjectType'问题原因:这个问题是由于依赖传递导致,在MyBatis起步依赖中的myBatis-spring版本过低,导致程......