与 RL Agent 进行训练#

在之前的教程中,我们介绍了如何定义一个 RL 任务环境、将其注册到 gym 注册表中,并使用一个随机 agent 与其交互。现在我们继续进行下一步: 训练一个 RL agent 来解决这个任务。

尽管 envs.ManagerBasedRLEnv 符合 gymnasium.Env 接口,但它并不完全是一个 gym 环境。环境的输入和输出不是 numpy 数组,而是基于torch tensors,其中第一个维度是环境实例的数量。

此外,大多数 RL 库都期望其自己的环境接口变体。例如, Stable-Baselines3 期望环境符合其 VecEnv API ,该 API 期望接收一个 numpy 数组列表而不是一个单一的张量。类似地, RSL-RLRL-GamesSKRL 也预期另一个接口。由于没有一种适合所有情况的解决方案,我们不将 envs.ManagerBasedRLEnv 基于任何特定的学习库。相反,我们实现了包装器来将环境转换为所期望的接口。这些包装器在 isaaclab_rl 模块中指定。

在本教程中,我们将使用 Stable-Baselines3 来训练一个 RL agent 来解决 cartpole 平衡任务。

小心

在最后,使用所对应的学习框架的包装器对环境进行包装。这是因为学习框架的包装器修改了环境 API 的解释,这可能不再与 gymnasium.Env 兼容。

代码#

在本教程中,我们使用 scripts/reinforcement_learning/sb3 目录中的 Stable-Baselines3 workflow 的训练脚本。

train.py 代码
  1# Copyright (c) 2022-2025, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
  2# All rights reserved.
  3#
  4# SPDX-License-Identifier: BSD-3-Clause
  5
  6
  7"""Script to train RL agent with Stable Baselines3."""
  8
  9"""Launch Isaac Sim Simulator first."""
 10
 11import argparse
 12import contextlib
 13import signal
 14import sys
 15from pathlib import Path
 16
 17from isaaclab.app import AppLauncher
 18
 19# add argparse arguments
 20parser = argparse.ArgumentParser(description="Train an RL agent with Stable-Baselines3.")
 21parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
 22parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
 23parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
 24parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
 25parser.add_argument("--task", type=str, default=None, help="Name of the task.")
 26parser.add_argument(
 27    "--agent", type=str, default="sb3_cfg_entry_point", help="Name of the RL agent configuration entry point."
 28)
 29parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
 30parser.add_argument("--log_interval", type=int, default=100_000, help="Log data every n timesteps.")
 31parser.add_argument("--checkpoint", type=str, default=None, help="Continue the training from checkpoint.")
 32parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
 33parser.add_argument("--export_io_descriptors", action="store_true", default=False, help="Export IO descriptors.")
 34parser.add_argument(
 35    "--keep_all_info",
 36    action="store_true",
 37    default=False,
 38    help="Use a slower SB3 wrapper but keep all the extra training info.",
 39)
 40parser.add_argument(
 41    "--ray-proc-id", "-rid", type=int, default=None, help="Automatically configured by Ray integration, otherwise None."
 42)
 43# append AppLauncher cli args
 44AppLauncher.add_app_launcher_args(parser)
 45# parse the arguments
 46args_cli, hydra_args = parser.parse_known_args()
 47# always enable cameras to record video
 48if args_cli.video:
 49    args_cli.enable_cameras = True
 50
 51# clear out sys.argv for Hydra
 52sys.argv = [sys.argv[0]] + hydra_args
 53
 54# launch omniverse app
 55app_launcher = AppLauncher(args_cli)
 56simulation_app = app_launcher.app
 57
 58
 59def cleanup_pbar(*args):
 60    """
 61    A small helper to stop training and
 62    cleanup progress bar properly on ctrl+c
 63    """
 64    import gc
 65
 66    tqdm_objects = [obj for obj in gc.get_objects() if "tqdm" in type(obj).__name__]
 67    for tqdm_object in tqdm_objects:
 68        if "tqdm_rich" in type(tqdm_object).__name__:
 69            tqdm_object.close()
 70    raise KeyboardInterrupt
 71
 72
 73# disable KeyboardInterrupt override
 74signal.signal(signal.SIGINT, cleanup_pbar)
 75
 76"""Rest everything follows."""
 77
 78import gymnasium as gym
 79import logging
 80import numpy as np
 81import os
 82import random
 83import time
 84from datetime import datetime
 85
 86from stable_baselines3 import PPO
 87from stable_baselines3.common.callbacks import CheckpointCallback, LogEveryNTimesteps
 88from stable_baselines3.common.vec_env import VecNormalize
 89
 90from isaaclab.envs import (
 91    DirectMARLEnv,
 92    DirectMARLEnvCfg,
 93    DirectRLEnvCfg,
 94    ManagerBasedRLEnvCfg,
 95    multi_agent_to_single_agent,
 96)
 97from isaaclab.utils.dict import print_dict
 98from isaaclab.utils.io import dump_yaml
 99
