环境:half_cheetah.py
from os import path
import numpy as np
from gymnasium import utils
from gymnasium.envs.mujoco import MujocoEnv
from gymnasium.spaces import Box
DEFAULT_CAMERA_CONFIG = {
"distance": 4.0,
}
class MOHalfCheetahEnv(MujocoEnv, utils.EzPickle):
metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
"render_fps": 20,
}
def __init__(
self,
**kwargs,
):
utils.EzPickle.__init__(
self,
**kwargs,
)
# 计算 observation_space
observation_space = Box(
low=-np.inf, high=np.inf, shape=(17,), dtype=np.float64
)
# init
MujocoEnv.__init__(
self,
"half_cheetah.xml", # 直接使用库里面的
5,
observation_space=observation_space,
default_camera_config=DEFAULT_CAMERA_CONFIG,
**kwargs,
)
# mo相关属性
self.reward_space = Box(low=-np.inf, high=np.inf, shape=(2,))
self.reward_dim = 2
def step(self, action):
# pgmorl pdmorl 直接在这里对action进行裁剪动作
action = np.clip(action, -1.0, 1.0)
# 计算速度
x_position_before = self.data.qpos[0]
self.do_simulation(action, self.frame_skip)
x_position_after = self.data.qpos[0]
x_velocity = (x_position_after - x_position_before) / self.dt
# observation
observation = self._get_obs()
# reward
alive_bonus = 1
reward_run = min(4.0, x_velocity) + alive_bonus
reward_energy = 4.0 - 1.0 * np.square(action).sum() + alive_bonus
vec_reward = np.array([reward_run, reward_energy], dtype=np.float32)
# terminated truncated
ang = self.data.qpos[2]
# terminated = not (abs(ang) < np.deg2rad(50)) # 终止 pgmorl pdmorl有终止
terminated = False # 终止 pgmorl pdmorl有终止
truncated = False # 截断
# info
info = {}
# render
if self.render_mode == "human":
self.render()
return observation, vec_reward, terminated, truncated, info
def _get_obs(self):
position = self.data.qpos.flat.copy()
velocity = self.data.qvel.flat.copy()
position = position[1:] # obs 维度17
observation = np.concatenate((position, velocity)).ravel()
return observation
def reset_model(self):
qpos = self.init_qpos + self.np_random.uniform(
low=-0.1, high=0.1, size=self.model.nq
)
qvel = self.init_qvel + self.np_random.standard_normal(self.model.nv) * 0.1
self.set_state(qpos, qvel)
return self._get_obs()
注册、不检查环境
from gymnasium.envs.registration import register
import mo_gymnasium as mo_gym
from half_cheetah import MOHalfCheetahEnv
register(
id="wx-half-v1",
entry_point=MOHalfCheetahEnv,
max_episode_steps=500,
)
if __name__ == '__main__':
import gymnasium as gym
# env = MOHalfCheetahEnv(render_mode="human")
# env = MOHalfCheetahEnv()
# env = mo_gym.make('mo-halfcheetah-v4') # 无done 1000次
# env = gym.make("HalfCheetah-v4") # 无done 1000次
env = gym.make("wx-half-v1", disable_env_checker=True)
done = False
obv, info = env.reset(seed=5)
env.action_space.seed(5)
env.observation_space.seed(5)
print(type(env))
steps = 0
while not done:
action = env.action_space.sample()
obv, r, d1, d2, _ = env.step(action)
# print(r)
done = d1 or d2
steps += 1
print(steps)
print(steps)
标签:observation,自定义,gym,环境,action,env,np,import,self
From: https://www.cnblogs.com/Twobox/p/18361816