gymnasium.Env
类提供了与 DummyVecEnv
不同的 reset
方法签名,该方法需要接收两个仅限关键字的参数 seed
和 options
:
Env.reset(self, *, seed: int | None = None, options: dict[str, Any] | None = None) -> tuple[ObsType, dict[str, Any]]
这意味着在自定义环境中,你需要实现 reset
函数来处理这些参数,并最终返回一个由有效观察值(ObsType 类型)和一个字典组成的元组。
需要注意的要点:
- 自定义环境中的
reset
方法签名需要与 gymnasium.Env
类保持一致,即需要包含 seed
和 options
参数。
reset
方法的返回值形式不正确,它应该返回一个 有效观察值和一个字典。
step
方法的返回值形式也需要修改,以便表明结果是否被截断或模型是否超出边界。
以下是纠正后的 reset
方法示例:
def reset(self, *, seed=None, options=None): # 修复输入签名
# 重置环境
self.flip_result = 0 # 注意此处的 flip_result 必须是有效观察值类型而非 None
return self._get_observation(), {} # 返回有效的观察值和空字典以修正返回签名
如果你返回 None 并且底层使用了 NumPy 数组,那么类似于 array([0])[0] = obs <- None
的操作将导致错误。
此外,step
方法需要返回五个参数:观察值(observation)、奖励(reward)、终止标志(terminated)、截断标志(truncated)以及信息(info):
def step(self, action):
# 执行动作(0 代表正面,1 代表反面)
self.flip_result = int(np.random.rand() < self.heads_probability)
# 计算奖励(正确预测得 1 分,错误预测得 -1 分)
reward = 1 if self.flip_result == action else -1
# 返回观察值、奖励、终止状态(在这里始终为 True)、截断状态(这里为 False)以及信息字典
return self._get_observation(), reward, True, False, {}
# 修复以上问题后,模型就可以顺利地进行训练了
接下来是一段模型训练的相关输出信息
-----------------------------
| time/ | |
| fps | 5608 |
| iterations | 1 |
| time_elapsed | 0 |
| total_timesteps | 2048 |
-----------------------------
-----------------------------------------
| time/ | |
| fps | 3530 |
| iterations | 2 |
| time_elapsed | 1 |
| total_timesteps | 4096 |
| train/ | |
| approx_kl | 0.020679139 |
| clip_fraction | 0.617 |
| clip_range | 0.2 |
| entropy_loss | -0.675 |
| explained_variance | 0 |
| learning_rate | 0.0003 |
| loss | 0.38 |
| n_updates | 10 |
| policy_gradient_loss | -0.107 |
| value_loss | 1 |
-----------------------------------------
-----------------------------------------
| time/ | |
| fps | 3146 |
| iterations | 3 |
| time_elapsed | 1 |
| total_timesteps | 6144 |
| train/ | |
| approx_kl | 0.032571375 |
| clip_fraction | 0.628 |
| clip_range | 0.2 |
| entropy_loss | -0.599 |
| explained_variance | 0 |
| learning_rate | 0.0003 |
| loss | 0.392 |
| n_updates | 20 |
| policy_gradient_loss | -0.104 |
| value_loss | 0.987 |
-----------------------------------------
---------------------------------------
| time/ | |
| fps | 2984 |
| iterations | 4 |
| time_elapsed | 2 |
| total_timesteps | 8192 |
| train/ | |
| approx_kl | 0.0691616 |
| clip_fraction | 0.535 |
| clip_range | 0.2 |
| entropy_loss | -0.417 |
| explained_variance | 0 |
| learning_rate | 0.0003 |
| loss | 0.335 |
| n_updates | 30 |
| policy_gradient_loss | -0.09 |
| value_loss | 0.941 |
---------------------------------------
----------------------------------------
| time/ | |
| fps | 2898 |
| iterations | 5 |
| time_elapsed | 3 |
| total_timesteps | 10240 |
| train/ | |
| approx_kl | 0.12130852 |
| clip_fraction | 0.125 |
| clip_range | 0.2 |
| entropy_loss | -0.189 |
| explained_variance | 0 |
| learning_rate | 0.0003 |
| loss | 0.536 |
| n_updates | 40 |
| policy_gradient_loss | -0.0397 |
| value_loss | 0.806 |
----------------------------------------
Action: [1], Observation: [0], Reward: [1.]
Action: [1], Observation: [0], Reward: [-1.]
Action: [1], Observation: [0], Reward: [-1.]
Action: [1], Observation: [0], Reward: [1.]
Action: [1], Observation: [0], Reward: [1.]
Action: [1], Observation: [0], Reward: [-1.]
Action: [1], Observation: [0], Reward: [1.]
Action: [1], Observation: [0], Reward: [-1.]
Action: [1], Observation: [0], Reward: [1.]
Action: [1], Observation: [0], Reward: [1.]