首页 > 编程问答 >通过 GraalVM 在 Spring Boot 项目中运行 Python 脚本(强化学习)

通过 GraalVM 在 Spring Boot 项目中运行 Python 脚本(强化学习)

时间:2024-08-01 16:30:25浏览次数:22  
标签:python spring-boot reinforcement-learning graalvm graalpython

我想要完成的是通过 GraalVM 在 Spring Boot 项目中运行 python 脚本(在我的例子中是 agent.py)。

我已经安装了 GraalVM (ce jdk 21) 并给它一个路径。还在项目结构中定义并根据graalvm和graalvm的python更改了pom.xml。即使它给出了无法找到火炬模块的错误。我还有一个安装了 torch 的 conda env。

这是基于 [此项目] ( https://github.com/patrickloeber/snake-ai-pytorch ) 的 agent.py轻微的变化

import torch
import random
import numpy as np
from collections import deque
import pandas as pd
import platform
import time
from game import SnakeGameAI, Direction, Point
from model import Linear_QNet, QTrainer
from helper import plot

MAX_MEMORY = 100_000
BATCH_SIZE = 1000
LR = 0.001

class Agent:

    def __init__(self):
        self.n_games = 0
        self.epsilon = 0 # randomness
        self.gamma = 0.9 # discount rate
        self.memory = deque(maxlen=MAX_MEMORY) # popleft()
        self.model = Linear_QNet(11, 256, 3)
        self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
        self.metrics = []
        self.start_time = time.time()

    def get_state(self, game):
        head = game.snake[0]
        point_l = Point(head.x - 20, head.y)
        point_r = Point(head.x + 20, head.y)
        point_u = Point(head.x, head.y - 20)
        point_d = Point(head.x, head.y + 20)
        
        dir_l = game.direction == Direction.LEFT
        dir_r = game.direction == Direction.RIGHT
        dir_u = game.direction == Direction.UP
        dir_d = game.direction == Direction.DOWN

        state = [
            # Danger straight
            (dir_r and game.is_collision(point_r)) or 
            (dir_l and game.is_collision(point_l)) or 
            (dir_u and game.is_collision(point_u)) or 
            (dir_d and game.is_collision(point_d)),

            # Danger right
            (dir_u and game.is_collision(point_r)) or 
            (dir_d and game.is_collision(point_l)) or 
            (dir_l and game.is_collision(point_u)) or 
            (dir_r and game.is_collision(point_d)),

            # Danger left
            (dir_d and game.is_collision(point_r)) or 
            (dir_u and game.is_collision(point_l)) or 
            (dir_r and game.is_collision(point_u)) or 
            (dir_l and game.is_collision(point_d)),
            
            # Move direction
            dir_l,
            dir_r,
            dir_u,
            dir_d,
            
            # Food location 
            game.food.x < game.head.x,  # food left
            game.food.x > game.head.x,  # food right
            game.food.y < game.head.y,  # food up
            game.food.y > game.head.y  # food down
            ]

        return np.array(state, dtype=int)

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done)) # popleft if MAX_MEMORY is reached

    def train_long_memory(self):
        if len(self.memory) > BATCH_SIZE:
            mini_sample = random.sample(self.memory, BATCH_SIZE) # list of tuples
        else:
            mini_sample = self.memory

        states, actions, rewards, next_states, dones = zip(*mini_sample)
        self.trainer.train_step(states, actions, rewards, next_states, dones)
        #for state, action, reward, next_state, done in mini_sample:
        #    self.trainer.train_step(state, action, reward, next_state, done)

    def train_short_memory(self, state, action, reward, next_state, done):
        self.trainer.train_step(state, action, reward, next_state, done)

    def get_action(self, state):
        # random moves: tradeoff exploration / exploitation
        self.epsilon = 80 - self.n_games
        final_move = [0,0,0]
        if random.randint(0, 200) < self.epsilon:
            move = random.randint(0, 2)
            final_move[move] = 1
        else:
            state0 = torch.tensor(state, dtype=torch.float)
            prediction = self.model(state0)
            move = torch.argmax(prediction).item()
            final_move[move] = 1

        return final_move
    
    def save_metrics_excel(self, file_name='metrics.xlsx'):
        hardware_info = {
            'system': platform.system(),
            'machine': platform.machine(),
            'platform': platform.platform(),
            'processor': platform.processor(),
            'cpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
            'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'
        }
        total_time = time.time() - self.start_time

        data = {
            'hardware_info': hardware_info,
            'total_training_time': total_time,
            'metrics': self.metrics
        }

        # Flatten data for the Excel file
        metrics_df = pd.DataFrame(self.metrics)
        hardware_df = pd.DataFrame([hardware_info])
        training_time_df = pd.DataFrame([{'total_training_time': total_time}])

        with pd.ExcelWriter(file_name) as writer:
            metrics_df.to_excel(writer, sheet_name='Metrics', index=False)
            hardware_df.to_excel(writer, sheet_name='Hardware', index=False)
            training_time_df.to_excel(writer, sheet_name='Training Time', index=False)