100from isaaclab_rl.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
101
102import isaaclab_tasks  # noqa: F401
103from isaaclab_tasks.utils.hydra import hydra_task_config
104
105# import logger
106logger = logging.getLogger(__name__)
107# PLACEHOLDER: Extension template (do not remove this comment)
108
109
110@hydra_task_config(args_cli.task, args_cli.agent)
111def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
112    """Train with stable-baselines agent."""
113    # randomly sample a seed if seed = -1
114    if args_cli.seed == -1:
115        args_cli.seed = random.randint(0, 10000)
116
117    # override configurations with non-hydra CLI arguments
118    env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
119    agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
120    # max iterations for training
121    if args_cli.max_iterations is not None:
122        agent_cfg["n_timesteps"] = args_cli.max_iterations * agent_cfg["n_steps"] * env_cfg.scene.num_envs
123
124    # set the environment seed
125    # note: certain randomizations occur in the environment initialization so we set the seed here
126    env_cfg.seed = agent_cfg["seed"]
127    env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
128
129    # directory for logging into
130    run_info = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
131    log_root_path = os.path.abspath(os.path.join("logs", "sb3", args_cli.task))
132    print(f"[INFO] Logging experiment in directory: {log_root_path}")
133    # The Ray Tune workflow extracts experiment name using the logging line below, hence, do not change it (see PR #2346, comment-2819298849)
134    print(f"Exact experiment name requested from command line: {run_info}")
135    log_dir = os.path.join(log_root_path, run_info)
136    # dump the configuration into log-directory
137    dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg)
138    dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg)
139
140    # save command used to run the script
141    command = " ".join(sys.orig_argv)
142    (Path(log_dir) / "command.txt").write_text(command)
143
144    # post-process agent configuration
145    agent_cfg = process_sb3_cfg(agent_cfg, env_cfg.scene.num_envs)
146    # read configurations about the agent-training
147    policy_arch = agent_cfg.pop("policy")
148    n_timesteps = agent_cfg.pop("n_timesteps")
149
150    # set the IO descriptors export flag if requested
151    if isinstance(env_cfg, ManagerBasedRLEnvCfg):
152        env_cfg.export_io_descriptors = args_cli.export_io_descriptors
153    else:
154        logger.warning(
155            "IO descriptors are only supported for manager based RL environments. No IO descriptors will be exported."
156        )
157
158    # set the log directory for the environment (works for all environment types)
159    env_cfg.log_dir = log_dir
160
161    # create isaac environment
162    env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
163
164    # convert to single-agent instance if required by the RL algorithm
165    if isinstance(env.unwrapped, DirectMARLEnv):
166        env = multi_agent_to_single_agent(env)
167
168    # wrap for video recording
169    if args_cli.video:
170        video_kwargs = {
171            "video_folder": os.path.join(log_dir, "videos", "train"),
172            "step_trigger": lambda step: step % args_cli.video_interval == 0,
173            "video_length": args_cli.video_length,
174            "disable_logger": True,
175        }
176        print("[INFO] Recording videos during training.")
177        print_dict(video_kwargs, nesting=4)
178        env = gym.wrappers.RecordVideo(env, **video_kwargs)
179
180    start_time = time.time()
181
182    # wrap around environment for stable baselines
183    env = Sb3VecEnvWrapper(env, fast_variant=not args_cli.keep_all_info)
184
185    norm_keys = {"normalize_input", "normalize_value", "clip_obs"}
186    norm_args = {}
187    for key in norm_keys:
188        if key in agent_cfg:
189            norm_args[key] = agent_cfg.pop(key)
190
191    if norm_args and norm_args.get("normalize_input"):
192        print(f"Normalizing input, {norm_args=}")
193        env = VecNormalize(
194            env,
195            training=True,
196            norm_obs=norm_args["normalize_input"],
197            norm_reward=norm_args.get("normalize_value", False),
198            clip_obs=norm_args.get("clip_obs", 100.0),
199            gamma=agent_cfg["gamma"],
200            clip_reward=np.inf,
201        )
202
203    # create agent from stable baselines
204    agent = PPO(policy_arch, env, verbose=1, tensorboard_log=log_dir, **agent_cfg)
205    if args_cli.checkpoint is not None:
206        agent = agent.load(args_cli.checkpoint, env, print_system_info=True)
207
208    # callbacks for agent
209    checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_dir, name_prefix="model", verbose=2)
210    callbacks = [checkpoint_callback, LogEveryNTimesteps(n_steps=args_cli.log_interval)]
211
212    # train the agent
213    with contextlib.suppress(KeyboardInterrupt):
214        agent.learn(
215            total_timesteps=n_timesteps,
216            callback=callbacks,
217            progress_bar=True,
218            log_interval=None,
219        )
220    # save the final model
221    agent.save(os.path.join(log_dir, "model"))
222    print("Saving to:")
223    print(os.path.join(log_dir, "model.zip"))
224
225    if isinstance(env, VecNormalize):
226        print("Saving normalization")
227        env.save(os.path.join(log_dir, "model_vecnormalize.pkl"))
228
229    print(f"Training time: {round(time.time() - start_time, 2)} seconds")
230
231    # close the simulator
232    env.close()
233
234
235if __name__ == "__main__":
236    # run the main function
237    main()
238    # close sim app
239    simulation_app.close()

