omni.isaac.lab.managers.action_manager 源代码
# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
"""Action manager for processing actions sent to the environment."""
from __future__ import annotations
import inspect
import torch
import weakref
from abc import abstractmethod
from collections.abc import Sequence
from prettytable import PrettyTable
from typing import TYPE_CHECKING
import omni.kit.app
from omni.isaac.lab.assets import AssetBase
from .manager_base import ManagerBase, ManagerTermBase
from .manager_term_cfg import ActionTermCfg
if TYPE_CHECKING:
from omni.isaac.lab.envs import ManagerBasedEnv
[文档]class ActionTerm(ManagerTermBase):
"""Base class for action terms.
The action term is responsible for processing the raw actions sent to the environment
and applying them to the asset managed by the term. The action term is comprised of two
operations:
* Processing of actions: This operation is performed once per **environment step** and
is responsible for pre-processing the raw actions sent to the environment.
* Applying actions: This operation is performed once per **simulation step** and is
responsible for applying the processed actions to the asset managed by the term.
"""
[文档] def __init__(self, cfg: ActionTermCfg, env: ManagerBasedEnv):
"""Initialize the action term.
Args:
cfg: The configuration object.
env: The environment instance.
"""
# call the base class constructor
super().__init__(cfg, env)
# parse config to obtain asset to which the term is applied
self._asset: AssetBase = self._env.scene[self.cfg.asset_name]
# add handle for debug visualization (this is set to a valid handle inside set_debug_vis)
self._debug_vis_handle = None
# set initial state of debug visualization
self.set_debug_vis(self.cfg.debug_vis)
def __del__(self):
"""Unsubscribe from the callbacks."""
if self._debug_vis_handle:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
"""
Properties.
"""
@property
@abstractmethod
def action_dim(self) -> int:
"""Dimension of the action term."""
raise NotImplementedError
@property
@abstractmethod
def raw_actions(self) -> torch.Tensor:
"""The input/raw actions sent to the term."""
raise NotImplementedError
@property
@abstractmethod
def processed_actions(self) -> torch.Tensor:
"""The actions computed by the term after applying any processing."""
raise NotImplementedError
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the action term has a debug visualization implemented."""
# check if function raises NotImplementedError
source_code = inspect.getsource(self._set_debug_vis_impl)
return "NotImplementedError" not in source_code
"""
Operations.
"""
[文档] def set_debug_vis(self, debug_vis: bool) -> bool:
"""Sets whether to visualize the action term data.
Args:
debug_vis: Whether to visualize the action term data.
Returns:
Whether the debug visualization was successfully set. False if the action term does
not support debug visualization.
"""
# check if debug visualization is supported
if not self.has_debug_vis_implementation:
return False
# toggle debug visualization objects
self._set_debug_vis_impl(debug_vis)
# toggle debug visualization handles
if debug_vis:
# create a subscriber for the post update event if it doesn't exist
if self._debug_vis_handle is None:
app_interface = omni.kit.app.get_app_interface()
self._debug_vis_handle = app_interface.get_post_update_event_stream().create_subscription_to_pop(
lambda event, obj=weakref.proxy(self): obj._debug_vis_callback(event)
)
else:
# remove the subscriber if it exists
if self._debug_vis_handle is not None:
self._debug_vis_handle.unsubscribe()
self._debug_vis_handle = None
# return success
return True
[文档] @abstractmethod
def process_actions(self, actions: torch.Tensor):
"""Processes the actions sent to the environment.
Note:
This function is called once per environment step by the manager.
Args:
actions: The actions to process.
"""
raise NotImplementedError
[文档] @abstractmethod
def apply_actions(self):
"""Applies the actions to the asset managed by the term.
Note:
This is called at every simulation step by the manager.
"""
raise NotImplementedError
def _set_debug_vis_impl(self, debug_vis: bool):
"""Set debug visualization into visualization objects.
This function is responsible for creating the visualization objects if they don't exist
and input ``debug_vis`` is True. If the visualization objects exist, the function should
set their visibility into the stage.
"""
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
def _debug_vis_callback(self, event):
"""Callback for debug visualization.
This function calls the visualization objects and sets the data to visualize into them.
"""
raise NotImplementedError(f"Debug visualization is not implemented for {self.__class__.__name__}.")
[文档]class ActionManager(ManagerBase):
"""Manager for processing and applying actions for a given world.
The action manager handles the interpretation and application of user-defined
actions on a given world. It is comprised of different action terms that decide
the dimension of the expected actions.
The action manager performs operations at two stages:
* processing of actions: It splits the input actions to each term and performs any
pre-processing needed. This should be called once at every environment step.
* apply actions: This operation typically sets the processed actions into the assets in the
scene (such as robots). It should be called before every simulation step.
"""
[文档] def __init__(self, cfg: object, env: ManagerBasedEnv):
"""Initialize the action manager.
Args:
cfg: The configuration object or dictionary (``dict[str, ActionTermCfg]``).
env: The environment instance.
Raises:
ValueError: If the configuration is None.
"""
# check if config is None
if cfg is None:
raise ValueError("Action manager configuration is None. Please provide a valid configuration.")
# call the base class constructor (this prepares the terms)
super().__init__(cfg, env)
# create buffers to store actions
self._action = torch.zeros((self.num_envs, self.total_action_dim), device=self.device)
self._prev_action = torch.zeros_like(self._action)
# check if any term has debug visualization implemented
self.cfg.debug_vis = False
for term in self._terms.values():
self.cfg.debug_vis |= term.cfg.debug_vis
def __str__(self) -> str:
"""Returns: A string representation for action manager."""
msg = f"<ActionManager> contains {len(self._term_names)} active terms.\n"
# create table for term information
table = PrettyTable()
table.title = f"Active Action Terms (shape: {self.total_action_dim})"
table.field_names = ["Index", "Name", "Dimension"]
# set alignment of table columns
table.align["Name"] = "l"
table.align["Dimension"] = "r"
# add info on each term
for index, (name, term) in enumerate(self._terms.items()):
table.add_row([index, name, term.action_dim])
# convert table to string
msg += table.get_string()
msg += "\n"
return msg
"""
Properties.
"""
@property
def total_action_dim(self) -> int:
"""Total dimension of actions."""
return sum(self.action_term_dim)
@property
def active_terms(self) -> list[str]:
"""Name of active action terms."""
return self._term_names
@property
def action_term_dim(self) -> list[int]:
"""Shape of each action term."""
return [term.action_dim for term in self._terms.values()]
@property
def action(self) -> torch.Tensor:
"""The actions sent to the environment. Shape is (num_envs, total_action_dim)."""
return self._action
@property
def prev_action(self) -> torch.Tensor:
"""The previous actions sent to the environment. Shape is (num_envs, total_action_dim)."""
return self._prev_action
@property
def has_debug_vis_implementation(self) -> bool:
"""Whether the command terms have debug visualization implemented."""
# check if function raises NotImplementedError
has_debug_vis = False
for term in self._terms.values():
has_debug_vis |= term.has_debug_vis_implementation
return has_debug_vis
"""
Operations.
"""
[文档] def set_debug_vis(self, debug_vis: bool) -> bool:
"""Sets whether to visualize the action data.
Args:
debug_vis: Whether to visualize the action data.
Returns:
Whether the debug visualization was successfully set. False if the action
does not support debug visualization.
"""
for term in self._terms.values():
term.set_debug_vis(debug_vis)
[文档] def reset(self, env_ids: Sequence[int] | None = None) -> dict[str, torch.Tensor]:
"""Resets the action history.
Args:
env_ids: The environment ids. Defaults to None, in which case
all environments are considered.
Returns:
An empty dictionary.
"""
# resolve environment ids
if env_ids is None:
env_ids = slice(None)
# reset the action history
self._prev_action[env_ids] = 0.0
self._action[env_ids] = 0.0
# reset all action terms
for term in self._terms.values():
term.reset(env_ids=env_ids)
# nothing to log here
return {}
[文档] def process_action(self, action: torch.Tensor):
"""Processes the actions sent to the environment.
Note:
This function should be called once per environment step.
Args:
action: The actions to process.
"""
# check if action dimension is valid
if self.total_action_dim != action.shape[1]:
raise ValueError(f"Invalid action shape, expected: {self.total_action_dim}, received: {action.shape[1]}.")
# store the input actions
self._prev_action[:] = self._action
self._action[:] = action.to(self.device)
# split the actions and apply to each tensor
idx = 0
for term in self._terms.values():
term_actions = action[:, idx : idx + term.action_dim]
term.process_actions(term_actions)
idx += term.action_dim
[文档] def apply_action(self) -> None:
"""Applies the actions to the environment/simulation.
Note:
This should be called at every simulation step.
"""
for term in self._terms.values():
term.apply_actions()
[文档] def get_term(self, name: str) -> ActionTerm:
"""Returns the action term with the specified name.
Args:
name: The name of the action term.
Returns:
The action term with the specified name.
"""
return self._terms[name]
"""
Helper functions.
"""
def _prepare_terms(self):
# create buffers to parse and store terms
self._term_names: list[str] = list()
self._terms: dict[str, ActionTerm] = dict()
# check if config is dict already
if isinstance(self.cfg, dict):
cfg_items = self.cfg.items()
else:
cfg_items = self.cfg.__dict__.items()
# parse action terms from the config
for term_name, term_cfg in cfg_items:
# check if term config is None
if term_cfg is None:
continue
# check valid type
if not isinstance(term_cfg, ActionTermCfg):
raise TypeError(
f"Configuration for the term '{term_name}' is not of type ActionTermCfg."
f" Received: '{type(term_cfg)}'."
)
# create the action term
term = term_cfg.class_type(term_cfg, self._env)
# sanity check if term is valid type
if not isinstance(term, ActionTerm):
raise TypeError(f"Returned object for the term '{term_name}' is not of type ActionType.")
# add term name and parameters
self._term_names.append(term_name)
self._terms[term_name] = term