def train():
    plot_scores = []
    plot_mean_scores = []
    total_score = 0
    record = 0
    agent = Agent()
    game = SnakeGameAI()
    while True:
        # get old state
        state_old = agent.get_state(game)

        # get move
        final_move = agent.get_action(state_old)

        # perform move and get new state
        reward, done, score = game.play_step(final_move)
        state_new = agent.get_state(game)

        # train short memory
        agent.train_short_memory(state_old, final_move, reward, state_new, done)

        # remember
        agent.remember(state_old, final_move, reward, state_new, done)

        if done:
            # train long memory, plot result
            game.reset()
            agent.n_games += 1
            agent.train_long_memory()

            if score > record:
                record = score
                agent.model.save()

            print('Game', agent.n_games, 'Score', score, 'Record:', record)

            plot_scores.append(score)
            total_score += score
            mean_score = total_score / agent.n_games
            plot_mean_scores.append(mean_score)
            plot(plot_scores, plot_mean_scores)

            # Collect metrics
            agent.metrics.append({
                'Game': agent.n_games,
                'Score': score,
                'Record': record,
                'Mean Score': mean_score,
                'Timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
            })

            # Save metrics to Excel periodically
            if agent.n_games % 10 == 0:  # Save every 10 games
                agent.save_metrics_excel()
            
            # Return a summary message
            return f'Training completed. Total games: {agent.n_games}, Final score: {score}, Record: {record}'


if __name__ == '__main__':
    train()

Spring Boot 方面

package com.rocksoft.denemeGraalVM;

import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Source;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;

import java.io.File;
import java.io.IOException;

@Service
public class SnakeGameService {

    @Value("${python.script.path}")
    private String scriptPath;

    public void runGame() {
        try {
            // GraalVM context for Python execution
            Context context = Context.newBuilder("python")
                    .allowAllAccess(true)
                    .build();

            // Load and evaluate the Python script
            Source source = Source.newBuilder("python", new File(scriptPath)).build();
            context.eval(source);

            org.graalvm.polyglot.Value result = context.getBindings("python").getMember("train").execute();
            String gameStatus = result.asString();

            // Handle the result
            System.out.println("Game executed successfully. Status: " + gameStatus);

        } catch (IOException e) {
            e.printStackTrace();
            System.err.println("Failed to load the Python script: " + e.getMessage());
        } catch (Exception e) {
            e.printStackTrace();
            System.err.println("Error during Python script execution: " + e.getMessage());
        }
    }

    @Scheduled(fixedRate = 60000) // runs every 60 seconds
    public void scheduledGameRun() {
        runGame();
    }

}

还有我使用的依赖项

<dependencies>
        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-web</artifactId>
        </dependency>

        <dependency>
            <groupId>org.springframework.boot</groupId>
            <artifactId>spring-boot-starter-test</artifactId>
            <scope>test</scope>
        </dependency>

        <dependency>
            <groupId>org.graalvm.sdk</groupId>
            <artifactId>graal-sdk</artifactId>
            <version>23.1.2</version>
        </dependency>

        <!-- https://mvnrepository.com/artifact/org.graalvm.polyglot/python-community -->
        <dependency>
            <groupId>org.graalvm.polyglot</groupId>
            <artifactId>python-community</artifactId>
            <version>23.1.2</version>
            <type>pom</type>
            <scope>runtime</scope>
        </dependency>


    </dependencies>

和错误