代码解释#

上面的大部分代码是创建日志目录、保存解析的配置和设置不同的 Stable-Baselines3 组件的样板代码。对于本教程,重要的部分是创建环境并使用 Stable-Baselines3 包装器对其进行包装。

代码中使用了三个包装器:

  1. gymnasium.wrappers.RecordVideo: 这个包装器记录环境的视频并将其保存到指定目录。这对于在训练过程中可视化 agent 的行为非常有用。

  2. wrappers.sb3.Sb3VecEnvWrapper: 这个包装器将环境转换为 Stable-Baselines3 兼容的环境。

  3. stable_baselines3.common.vec_env.VecNormalize: 这个包装器对环境的观测和奖励进行标准化。

这些包装器中的每一个都通过反复执行 env = wrapper(env, *args, **kwargs) 来包装前一个包装器。然后使用最终的环境来训练 agent。有关这些包装器如何工作的更多信息,请参考 包装环境 文档。

代码执行#

我们训练一个从 Stable-Baselines3 学习的 PPO agent 来解决 cartpole 平衡任务。

训练 agent#

训练 agent 有三种主要方法。每种方法都有其自己的优点和缺点。根据您的用例,您可以决定使用哪种方法。

无界面执行#

如果设置了 --headless 标志,则在训练过程中不会呈现仿真。当在远程服务器上进行训练或者不想看到仿真时,这很有用。通常情况下,此操作会加快训练过程,因为只执行物理仿真步骤。

./isaaclab.sh -p scripts/reinforcement_learning/sb3/train.py --task Isaac-Cartpole-v0 --num_envs 64 --headless

无界面执行与离屏渲染#

由于上述命令不会呈现仿真,所以无法在训练过程中看到 agent 的行为。要可视化 agent 的行为,我们传递 --enable_cameras ,这会启用离屏渲染。此外,我们传递标志 --video ,这会记录 agent 在训练期间的行为视频。

./isaaclab.sh -p scripts/reinforcement_learning/sb3/train.py --task Isaac-Cartpole-v0 --num_envs 64 --headless --video

视频保存在 logs/sb3/Isaac-Cartpole-v0/<run-dir>/videos/train 目录中。您可以使用任何视频播放器打开这些视频。

交互式执行#

虽然上述两种方法对于训练 agent 很有用,但不能让您与仿真进行交互以查看发生了什么。在这种情况下,您可以忽略 --headless 标志并按如下方式运行训练脚本:

./isaaclab.sh -p scripts/reinforcement_learning/sb3/train.py --task Isaac-Cartpole-v0 --num_envs 64

这将打开 Isaac Sim 窗口,您可以看到 agent 在环境中进行训练。然而,这会减慢训练过程,因为仿真会在屏幕上呈现。作为变通方法,您可以在屏幕右下角停靠的 "Isaac Lab" 窗口中在不同的渲染模式之间切换。要了解更多有关这些渲染模式的信息,请查看 sim.SimulationContext.RenderMode 类。

查看日志#

在单独的终端中,您可以通过执行以下命令监视训练进度:

# execute from the root directory of the repository
./isaaclab.sh -p -m tensorboard.main --logdir logs/sb3/Isaac-Cartpole-v0

播放经过训练的 agent#

一旦训练完成,您可以通过执行以下命令来可视化经过训练的 agent:

# execute from the root directory of the repository
./isaaclab.sh -p scripts/reinforcement_learning/sb3/play.py --task Isaac-Cartpole-v0 --num_envs 32 --use_last_checkpoint

上述命令将从 logs/sb3/Isaac-Cartpole-v0 目录加载最新的检查点。您也可以通过传递 --checkpoint 标志指定特定的检查点。