与 RL Agent 进行训练#
在之前的教程中,我们介绍了如何定义一个 RL 任务环境、将其注册到 gym
注册表中,并使用一个随机 agent 与其交互。现在我们继续进行下一步: 训练一个 RL agent 来解决这个任务。
尽管 envs.ManagerBasedRLEnv
符合 gymnasium.Env
接口,但它并不完全是一个 gym
环境。环境的输入和输出不是 numpy 数组,而是基于torch tensors,其中第一个维度是环境实例的数量。
此外,大多数 RL 库都期望其自己的环境接口变体。例如, Stable-Baselines3 期望环境符合其 VecEnv API ,该 API 期望接收一个 numpy 数组列表而不是一个单一的张量。类似地, RSL-RL 、RL-Games 和 SKRL 也预期另一个接口。由于没有一种适合所有情况的解决方案,我们不将 envs.ManagerBasedRLEnv
基于任何特定的学习库。相反,我们实现了包装器来将环境转换为所期望的接口。这些包装器在 omni.isaac.lab_tasks.utils.wrappers
模块中指定。
在本教程中,我们将使用 Stable-Baselines3 来训练一个 RL agent 来解决 cartpole 平衡任务。
小心
在最后,使用所对应的学习框架的包装器对环境进行包装。这是因为学习框架的包装器修改了环境 API 的解释,这可能不再与 gymnasium.Env
兼容。
代码#
在本教程中,我们使用 source/standalone/workflows/sb3
目录中的 Stable-Baselines3 workflow 的训练脚本。
train.py 代码
1# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
2# All rights reserved.
3#
4# SPDX-License-Identifier: BSD-3-Clause
5
6"""Script to train RL agent with Stable Baselines3.
7
8Since Stable-Baselines3 does not support buffers living on GPU directly,
9we recommend using smaller number of environments. Otherwise,
10there will be significant overhead in GPU->CPU transfer.
11"""
12
13"""Launch Isaac Sim Simulator first."""
14
15import argparse
16import sys
17
18from omni.isaac.lab.app import AppLauncher
19
20# add argparse arguments
21parser = argparse.ArgumentParser(description="Train an RL agent with Stable-Baselines3.")
22parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
23parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
24parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
25parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
26parser.add_argument("--task", type=str, default=None, help="Name of the task.")
27parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
28parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
29# append AppLauncher cli args
30AppLauncher.add_app_launcher_args(parser)
31# parse the arguments
32args_cli, hydra_args = parser.parse_known_args()
33# always enable cameras to record video
34if args_cli.video:
35 args_cli.enable_cameras = True
36
37# clear out sys.argv for Hydra
38sys.argv = [sys.argv[0]] + hydra_args
39
40# launch omniverse app
41app_launcher = AppLauncher(args_cli)
42simulation_app = app_launcher.app
43
44"""Rest everything follows."""
45
46import gymnasium as gym
47import numpy as np
48import os
49import random
50from datetime import datetime
51
52from stable_baselines3 import PPO
53from stable_baselines3.common.callbacks import CheckpointCallback
54from stable_baselines3.common.logger import configure
55from stable_baselines3.common.vec_env import VecNormalize
56
57from omni.isaac.lab.envs import (
58 DirectMARLEnv,
59 DirectMARLEnvCfg,
60 DirectRLEnvCfg,
61 ManagerBasedRLEnvCfg,
62 multi_agent_to_single_agent,
63)
64from omni.isaac.lab.utils.dict import print_dict
65from omni.isaac.lab.utils.io import dump_pickle, dump_yaml
66
67import omni.isaac.lab_tasks # noqa: F401
68from omni.isaac.lab_tasks.utils.hydra import hydra_task_config
69from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
70
71
72@hydra_task_config(args_cli.task, "sb3_cfg_entry_point")
73def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
74 """Train with stable-baselines agent."""
75 # randomly sample a seed if seed = -1
76 if args_cli.seed == -1:
77 args_cli.seed = random.randint(0, 10000)
78
79 # override configurations with non-hydra CLI arguments
80 env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
81 agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
82 # max iterations for training
83 if args_cli.max_iterations is not None:
84 agent_cfg["n_timesteps"] = args_cli.max_iterations * agent_cfg["n_steps"] * env_cfg.scene.num_envs
85
86 # set the environment seed
87 # note: certain randomizations occur in the environment initialization so we set the seed here
88 env_cfg.seed = agent_cfg["seed"]
89 env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
90
91 # directory for logging into
92 log_dir = os.path.join("logs", "sb3", args_cli.task, datetime.now().strftime("%Y-%m-%d_%H-%M-%S"))
93 # dump the configuration into log-directory
94 dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg)
95 dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg)
96 dump_pickle(os.path.join(log_dir, "params", "env.pkl"), env_cfg)
97 dump_pickle(os.path.join(log_dir, "params", "agent.pkl"), agent_cfg)
98
99 # post-process agent configuration
100 agent_cfg = process_sb3_cfg(agent_cfg)
101 # read configurations about the agent-training
102 policy_arch = agent_cfg.pop("policy")
103 n_timesteps = agent_cfg.pop("n_timesteps")
104
105 # create isaac environment
106 env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
107 # wrap for video recording
108 if args_cli.video:
109 video_kwargs = {
110 "video_folder": os.path.join(log_dir, "videos", "train"),
111 "step_trigger": lambda step: step % args_cli.video_interval == 0,
112 "video_length": args_cli.video_length,
113 "disable_logger": True,
114 }
115 print("[INFO] Recording videos during training.")
116 print_dict(video_kwargs, nesting=4)
117 env = gym.wrappers.RecordVideo(env, **video_kwargs)
118
119 # convert to single-agent instance if required by the RL algorithm
120 if isinstance(env.unwrapped, DirectMARLEnv):
121 env = multi_agent_to_single_agent(env)
122
123 # wrap around environment for stable baselines
124 env = Sb3VecEnvWrapper(env)
125
126 if "normalize_input" in agent_cfg:
127 env = VecNormalize(
128 env,
129 training=True,
130 norm_obs="normalize_input" in agent_cfg and agent_cfg.pop("normalize_input"),
131 norm_reward="normalize_value" in agent_cfg and agent_cfg.pop("normalize_value"),
132 clip_obs="clip_obs" in agent_cfg and agent_cfg.pop("clip_obs"),
133 gamma=agent_cfg["gamma"],
134 clip_reward=np.inf,
135 )
136
137 # create agent from stable baselines
138 agent = PPO(policy_arch, env, verbose=1, **agent_cfg)
139 # configure the logger
140 new_logger = configure(log_dir, ["stdout", "tensorboard"])
141 agent.set_logger(new_logger)
142
143 # callbacks for agent
144 checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_dir, name_prefix="model", verbose=2)
145 # train the agent
146 agent.learn(total_timesteps=n_timesteps, callback=checkpoint_callback)
147 # save the final model
148 agent.save(os.path.join(log_dir, "model"))
149
150 # close the simulator
151 env.close()
152
153
154if __name__ == "__main__":
155 # run the main function
156 main()
157 # close sim app
158 simulation_app.close()
代码解释#
上面的大部分代码是创建日志目录、保存解析的配置和设置不同的 Stable-Baselines3 组件的样板代码。对于本教程,重要的部分是创建环境并使用 Stable-Baselines3 包装器对其进行包装。
代码中使用了三个包装器:
gymnasium.wrappers.RecordVideo
: 这个包装器记录环境的视频并将其保存到指定目录。这对于在训练过程中可视化 agent 的行为非常有用。wrappers.sb3.Sb3VecEnvWrapper
: 这个包装器将环境转换为 Stable-Baselines3 兼容的环境。stable_baselines3.common.vec_env.VecNormalize: 这个包装器对环境的观察和奖励进行标准化。
这些包装器中的每一个都通过反复执行 env = wrapper(env, *args, **kwargs)
来包装前一个包装器。然后使用最终的环境来训练 agent。有关这些包装器如何工作的更多信息,请参考 包装环境 文档。
代码执行#
我们训练一个从 Stable-Baselines3 学习的 PPO agent 来解决 cartpole 平衡任务。
训练 agent#
训练 agent 有三种主要方法。每种方法都有其自己的优点和缺点。根据您的用例,您可以决定使用哪种方法。
无界面执行#
如果设置了 --headless
标志,则在训练过程中不会呈现模拟。当在远程服务器上进行训练或者不想看到模拟时,这很有用。通常情况下,此操作会加快训练过程,因为只执行物理模拟步骤。
./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --num_envs 64 --headless
无界面执行与离屏渲染#
由于上述命令不会呈现模拟,所以无法在训练过程中看到 agent 的行为。要可视化 agent 的行为,我们传递 --enable_cameras
,这会启用离屏渲染。此外,我们传递标志 --video
,这会记录 agent 在训练期间的行为视频。
./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --num_envs 64 --headless --video
视频保存在 logs/sb3/Isaac-Cartpole-v0/<run-dir>/videos/train
目录中。您可以使用任何视频播放器打开这些视频。
交互式执行#
虽然上述两种方法对于训练 agent 很有用,但不能让您与模拟进行交互以查看发生了什么。在这种情况下,您可以忽略 --headless
标志并按如下方式运行训练脚本:
./isaaclab.sh -p source/standalone/workflows/sb3/train.py --task Isaac-Cartpole-v0 --num_envs 64
这将打开 Isaac Sim 窗口,您可以看到 agent 在环境中进行训练。然而,这会减慢训练过程,因为模拟会在屏幕上呈现。作为变通方法,您可以在屏幕右下角停靠的 "Isaac Lab"
窗口中在不同的渲染模式之间切换。要了解更多有关这些渲染模式的信息,请查看 sim.SimulationContext.RenderMode
类。
查看日志#
在单独的终端中,您可以通过执行以下命令监视训练进度:
# execute from the root directory of the repository
./isaaclab.sh -p -m tensorboard.main --logdir logs/sb3/Isaac-Cartpole-v0
播放经过训练的 agent#
一旦训练完成,您可以通过执行以下命令来可视化经过训练的 agent:
# execute from the root directory of the repository
./isaaclab.sh -p source/standalone/workflows/sb3/play.py --task Isaac-Cartpole-v0 --num_envs 32 --use_last_checkpoint
上述命令将从 logs/sb3/Isaac-Cartpole-v0
目录加载最新的检查点。您也可以通过传递 --checkpoint
标志指定特定的检查点。