# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Wrapper to configure a :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv` instance to Stable-Baselines3 vectorized environment.
The following example shows how to wrap an environment for Stable-Baselines3:
.. code-block:: python
from omni.isaac.lab_tasks.utils.wrappers.sb3 import Sb3VecEnvWrapper
env = Sb3VecEnvWrapper(env)
"""
# needed to import for allowing type-hinting: torch.Tensor | dict[str, torch.Tensor]
from __future__ import annotations
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn # noqa: F401
from typing import Any
from stable_baselines3.common.utils import constant_fn
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs, VecEnvStepReturn
from omni.isaac.lab.envs import DirectRLEnv, ManagerBasedRLEnv
"""
Configuration Parser.
"""
[文档]def process_sb3_cfg(cfg: dict) -> dict:
"""Convert simple YAML types to Stable-Baselines classes/components.
Args:
cfg: A configuration dictionary.
Returns:
A dictionary containing the converted configuration.
Reference:
https://github.com/DLR-RM/rl-baselines3-zoo/blob/0e5eb145faefa33e7d79c7f8c179788574b20da5/utils/exp_manager.py#L358
"""
def update_dict(hyperparams: dict[str, Any]) -> dict[str, Any]:
for key, value in hyperparams.items():
if isinstance(value, dict):
update_dict(value)
else:
if key in ["policy_kwargs", "replay_buffer_class", "replay_buffer_kwargs"]:
hyperparams[key] = eval(value)
elif key in ["learning_rate", "clip_range", "clip_range_vf", "delta_std"]:
if isinstance(value, str):
_, initial_value = value.split("_")
initial_value = float(initial_value)
hyperparams[key] = lambda progress_remaining: progress_remaining * initial_value
elif isinstance(value, (float, int)):
# Negative value: ignore (ex: for clipping)
if value < 0:
continue
hyperparams[key] = constant_fn(float(value))
else:
raise ValueError(f"Invalid value for {key}: {hyperparams[key]}")
return hyperparams
# parse agent configuration and convert to classes
return update_dict(cfg)
"""
Vectorized environment wrapper.
"""
[文档]class Sb3VecEnvWrapper(VecEnv):
"""Wraps around Isaac Lab environment for Stable Baselines3.
Isaac Sim internally implements a vectorized environment. However, since it is
still considered a single environment instance, Stable Baselines tries to wrap
around it using the :class:`DummyVecEnv`. This is only done if the environment
is not inheriting from their :class:`VecEnv`. Thus, this class thinly wraps
over the environment from :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
Note:
While Stable-Baselines3 supports Gym 0.26+ API, their vectorized environment
still uses the old API (i.e. it is closer to Gym 0.21). Thus, we implement
the old API for the vectorized environment.
We also add monitoring functionality that computes the un-discounted episode
return and length. This information is added to the info dicts under key `episode`.
In contrast to the Isaac Lab environment, stable-baselines expect the following:
1. numpy datatype for MDP signals
2. a list of info dicts for each sub-environment (instead of a dict)
3. when environment has terminated, the observations from the environment should correspond
to the one after reset. The "real" final observation is passed using the info dicts
under the key ``terminal_observation``.
.. warning::
By the nature of physics stepping in Isaac Sim, it is not possible to forward the
simulation buffers without performing a physics step. Thus, reset is performed
inside the :meth:`step()` function after the actual physics step is taken.
Thus, the returned observations for terminated environments is the one after the reset.
.. caution::
This class must be the last wrapper in the wrapper chain. This is because the wrapper does not follow
the :class:`gym.Wrapper` interface. Any subsequent wrappers will need to be modified to work with this
wrapper.
Reference:
1. https://stable-baselines3.readthedocs.io/en/master/guide/vec_envs.html
2. https://stable-baselines3.readthedocs.io/en/master/common/monitor.html
"""
[文档] def __init__(self, env: ManagerBasedRLEnv | DirectRLEnv):
"""Initialize the wrapper.
Args:
env: The environment to wrap around.
Raises:
ValueError: When the environment is not an instance of :class:`ManagerBasedRLEnv` or :class:`DirectRLEnv`.
"""
# check that input is valid
if not isinstance(env.unwrapped, ManagerBasedRLEnv) and not isinstance(env.unwrapped, DirectRLEnv):
raise ValueError(
"The environment must be inherited from ManagerBasedRLEnv or DirectRLEnv. Environment type:"
f" {type(env)}"
)
# initialize the wrapper
self.env = env
# collect common information
self.num_envs = self.unwrapped.num_envs
self.sim_device = self.unwrapped.device
self.render_mode = self.unwrapped.render_mode
# obtain gym spaces
# note: stable-baselines3 does not like when we have unbounded action space so
# we set it to some high value here. Maybe this is not general but something to think about.
observation_space = self.unwrapped.single_observation_space["policy"]
action_space = self.unwrapped.single_action_space
if isinstance(action_space, gym.spaces.Box) and not action_space.is_bounded("both"):
action_space = gym.spaces.Box(low=-100, high=100, shape=action_space.shape)
# initialize vec-env
VecEnv.__init__(self, self.num_envs, observation_space, action_space)
# add buffer for logging episodic information
self._ep_rew_buf = torch.zeros(self.num_envs, device=self.sim_device)
self._ep_len_buf = torch.zeros(self.num_envs, device=self.sim_device)
def __str__(self):
"""Returns the wrapper name and the :attr:`env` representation string."""
return f"<{type(self).__name__}{self.env}>"
def __repr__(self):
"""Returns the string representation of the wrapper."""
return str(self)
"""
Properties -- Gym.Wrapper
"""
[文档] @classmethod
def class_name(cls) -> str:
"""Returns the class name of the wrapper."""
return cls.__name__
@property
def unwrapped(self) -> ManagerBasedRLEnv | DirectRLEnv:
"""Returns the base environment of the wrapper.
This will be the bare :class:`gymnasium.Env` environment, underneath all layers of wrappers.
"""
return self.env.unwrapped
"""
Properties
"""
[文档] def get_episode_rewards(self) -> list[float]:
"""Returns the rewards of all the episodes."""
return self._ep_rew_buf.cpu().tolist()
[文档] def get_episode_lengths(self) -> list[int]:
"""Returns the number of time-steps of all the episodes."""
return self._ep_len_buf.cpu().tolist()
"""
Operations - MDP
"""
def seed(self, seed: int | None = None) -> list[int | None]: # noqa: D102
return [self.unwrapped.seed(seed)] * self.unwrapped.num_envs
def reset(self) -> VecEnvObs: # noqa: D102
obs_dict, _ = self.env.reset()
# reset episodic information buffers
self._ep_rew_buf.zero_()
self._ep_len_buf.zero_()
# convert data types to numpy depending on backend
return self._process_obs(obs_dict)
def step_async(self, actions): # noqa: D102
# convert input to numpy array
if not isinstance(actions, torch.Tensor):
actions = np.asarray(actions)
actions = torch.from_numpy(actions).to(device=self.sim_device, dtype=torch.float32)
else:
actions = actions.to(device=self.sim_device, dtype=torch.float32)
# convert to tensor
self._async_actions = actions
def step_wait(self) -> VecEnvStepReturn: # noqa: D102
# record step information
obs_dict, rew, terminated, truncated, extras = self.env.step(self._async_actions)
# update episode un-discounted return and length
self._ep_rew_buf += rew
self._ep_len_buf += 1
# compute reset ids
dones = terminated | truncated
reset_ids = (dones > 0).nonzero(as_tuple=False)
# convert data types to numpy depending on backend
# note: ManagerBasedRLEnv uses torch backend (by default).
obs = self._process_obs(obs_dict)
rew = rew.detach().cpu().numpy()
terminated = terminated.detach().cpu().numpy()
truncated = truncated.detach().cpu().numpy()
dones = dones.detach().cpu().numpy()
# convert extra information to list of dicts
infos = self._process_extras(obs, terminated, truncated, extras, reset_ids)
# reset info for terminated environments
self._ep_rew_buf[reset_ids] = 0
self._ep_len_buf[reset_ids] = 0
return obs, rew, dones, infos
def close(self): # noqa: D102
self.env.close()
def get_attr(self, attr_name, indices=None): # noqa: D102
# resolve indices
if indices is None:
indices = slice(None)
num_indices = self.num_envs
else:
num_indices = len(indices)
# obtain attribute value
attr_val = getattr(self.env, attr_name)
# return the value
if not isinstance(attr_val, torch.Tensor):
return [attr_val] * num_indices
else:
return attr_val[indices].detach().cpu().numpy()
def set_attr(self, attr_name, value, indices=None): # noqa: D102
raise NotImplementedError("Setting attributes is not supported.")
def env_method(self, method_name: str, *method_args, indices=None, **method_kwargs): # noqa: D102
if method_name == "render":
# gymnasium does not support changing render mode at runtime
return self.env.render()
else:
# this isn't properly implemented but it is not necessary.
# mostly done for completeness.
env_method = getattr(self.env, method_name)
return env_method(*method_args, indices=indices, **method_kwargs)
def env_is_wrapped(self, wrapper_class, indices=None): # noqa: D102
raise NotImplementedError("Checking if environment is wrapped is not supported.")
def get_images(self): # noqa: D102
raise NotImplementedError("Getting images is not supported.")
"""
Helper functions.
"""
def _process_obs(self, obs_dict: torch.Tensor | dict[str, torch.Tensor]) -> np.ndarray | dict[str, np.ndarray]:
"""Convert observations into NumPy data type."""
# Sb3 doesn't support asymmetric observation spaces, so we only use "policy"
obs = obs_dict["policy"]
# note: ManagerBasedRLEnv uses torch backend (by default).
if isinstance(obs, dict):
for key, value in obs.items():
obs[key] = value.detach().cpu().numpy()
elif isinstance(obs, torch.Tensor):
obs = obs.detach().cpu().numpy()
else:
raise NotImplementedError(f"Unsupported data type: {type(obs)}")
return obs
def _process_extras(
self, obs: np.ndarray, terminated: np.ndarray, truncated: np.ndarray, extras: dict, reset_ids: np.ndarray
) -> list[dict[str, Any]]:
"""Convert miscellaneous information into dictionary for each sub-environment."""
# create empty list of dictionaries to fill
infos: list[dict[str, Any]] = [dict.fromkeys(extras.keys()) for _ in range(self.num_envs)]
# fill-in information for each sub-environment
# note: This loop becomes slow when number of environments is large.
for idx in range(self.num_envs):
# fill-in episode monitoring info
if idx in reset_ids:
infos[idx]["episode"] = dict()
infos[idx]["episode"]["r"] = float(self._ep_rew_buf[idx])
infos[idx]["episode"]["l"] = float(self._ep_len_buf[idx])
else:
infos[idx]["episode"] = None
# fill-in bootstrap information
infos[idx]["TimeLimit.truncated"] = truncated[idx] and not terminated[idx]
# fill-in information from extras
for key, value in extras.items():
# 1. remap extra episodes information safely
# 2. for others just store their values
if key == "log":
# only log this data for episodes that are terminated
if infos[idx]["episode"] is not None:
for sub_key, sub_value in value.items():
infos[idx]["episode"][sub_key] = sub_value
else:
infos[idx][key] = value[idx]
# add information about terminal observation separately
if idx in reset_ids:
# extract terminal observations
if isinstance(obs, dict):
terminal_obs = dict.fromkeys(obs.keys())
for key, value in obs.items():
terminal_obs[key] = value[idx]
else:
terminal_obs = obs[idx]
# add info to dict
infos[idx]["terminal_observation"] = terminal_obs
else:
infos[idx]["terminal_observation"] = None
# return list of dictionaries
return infos