# Copyright (c) 2022-2025, The Isaac Lab Project Developers.# All rights reserved.## SPDX-License-Identifier: BSD-3-Clause"""Termination manager for computing done signals for a given world."""from__future__importannotationsimporttorchfromcollections.abcimportSequencefromprettytableimportPrettyTablefromtypingimportTYPE_CHECKINGfrom.manager_baseimportManagerBase,ManagerTermBasefrom.manager_term_cfgimportTerminationTermCfgifTYPE_CHECKING:fromomni.isaac.lab.envsimportManagerBasedRLEnv
[文档]classTerminationManager(ManagerBase):"""Manager for computing done signals for a given world. The termination manager computes the termination signal (also called dones) as a combination of termination terms. Each termination term is a function which takes the environment as an argument and returns a boolean tensor of shape (num_envs,). The termination manager computes the termination signal as the union (logical or) of all the termination terms. Following the `Gymnasium API <https://gymnasium.farama.org/tutorials/gymnasium_basics/handling_time_limits/>`_, the termination signal is computed as the logical OR of the following signals: * **Time-out**: This signal is set to true if the environment has ended after an externally defined condition (that is outside the scope of a MDP). For example, the environment may be terminated if the episode has timed out (i.e. reached max episode length). * **Terminated**: This signal is set to true if the environment has reached a terminal state defined by the environment. This state may correspond to task success, task failure, robot falling, etc. These signals can be individually accessed using the :attr:`time_outs` and :attr:`terminated` properties. The termination terms are parsed from a config class containing the manager's settings and each term's parameters. Each termination term should instantiate the :class:`TerminationTermCfg` class. The term's configuration :attr:`TerminationTermCfg.time_out` decides whether the term is a timeout or a termination term. """_env:ManagerBasedRLEnv"""The environment instance."""
[文档]def__init__(self,cfg:object,env:ManagerBasedRLEnv):"""Initializes the termination manager. Args: cfg: The configuration object or dictionary (``dict[str, TerminationTermCfg]``). env: An environment object. """# create buffers to parse and store termsself._term_names:list[str]=list()self._term_cfgs:list[TerminationTermCfg]=list()self._class_term_cfgs:list[TerminationTermCfg]=list()# call the base class constructor (this will parse the terms config)super().__init__(cfg,env)# prepare extra info to store individual termination term informationself._term_dones=dict()forterm_nameinself._term_names:self._term_dones[term_name]=torch.zeros(self.num_envs,device=self.device,dtype=torch.bool)# create buffer for managing termination per environmentself._truncated_buf=torch.zeros(self.num_envs,device=self.device,dtype=torch.bool)self._terminated_buf=torch.zeros_like(self._truncated_buf)
def__str__(self)->str:"""Returns: A string representation for termination manager."""msg=f"<TerminationManager> contains {len(self._term_names)} active terms.\n"# create table for term informationtable=PrettyTable()table.title="Active Termination Terms"table.field_names=["Index","Name","Time Out"]# set alignment of table columnstable.align["Name"]="l"# add info on each termforindex,(name,term_cfg)inenumerate(zip(self._term_names,self._term_cfgs)):table.add_row([index,name,term_cfg.time_out])# convert table to stringmsg+=table.get_string()msg+="\n"returnmsg""" Properties. """@propertydefactive_terms(self)->list[str]:"""Name of active termination terms."""returnself._term_names@propertydefdones(self)->torch.Tensor:"""The net termination signal. Shape is (num_envs,)."""returnself._truncated_buf|self._terminated_buf@propertydeftime_outs(self)->torch.Tensor:"""The timeout signal (reaching max episode length). Shape is (num_envs,). This signal is set to true if the environment has ended after an externally defined condition (that is outside the scope of a MDP). For example, the environment may be terminated if the episode has timed out (i.e. reached max episode length). """returnself._truncated_buf@propertydefterminated(self)->torch.Tensor:"""The terminated signal (reaching a terminal state). Shape is (num_envs,). This signal is set to true if the environment has reached a terminal state defined by the environment. This state may correspond to task success, task failure, robot falling, etc. """returnself._terminated_buf""" Operations. """
[文档]defreset(self,env_ids:Sequence[int]|None=None)->dict[str,torch.Tensor]:"""Returns the episodic counts of individual termination terms. Args: env_ids: The environment ids. Defaults to None, in which case all environments are considered. Returns: Dictionary of episodic sum of individual reward terms. """# resolve environment idsifenv_idsisNone:env_ids=slice(None)# add to episode dictextras={}forkeyinself._term_dones.keys():# store informationextras["Episode_Termination/"+key]=torch.count_nonzero(self._term_dones[key][env_ids]).item()# reset all the reward termsforterm_cfginself._class_term_cfgs:term_cfg.func.reset(env_ids=env_ids)# return logged informationreturnextras
[文档]defcompute(self)->torch.Tensor:"""Computes the termination signal as union of individual terms. This function calls each termination term managed by the class and performs a logical OR operation to compute the net termination signal. Returns: The combined termination signal of shape (num_envs,). """# reset computationself._truncated_buf[:]=Falseself._terminated_buf[:]=False# iterate over all the termination termsforname,term_cfginzip(self._term_names,self._term_cfgs):value=term_cfg.func(self._env,**term_cfg.params)# store timeout signal separatelyifterm_cfg.time_out:self._truncated_buf|=valueelse:self._terminated_buf|=value# add to episode donesself._term_dones[name][:]=value# return combined termination signalreturnself._truncated_buf|self._terminated_buf
[文档]defget_term(self,name:str)->torch.Tensor:"""Returns the termination term with the specified name. Args: name: The name of the termination term. Returns: The corresponding termination term value. Shape is (num_envs,). """returnself._term_dones[name]
[文档]defget_active_iterable_terms(self,env_idx:int)->Sequence[tuple[str,Sequence[float]]]:"""Returns the active terms as iterable sequence of tuples. The first element of the tuple is the name of the term and the second element is the raw value(s) of the term. Args: env_idx: The specific environment to pull the active terms from. Returns: The active terms. """terms=[]forkeyinself._term_dones.keys():terms.append((key,[self._term_dones[key][env_idx].float().cpu().item()]))returnterms
""" Operations - Term settings. """
[文档]defset_term_cfg(self,term_name:str,cfg:TerminationTermCfg):"""Sets the configuration of the specified term into the manager. Args: term_name: The name of the termination term. cfg: The configuration for the termination term. Raises: ValueError: If the term name is not found. """ifterm_namenotinself._term_names:raiseValueError(f"Termination term '{term_name}' not found.")# set the configurationself._term_cfgs[self._term_names.index(term_name)]=cfg
[文档]defget_term_cfg(self,term_name:str)->TerminationTermCfg:"""Gets the configuration for the specified term. Args: term_name: The name of the termination term. Returns: The configuration of the termination term. Raises: ValueError: If the term name is not found. """ifterm_namenotinself._term_names:raiseValueError(f"Termination term '{term_name}' not found.")# return the configurationreturnself._term_cfgs[self._term_names.index(term_name)]
""" Helper functions. """def_prepare_terms(self):# check if config is dict alreadyifisinstance(self.cfg,dict):cfg_items=self.cfg.items()else:cfg_items=self.cfg.__dict__.items()# iterate over all the termsforterm_name,term_cfgincfg_items:# check for non configifterm_cfgisNone:continue# check for valid config typeifnotisinstance(term_cfg,TerminationTermCfg):raiseTypeError(f"Configuration for the term '{term_name}' is not of type TerminationTermCfg."f" Received: '{type(term_cfg)}'.")# resolve common parametersself._resolve_common_term_cfg(term_name,term_cfg,min_argc=1)# add function to listself._term_names.append(term_name)self._term_cfgs.append(term_cfg)# check if the term is a classifisinstance(term_cfg.func,ManagerTermBase):self._class_term_cfgs.append(term_cfg)