我想要完成的是通过 GraalVM 在 Spring Boot 项目中运行 python 脚本(在我的例子中是 agent.py)。
我已经安装了 GraalVM (ce jdk 21) 并给它一个路径。还在项目结构中定义并根据graalvm和graalvm的python更改了pom.xml。即使它给出了无法找到火炬模块的错误。我还有一个安装了 torch 的 conda env。
这是基于 [此项目] ( https://github.com/patrickloeber/snake-ai-pytorch ) 的 agent.py轻微的变化
import torch
import random
import numpy as np
from collections import deque
import pandas as pd
import platform
import time
from game import SnakeGameAI, Direction, Point
from model import Linear_QNet, QTrainer
from helper import plot
MAX_MEMORY = 100_000
BATCH_SIZE = 1000
LR = 0.001
class Agent:
def __init__(self):
self.n_games = 0
self.epsilon = 0 # randomness
self.gamma = 0.9 # discount rate
self.memory = deque(maxlen=MAX_MEMORY) # popleft()
self.model = Linear_QNet(11, 256, 3)
self.trainer = QTrainer(self.model, lr=LR, gamma=self.gamma)
self.metrics = []
self.start_time = time.time()
def get_state(self, game):
head = game.snake[0]
point_l = Point(head.x - 20, head.y)
point_r = Point(head.x + 20, head.y)
point_u = Point(head.x, head.y - 20)
point_d = Point(head.x, head.y + 20)
dir_l = game.direction == Direction.LEFT
dir_r = game.direction == Direction.RIGHT
dir_u = game.direction == Direction.UP
dir_d = game.direction == Direction.DOWN
state = [
# Danger straight
(dir_r and game.is_collision(point_r)) or
(dir_l and game.is_collision(point_l)) or
(dir_u and game.is_collision(point_u)) or
(dir_d and game.is_collision(point_d)),
# Danger right
(dir_u and game.is_collision(point_r)) or
(dir_d and game.is_collision(point_l)) or
(dir_l and game.is_collision(point_u)) or
(dir_r and game.is_collision(point_d)),
# Danger left
(dir_d and game.is_collision(point_r)) or
(dir_u and game.is_collision(point_l)) or
(dir_r and game.is_collision(point_u)) or
(dir_l and game.is_collision(point_d)),
# Move direction
dir_l,
dir_r,
dir_u,
dir_d,
# Food location
game.food.x < game.head.x, # food left
game.food.x > game.head.x, # food right
game.food.y < game.head.y, # food up
game.food.y > game.head.y # food down
]
return np.array(state, dtype=int)
def remember(self, state, action, reward, next_state, done):
self.memory.append((state, action, reward, next_state, done)) # popleft if MAX_MEMORY is reached
def train_long_memory(self):
if len(self.memory) > BATCH_SIZE:
mini_sample = random.sample(self.memory, BATCH_SIZE) # list of tuples
else:
mini_sample = self.memory
states, actions, rewards, next_states, dones = zip(*mini_sample)
self.trainer.train_step(states, actions, rewards, next_states, dones)
#for state, action, reward, next_state, done in mini_sample:
# self.trainer.train_step(state, action, reward, next_state, done)
def train_short_memory(self, state, action, reward, next_state, done):
self.trainer.train_step(state, action, reward, next_state, done)
def get_action(self, state):
# random moves: tradeoff exploration / exploitation
self.epsilon = 80 - self.n_games
final_move = [0,0,0]
if random.randint(0, 200) < self.epsilon:
move = random.randint(0, 2)
final_move[move] = 1
else:
state0 = torch.tensor(state, dtype=torch.float)
prediction = self.model(state0)
move = torch.argmax(prediction).item()
final_move[move] = 1
return final_move
def save_metrics_excel(self, file_name='metrics.xlsx'):
hardware_info = {
'system': platform.system(),
'machine': platform.machine(),
'platform': platform.platform(),
'processor': platform.processor(),
'cpu_count': torch.cuda.device_count() if torch.cuda.is_available() else 0,
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'None'
}
total_time = time.time() - self.start_time
data = {
'hardware_info': hardware_info,
'total_training_time': total_time,
'metrics': self.metrics
}
# Flatten data for the Excel file
metrics_df = pd.DataFrame(self.metrics)
hardware_df = pd.DataFrame([hardware_info])
training_time_df = pd.DataFrame([{'total_training_time': total_time}])
with pd.ExcelWriter(file_name) as writer:
metrics_df.to_excel(writer, sheet_name='Metrics', index=False)
hardware_df.to_excel(writer, sheet_name='Hardware', index=False)
training_time_df.to_excel(writer, sheet_name='Training Time', index=False)
def train():
plot_scores = []
plot_mean_scores = []
total_score = 0
record = 0
agent = Agent()
game = SnakeGameAI()
while True:
# get old state
state_old = agent.get_state(game)
# get move
final_move = agent.get_action(state_old)
# perform move and get new state
reward, done, score = game.play_step(final_move)
state_new = agent.get_state(game)
# train short memory
agent.train_short_memory(state_old, final_move, reward, state_new, done)
# remember
agent.remember(state_old, final_move, reward, state_new, done)
if done:
# train long memory, plot result
game.reset()
agent.n_games += 1
agent.train_long_memory()
if score > record:
record = score
agent.model.save()
print('Game', agent.n_games, 'Score', score, 'Record:', record)
plot_scores.append(score)
total_score += score
mean_score = total_score / agent.n_games
plot_mean_scores.append(mean_score)
plot(plot_scores, plot_mean_scores)
# Collect metrics
agent.metrics.append({
'Game': agent.n_games,
'Score': score,
'Record': record,
'Mean Score': mean_score,
'Timestamp': time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())
})
# Save metrics to Excel periodically
if agent.n_games % 10 == 0: # Save every 10 games
agent.save_metrics_excel()
# Return a summary message
return f'Training completed. Total games: {agent.n_games}, Final score: {score}, Record: {record}'
if __name__ == '__main__':
train()
Spring Boot 方面
package com.rocksoft.denemeGraalVM;
import org.graalvm.polyglot.Context;
import org.graalvm.polyglot.Source;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.scheduling.annotation.Scheduled;
import org.springframework.stereotype.Service;
import java.io.File;
import java.io.IOException;
@Service
public class SnakeGameService {
@Value("${python.script.path}")
private String scriptPath;
public void runGame() {
try {
// GraalVM context for Python execution
Context context = Context.newBuilder("python")
.allowAllAccess(true)
.build();
// Load and evaluate the Python script
Source source = Source.newBuilder("python", new File(scriptPath)).build();
context.eval(source);
org.graalvm.polyglot.Value result = context.getBindings("python").getMember("train").execute();
String gameStatus = result.asString();
// Handle the result
System.out.println("Game executed successfully. Status: " + gameStatus);
} catch (IOException e) {
e.printStackTrace();
System.err.println("Failed to load the Python script: " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
System.err.println("Error during Python script execution: " + e.getMessage());
}
}
@Scheduled(fixedRate = 60000) // runs every 60 seconds
public void scheduledGameRun() {
runGame();
}
}
还有我使用的依赖项
<dependencies>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.graalvm.sdk</groupId>
<artifactId>graal-sdk</artifactId>
<version>23.1.2</version>
</dependency>
<!-- https://mvnrepository.com/artifact/org.graalvm.polyglot/python-community -->
<dependency>
<groupId>org.graalvm.polyglot</groupId>
<artifactId>python-community</artifactId>
<version>23.1.2</version>
<type>pom</type>
<scope>runtime</scope>
</dependency>
</dependencies>
和错误
ModuleNotFoundError: No module named 'torch'
at <python> <module>(Unknown)
at org.graalvm.polyglot.Context.eval(Context.java:402)
at com.rocksoft.denemeGraalVM.SnakeGameService.runGame(SnakeGameService.java:27)
at com.rocksoft.denemeGraalVM.SnakeGameService.scheduledGameRun(SnakeGameService.java:46)
at java.base/jdk.internal.reflect.DirectMethodHandleAccessor.invoke(DirectMethodHandleAccessor.java:103)
at java.base/java.lang.reflect.Method.invoke(Method.java:580)
at org.springframework.scheduling.support.ScheduledMethodRunnable.runInternal(ScheduledMethodRunnable.java:130)
at org.springframework.scheduling.support.ScheduledMethodRunnable.lambda$run$2(ScheduledMethodRunnable.java:124)
at io.micrometer.observation.Observation.observe(Observation.java:499)
at org.springframework.scheduling.support.ScheduledMethodRunnable.run(ScheduledMethodRunnable.java:124)
at org.springframework.scheduling.support.DelegatingErrorHandlingRunnable.run(DelegatingErrorHandlingRunnable.java:54)
at java.base/java.util.concurrent.Executors$RunnableAdapter.call(Executors.java:572)
at java.base/java.util.concurrent.FutureTask.runAndReset(FutureTask.java:358)
at java.base/java.util.concurrent.ScheduledThreadPoolExecutor$ScheduledFutureTask.run(ScheduledThreadPoolExecutor.java:305)
at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1144)
at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:642)
at java.base/java.lang.Thread.run(Thread.java:1583)
Error during Python script execution: ModuleNotFoundError: No module named 'torch'
我愿意接受任何建议。
这个错误信息表示 GraalVM 上下文找不到 'torch' 模块。尽管你在 conda 环境中安装了 torch,但 GraalVM 并不会自动使用你的 conda 环境。你需要告诉 GraalVM 上下文去哪里寻找 Python 模块。
以下是如何解决这个问题,并通过 GraalVM 在 Spring Boot 项目中运行你的 Python 脚本的步骤:
1. 使用虚拟环境
最好为你的项目创建一个虚拟环境,以避免依赖项冲突。你可以使用 conda 或 venv:
使用 conda:
bash
conda create -n myenv python=3.8 # 将 3.8 替换为你需要的 Python 版本
conda activate myenv
pip install torch pandas numpy
使用 venv:
bash
python3 -m venv myenv
source myenv/bin/activate
pip install torch pandas numpy
2. 安装 GraalVM 的 Python 支持
确保你已经安装了 GraalVM 的 Python 支持(你已经在做这件事了)。
3. 修改你的 Spring Boot 代码
你需要修改你的 Spring Boot 代码,以告诉 GraalVM 上下文使用你的虚拟环境:
```java package com.rocksoft.denemeGraalVM;
import org.graalvm.polyglot.Context; import org.graalvm.polyglot.Source; import org.springframework.beans.factory.annotation.Value; import org.springframework.scheduling.annotation.Scheduled; import org.springframework.stereotype.Service;
import java.io.File; import java.io.IOException;
@Service public class SnakeGameService {
@Value("${python.script.path}")
private String scriptPath;
@Value("${python.env.path}")
private String pythonEnvPath;
public void runGame() {
try {
// 构建 GraalVM 上下文,并指定 Python 环境路径
Context context = Context.newBuilder("python")
.allowAllAccess(true)
.option("python.ForceImportSite","true")
.environment("PYTHONHOME", pythonEnvPath)
.build();
// 加载并执行 Python 脚本
Source source = Source.newBuilder("python", new File(scriptPath)).build();
context.eval(source);
org.graalvm.polyglot.Value result = context.getBindings("python").getMember("train").execute();
String gameStatus = result.asString();
// 处理结果
System.out.println("Game executed successfully. Status: " + gameStatus);
} catch (IOException e) {
e.printStackTrace();
System.err.println("Failed to load the Python script: " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
System.err.println("Error during Python script execution: " + e.getMessage());
}
}
@Scheduled(fixedRate = 60000) // 每 60 秒运行一次
public void scheduledGameRun() {
runGame();
}
} ```
-
将
${python.env.path}
替换为你的虚拟环境路径(例如,/home/user/myenv
)。 -
确保
python.ForceImportSite
选项设置为true
,这将强制 GraalVM 使用指定的PYTHONHOME
。
4. 设置环境变量
在你的 Spring Boot 应用程序的配置文件 (
application.properties
或
application.yml
) 中设置
python.env.path
属性:
properties
python.script.path=path/to/your/agent.py
python.env.path=/home/user/myenv # 替换为你的虚拟环境路径
5. 重新构建和运行你的应用程序
完成这些步骤后,重新构建并运行你的 Spring Boot 应用程序。GraalVM 现在应该能够找到 'torch' 模块并成功执行你的 Python 脚本。
标签:python,spring-boot,reinforcement-learning,graalvm,graalpython From: 78819853