ModuleNotFoundError: No module named 'torch'
    at <python> <module>(Unknown)
    at org.graalvm.polyglot.Context.eval(Context.java:402)
    at com.rocksoft.denemeGraalVM.SnakeGameService.runGame(SnakeGameService.java:27)
    at com.rocksoft.denemeGraalVM.SnakeGameService.scheduledGameRun(SnakeGameService.java:46)
    at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103)
    at java.base/java.lang.reflect.Method.invoke(Method.java:580)
    at org.springframework.scheduling.support.ScheduledMethodRunnable.runInternal(ScheduledMethodRunnable.java:130)
    at org.springframework.scheduling.support.ScheduledMethodRunnable.lambda$run$2(ScheduledMethodRunnable.java:124)
    at io.micrometer.observation.Observation.observe(Observation.java:499)
    at org.springframework.scheduling.support.ScheduledMethodRunnable.run(ScheduledMethodRunnable.java:124)
    at org.springframework.scheduling.support.DelegatingErrorHandlingRunnable.run(DelegatingErrorHandlingRunnable.java:54)
    at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:572)
    at java.base/java.util.concurrent.FutureTask.runAndReset(FutureTask.java:358)
    at java.base/java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask.run(ScheduledThreadPoolExecutor.java:305)
    at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
    at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
    at java.base/java.lang.Thread.run(Thread.java:1583)
Error during Python script execution: ModuleNotFoundError: No module named 'torch'

我愿意接受任何建议。


这个错误信息表示 GraalVM 上下文找不到 'torch' 模块。尽管你在 conda 环境中安装了 torch,但 GraalVM 并不会自动使用你的 conda 环境。你需要告诉 GraalVM 上下文去哪里寻找 Python 模块。

以下是如何解决这个问题,并通过 GraalVM 在 Spring Boot 项目中运行你的 Python 脚本的步骤:

1. 使用虚拟环境

最好为你的项目创建一个虚拟环境,以避免依赖项冲突。你可以使用 conda 或 venv:

使用 conda:

bash conda create -n myenv python=3.8 # 将 3.8 替换为你需要的 Python 版本 conda activate myenv pip install torch pandas numpy

使用 venv:

bash python3 -m venv myenv source myenv/bin/activate pip install torch pandas numpy

2. 安装 GraalVM 的 Python 支持

确保你已经安装了 GraalVM 的 Python 支持(你已经在做这件事了)。

3. 修改你的 Spring Boot 代码

你需要修改你的 Spring Boot 代码,以告诉 GraalVM 上下文使用你的虚拟环境:

```java package com.rocksoft.denemeGraalVM;

import org.graalvm.polyglot.Context; import org.graalvm.polyglot.Source; import org.springframework.beans.factory.annotation.Value; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Service;

import java.io.File; import java.io.IOException;

@Service public class SnakeGameService {

 @Value("${python.script.path}")
 private String scriptPath;

 @Value("${python.env.path}") 
 private String pythonEnvPath;

 public void runGame() {
     try {
         // 构建 GraalVM 上下文,并指定 Python 环境路径
         Context context = Context.newBuilder("python")
                 .allowAllAccess(true)
                 .option("python.ForceImportSite","true")
                 .environment("PYTHONHOME", pythonEnvPath)
                 .build();

         // 加载并执行 Python 脚本
         Source source = Source.newBuilder("python", new File(scriptPath)).build();
         context.eval(source);

         org.graalvm.polyglot.Value result = context.getBindings("python").getMember("train").execute();
         String gameStatus = result.asString();

         // 处理结果
         System.out.println("Game executed successfully. Status: " + gameStatus);

     } catch (IOException e) {
         e.printStackTrace();
         System.err.println("Failed to load the Python script: " + e.getMessage());
     } catch (Exception e) {
         e.printStackTrace();
         System.err.println("Error during Python script execution: " + e.getMessage());
     }
 }

 @Scheduled(fixedRate = 60000) // 每 60 秒运行一次
 public void scheduledGameRun() {
     runGame();
 }

} ```

  • ${python.env.path} 替换为你的虚拟环境路径(例如, /home/user/myenv )。
  • 确保 python.ForceImportSite 选项设置为 true ,这将强制 GraalVM 使用指定的 PYTHONHOME

4. 设置环境变量

在你的 Spring Boot 应用程序的配置文件 ( application.properties application.yml ) 中设置 python.env.path 属性:

properties python.script.path=path/to/your/agent.py python.env.path=/home/user/myenv # 替换为你的虚拟环境路径

