配置一个RL Agent#
在之前的教程中,我们看到了如何使用 Stable-Baselines3 库训练RL智能体以解决cartpole平衡任务。在本教程中,我们将看到如何配置训练过程以使用不同的RL库和不同的训练算法。
在目录 scripts/reinforcement_learning 中,您将找到不同RL库的脚本。这些按照库名称命名的子目录进行组织。每个子目录包含该库的训练和回放脚本。
要为特定任务配置学习库,您需要为学习代理创建一个配置文件。这个配置文件用于创建学习代理的实例,并用于配置训练过程。类似于在:ref:`tutorial-register-rl-env-gym`教程中显示的环境注册,您可以使用``gymnasium.register``方法注册学习代理。
代码#
作为示例,我们将查看``isaaclab_tasks``包中为任务``Isaac-Cartpole-v0``包含的配置。这是我们在:ref:`tutorial-run-rl-training`教程中使用的相同任务。
gym.register(
id="Isaac-Cartpole-v0",
entry_point="isaaclab.envs:ManagerBasedRLEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": f"{__name__}.cartpole_env_cfg:CartpoleEnvCfg",
"rl_games_cfg_entry_point": f"{agents.__name__}:rl_games_ppo_cfg.yaml",
"rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CartpolePPORunnerCfg",
"rsl_rl_with_symmetry_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CartpolePPORunnerWithSymmetryCfg",
"skrl_cfg_entry_point": f"{agents.__name__}:skrl_ppo_cfg.yaml",
"sb3_cfg_entry_point": f"{agents.__name__}:sb3_ppo_cfg.yaml",
},
代码详解#
在属性`kwargs`下,我们可以看到不同学习库的配置。 键是库的名称,值是配置实例的路径。 这个配置实例可以是字符串、类或类的实例。 例如,键`"rl_games_cfg_entry_point"的值是指向RL-Games库配置YAML文件的字符串。 同时,键"rsl_rl_cfg_entry_point"`的值指向RSL-RL库的配置类。
用于指定代理配置类的模式与用于指定环境配置入口点的模式非常相似。这意味着以下内容是等效的:
指定配置入口点为字符串
from . import agents
gym.register(
id="Isaac-Cartpole-v0",
entry_point="isaaclab.envs:ManagerBasedRLEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": f"{__name__}.cartpole_env_cfg:CartpoleEnvCfg",
"rsl_rl_cfg_entry_point": f"{agents.__name__}.rsl_rl_ppo_cfg:CartpolePPORunnerCfg",
},
)
指定配置入口点为一个类
from . import agents
gym.register(
id="Isaac-Cartpole-v0",
entry_point="isaaclab.envs:ManagerBasedRLEnv",
disable_env_checker=True,
kwargs={
"env_cfg_entry_point": f"{__name__}.cartpole_env_cfg:CartpoleEnvCfg",
"rsl_rl_cfg_entry_point": agents.rsl_rl_ppo_cfg.CartpolePPORunnerCfg,
},
)
第一个代码块是指定配置入口的首选方式。第二个代码块等同于第一个代码块,但会导致配置类的导入,从而减慢导入时间。这就是为什么我们建议使用字符串作为配置入口的原因。
所有位于 scripts/reinforcement_learning 目录中的脚本都默认配置为从 kwargs 字典中读取 <library_name>_cfg_entry_point 以获取配置实例。
例如,以下代码块显示了 train.py 脚本如何读取 Stable-Baselines3 库的配置实例:
使用SB3的train.py代码
1# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
2# All rights reserved.
3#
4# SPDX-License-Identifier: BSD-3-Clause
5
6
7"""Script to train RL agent with Stable Baselines3."""
8
9"""Launch Isaac Sim Simulator first."""
10
11import argparse
12import contextlib
13import signal
14import sys
15from pathlib import Path
16
17from isaaclab.app import AppLauncher
18
19# add argparse arguments
20parser = argparse.ArgumentParser(description="Train an RL agent with Stable-Baselines3.")
21parser.add_argument("--video", action="store_true", default=False, help="Record videos during training.")
22parser.add_argument("--video_length", type=int, default=200, help="Length of the recorded video (in steps).")
23parser.add_argument("--video_interval", type=int, default=2000, help="Interval between video recordings (in steps).")
24parser.add_argument("--num_envs", type=int, default=None, help="Number of environments to simulate.")
25parser.add_argument("--task", type=str, default=None, help="Name of the task.")
26parser.add_argument(
27 "--agent", type=str, default="sb3_cfg_entry_point", help="Name of the RL agent configuration entry point."
28)
29parser.add_argument("--seed", type=int, default=None, help="Seed used for the environment")
30parser.add_argument("--log_interval", type=int, default=100_000, help="Log data every n timesteps.")
31parser.add_argument("--checkpoint", type=str, default=None, help="Continue the training from checkpoint.")
32parser.add_argument("--max_iterations", type=int, default=None, help="RL Policy training iterations.")
33parser.add_argument("--export_io_descriptors", action="store_true", default=False, help="Export IO descriptors.")
34parser.add_argument(
35 "--keep_all_info",
36 action="store_true",
37 default=False,
38 help="Use a slower SB3 wrapper but keep all the extra training info.",
39)
40parser.add_argument(
41 "--ray-proc-id", "-rid", type=int, default=None, help="Automatically configured by Ray integration, otherwise None."
42)
43# append AppLauncher cli args
44AppLauncher.add_app_launcher_args(parser)
45# parse the arguments
46args_cli, hydra_args = parser.parse_known_args()
47# always enable cameras to record video
48if args_cli.video:
49 args_cli.enable_cameras = True
50
51# clear out sys.argv for Hydra
52sys.argv = [sys.argv[0]] + hydra_args
53
54# launch omniverse app
55app_launcher = AppLauncher(args_cli)
56simulation_app = app_launcher.app
57
58
59def cleanup_pbar(*args):
60 """
61 A small helper to stop training and
62 cleanup progress bar properly on ctrl+c
63 """
64 import gc
65
66 tqdm_objects = [obj for obj in gc.get_objects() if "tqdm" in type(obj).__name__]
67 for tqdm_object in tqdm_objects:
68 if "tqdm_rich" in type(tqdm_object).__name__:
69 tqdm_object.close()
70 raise KeyboardInterrupt
71
72
73# disable KeyboardInterrupt override
74signal.signal(signal.SIGINT, cleanup_pbar)
75
76"""Rest everything follows."""
77
78import logging
79import os
80import random
81import time
82from datetime import datetime
83
84import gymnasium as gym
85import numpy as np
86from stable_baselines3 import PPO
87from stable_baselines3.common.callbacks import CheckpointCallback, LogEveryNTimesteps
88from stable_baselines3.common.vec_env import VecNormalize
89
90from isaaclab.envs import (
91 DirectMARLEnv,
92 DirectMARLEnvCfg,
93 DirectRLEnvCfg,
94 ManagerBasedRLEnvCfg,
95 multi_agent_to_single_agent,
96)
97from isaaclab.utils.dict import print_dict
98from isaaclab.utils.io import dump_yaml
99
100from isaaclab_rl.sb3 import Sb3VecEnvWrapper, process_sb3_cfg
101
102import isaaclab_tasks # noqa: F401
103from isaaclab_tasks.utils.hydra import hydra_task_config
104
105# import logger
106logger = logging.getLogger(__name__)
107# PLACEHOLDER: Extension template (do not remove this comment)
108
109
110@hydra_task_config(args_cli.task, args_cli.agent)
111def main(env_cfg: ManagerBasedRLEnvCfg | DirectRLEnvCfg | DirectMARLEnvCfg, agent_cfg: dict):
112 """Train with stable-baselines agent."""
113 # randomly sample a seed if seed = -1
114 if args_cli.seed == -1:
115 args_cli.seed = random.randint(0, 10000)
116
117 # override configurations with non-hydra CLI arguments
118 env_cfg.scene.num_envs = args_cli.num_envs if args_cli.num_envs is not None else env_cfg.scene.num_envs
119 agent_cfg["seed"] = args_cli.seed if args_cli.seed is not None else agent_cfg["seed"]
120 # max iterations for training
121 if args_cli.max_iterations is not None:
122 agent_cfg["n_timesteps"] = args_cli.max_iterations * agent_cfg["n_steps"] * env_cfg.scene.num_envs
123
124 # set the environment seed
125 # note: certain randomizations occur in the environment initialization so we set the seed here
126 env_cfg.seed = agent_cfg["seed"]
127 env_cfg.sim.device = args_cli.device if args_cli.device is not None else env_cfg.sim.device
128
129 # directory for logging into
130 run_info = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
131 log_root_path = os.path.abspath(os.path.join("logs", "sb3", args_cli.task))
132 print(f"[INFO] Logging experiment in directory: {log_root_path}")
133 # The Ray Tune workflow extracts experiment name using the logging line below, hence,
134 # do not change it (see PR #2346, comment-2819298849)
135 print(f"Exact experiment name requested from command line: {run_info}")
136 log_dir = os.path.join(log_root_path, run_info)
137 # dump the configuration into log-directory
138 dump_yaml(os.path.join(log_dir, "params", "env.yaml"), env_cfg)
139 dump_yaml(os.path.join(log_dir, "params", "agent.yaml"), agent_cfg)
140
141 # save command used to run the script
142 command = " ".join(sys.orig_argv)
143 (Path(log_dir) / "command.txt").write_text(command)
144
145 # post-process agent configuration
146 agent_cfg = process_sb3_cfg(agent_cfg, env_cfg.scene.num_envs)
147 # read configurations about the agent-training
148 policy_arch = agent_cfg.pop("policy")
149 n_timesteps = agent_cfg.pop("n_timesteps")
150
151 # set the IO descriptors export flag if requested
152 if isinstance(env_cfg, ManagerBasedRLEnvCfg):
153 env_cfg.export_io_descriptors = args_cli.export_io_descriptors
154 else:
155 logger.warning(
156 "IO descriptors are only supported for manager based RL environments. No IO descriptors will be exported."
157 )
158
159 # set the log directory for the environment (works for all environment types)
160 env_cfg.log_dir = log_dir
161
162 # create isaac environment
163 env = gym.make(args_cli.task, cfg=env_cfg, render_mode="rgb_array" if args_cli.video else None)
164
165 # convert to single-agent instance if required by the RL algorithm
166 if isinstance(env.unwrapped, DirectMARLEnv):
167 env = multi_agent_to_single_agent(env)
168
169 # wrap for video recording
170 if args_cli.video:
171 video_kwargs = {
172 "video_folder": os.path.join(log_dir, "videos", "train"),
173 "step_trigger": lambda step: step % args_cli.video_interval == 0,
174 "video_length": args_cli.video_length,
175 "disable_logger": True,
176 }
177 print("[INFO] Recording videos during training.")
178 print_dict(video_kwargs, nesting=4)
179 env = gym.wrappers.RecordVideo(env, **video_kwargs)
180
181 start_time = time.time()
182
183 # wrap around environment for stable baselines
184 env = Sb3VecEnvWrapper(env, fast_variant=not args_cli.keep_all_info)
185
186 norm_keys = {"normalize_input", "normalize_value", "clip_obs"}
187 norm_args = {}
188 for key in norm_keys:
189 if key in agent_cfg:
190 norm_args[key] = agent_cfg.pop(key)
191
192 if norm_args and norm_args.get("normalize_input"):
193 print(f"Normalizing input, {norm_args=}")
194 env = VecNormalize(
195 env,
196 training=True,
197 norm_obs=norm_args["normalize_input"],
198 norm_reward=norm_args.get("normalize_value", False),
199 clip_obs=norm_args.get("clip_obs", 100.0),
200 gamma=agent_cfg["gamma"],
201 clip_reward=np.inf,
202 )
203
204 # create agent from stable baselines
205 agent = PPO(policy_arch, env, verbose=1, tensorboard_log=log_dir, **agent_cfg)
206 if args_cli.checkpoint is not None:
207 agent = agent.load(args_cli.checkpoint, env, print_system_info=True)
208
209 # callbacks for agent
210 checkpoint_callback = CheckpointCallback(save_freq=1000, save_path=log_dir, name_prefix="model", verbose=2)
211 callbacks = [checkpoint_callback, LogEveryNTimesteps(n_steps=args_cli.log_interval)]
212
213 # train the agent
214 with contextlib.suppress(KeyboardInterrupt):
215 agent.learn(
216 total_timesteps=n_timesteps,
217 callback=callbacks,
218 progress_bar=True,
219 log_interval=None,
220 )
221 # save the final model
222 agent.save(os.path.join(log_dir, "model"))
223 print("Saving to:")
224 print(os.path.join(log_dir, "model.zip"))
225
226 if isinstance(env, VecNormalize):
227 print("Saving normalization")
228 env.save(os.path.join(log_dir, "model_vecnormalize.pkl"))
229
230 print(f"Training time: {round(time.time() - start_time, 2)} seconds")
231
232 # close the simulator
233 env.close()
234
235
236if __name__ == "__main__":
237 # run the main function
238 main()
239 # close sim app
240 simulation_app.close()
参数 --agent 用于指定要使用的学习库。这用于从 kwargs 字典中检索配置实例。您可以通过传递 --agent 参数来手动指定替代配置实例。
代码执行#
由于cartpole平衡任务,RSL-RL 库提供两个配置实例,我们可以使用 --agent 参数来指定要使用的配置实例。
使用标准PPO配置进行训练:
# standard PPO training ./isaaclab.sh -p scripts/reinforcement_learning/rsl_rl/train.py --task Isaac-Cartpole-v0 --headless \ --run_name ppo
使用具有对称增强的PPO配置进行训练:
# PPO training with symmetry augmentation ./isaaclab.sh -p scripts/reinforcement_learning/rsl_rl/train.py --task Isaac-Cartpole-v0 --headless \ --agent rsl_rl_with_symmetry_cfg_entry_point \ --run_name ppo_with_symmetry_data_augmentation # you can use hydra to disable symmetry augmentation but enable mirror loss computation ./isaaclab.sh -p scripts/reinforcement_learning/rsl_rl/train.py --task Isaac-Cartpole-v0 --headless \ --agent rsl_rl_with_symmetry_cfg_entry_point \ --run_name ppo_without_symmetry_data_augmentation \ agent.algorithm.symmetry_cfg.use_data_augmentation=false
--run_name 参数用于指定运行的名称。这用于在``logs/rsl_rl/cartpole``目录中为运行创建一个目录。