5. 重新构建和运行你的应用程序

完成这些步骤后,重新构建并运行你的 Spring Boot 应用程序。GraalVM 现在应该能够找到 'torch' 模块并成功执行你的 Python 脚本。

标签:python,spring-boot,reinforcement-learning,graalvm,graalpython
From: 78819853

相关文章

  • 为什么Python库中还没有Gauss-Jordan消去函数?
    Gaussian-Jordan消除法是一种常用且方便的技术。在矩阵计算中,该方法得到广泛应用。但是,Python库中没有此方法的内置函数。设计这样的功能并不困难。开发人员似乎忽略了这个功能。这可能是什么原因?希望开发者能够重视这一点,并开发与该方法相关的功能。虽然高斯-约旦消元......
  • Python爬虫 - js逆向之扣出某平台的_signature加密字段
    前言好久没有做逆向案例分析了,最近都在看同行朋友写好的案例,感觉学到很多,算是取长补短了不多bb,机缘巧合下,拿到个目标网站分析首先抓个包,就看到请求参数里带着这几个参数主要就是_signature了,其他的参数都不重要得想必你一看就知道啥意思了。行,开始分析_signature......
  • python 音频处理(1)——重采样、音高提取
    采集数据->采样率调整使用torchaudio进行重采样(cpu版)首先导入相关包,既然使用torch作为我们的选项,安装torch环境我就不必多说了,如果你不想用torch可以使用后文提到的另一个库1importtorch2importtorchaudio3fromtorchaudio.transformsimportResample4fromtime......
  • 用于从连接到 LAN 交换机的四个 Arduino 扩展板读取/保存数据的 Python 代码
    我有四个Arduino扩展板连接到LAN交换机。每个人都会发出一条“HelloJane”消息。LAN交换机已连接到我的PC。我需要将每个Arduino的数据/消息保存在PC上的单独文件夹中。请帮助使用Python脚本/代码。我能够看到来自所有四个Arduino的消息,但无法将它们保存到文......
  • 使用 Python 生产者和消费者在 Kubernetes minikube 上设置 Kafka Kraft
    我正在尝试从kubernetes集群外部连接到kubernetesminikubekafkapod。服务器启动没有任何问题,但我无法设法将本地kafka生产者/消费者连接到外部kafkapod。在集群内的kafka服务器映像上,我将bootstrap-server设置为:bin/kafka-topics.sh--create--bootst......
  • python llama_index.indices.list.retrievers 导入错误
    fromllama_indeximportGPTListIndexfromllama_index.indices.list.retrieversimportListIndexLLMRetrieverdocuments=SimpleDirectoryReader('./data').load_data()index=GPTListIndex.from_documents(documents,service_context=service_context)r......
  • 如何使用 python 和 bs4 修复抓取 web 表输出 csv
    请帮帮我,,我想在“td”、“Barcode”和“namaproduk”中获取2个数据,但我得到的数据非常糟糕。我应该修复什么?importcsvimportrequestsfrombs4importBeautifulSoupoutfile=open("dataaa.csv","w",newline='')writer=csv.writer(outfile)page=0whilepag......
  • 基于Django的超市小程序+47822(免费领源码)可做计算机毕业设计JAVA、PHP、爬虫、APP、小
    基于django超市小程序摘 要随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,超市小程序被用户普遍使用,为方便用户能够可以随时进行超市小程序的数据信息管理,特开发了基于djan......
  • 【优秀python django系统案例】基于python的医院挂号管理系统,角色包括医生、患者、管
    随着信息技术的迅猛发展,传统的医院挂号管理方式面临着效率低下、排队时间长、信息不对称等诸多问题。这些问题不仅影响患者的就医体验,也加重了医院工作人员的负担。在此背景下,基于Python的医院挂号管理系统应运而生。该系统旨在通过信息化手段优化挂号流程,提高管理效率,提升医疗......
  • Python-PLAXIS自动化建模技术与典型岩土工程
    原文链接:Python-PLAXIS自动化建模技术与典型岩土工程https://blog.csdn.net/2301_78164062/article/details/140607885?spm=1001.2014.3001.5501一:Plaxis软件及PlaxisPythonAPI环境搭建1、Plaxis2D\Plaxis3D软件2、面向对象编程语言Python及其开发环境Spyder3、Plaxis输......