Source code for isaaclab_tasks.utils.hydra

# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

"""Hydra utilities with REPLACE-only preset system.

This module bypasses Hydra's default MERGE behavior for config groups.
Instead, when a preset is selected, the entire config section is REPLACED
with the preset -- no field merging.

Presets are declared by subclassing :class:`PresetCfg` (or using the
:func:`preset` factory for scalars). The system recursively discovers all
presets and their paths automatically, including inside dict-valued fields.

Override categories (applied in order):
    1. Global presets: ``presets=inference,newton`` -- apply everywhere matching
    2. Path presets: ``env.backend=newton`` -- REPLACE specific section
    3. Preset-path scalars: ``env.backend.dt=0.001`` -- handled by us
    4. Global scalars: ``env.decimation=10`` -- handled by Hydra

Example usage::

    presets=newton env.backend.dt=0.001 env.decimation=10
"""

import functools
import sys
from collections.abc import Callable, Mapping

try:
    import hydra
    from hydra.core.config_store import ConfigStore
    from omegaconf import OmegaConf
except ImportError:
    raise ImportError("Hydra not installed. Run: pip install hydra-core")

from isaaclab.envs.utils.spaces import replace_env_cfg_spaces_with_strings, replace_strings_with_env_cfg_spaces
from isaaclab.utils import configclass, replace_slices_with_strings, replace_strings_with_slices

from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry


[docs] @configclass class PresetCfg: """Base class for declarative preset definitions. Subclass this and define fields as preset options. The field named ``default`` holds the config instance used when no CLI override is given. All other fields are named alternative presets. Example:: @configclass class PhysicsCfg(PresetCfg): default: PhysxCfg = PhysxCfg() newton: NewtonCfg = NewtonCfg() """ pass
[docs] def preset(**options) -> PresetCfg: """Create a :class:`PresetCfg` instance from keyword arguments. A convenience factory that dynamically builds a ``PresetCfg`` subclass with one field per keyword argument, then returns an instance of it. The caller **must** supply a ``default`` key. Example:: armature = preset(default=0.0, newton=0.01) # Equivalent to: # @configclass # class _Preset(PresetCfg): # default: float = 0.0 # newton: float = 0.01 # armature = _Preset() Args: **options: Preset alternatives keyed by name. Must include ``default``. Returns: A ``PresetCfg`` instance whose fields are the supplied options. Raises: ValueError: If ``default`` is not provided. """ if "default" not in options: raise ValueError("preset() requires a 'default' keyword argument.") annotations = {k: type(v) if v is not None else object for k, v in options.items()} ns = {"__annotations__": annotations, **options} cls = configclass(type("_Preset", (PresetCfg,), ns)) return cls()
def collect_presets(cfg, path: str = "") -> dict: """Recursively walk config tree and collect :class:`PresetCfg` fields. Presets are defined by subclassing :class:`PresetCfg` with typed fields for each alternative, or by using the :func:`preset` factory for scalar values. This function discovers them at every level of the config tree, including inside dict-valued fields (e.g. ``actuators``). Args: cfg: A configclass instance to walk. path: Current path prefix (used during recursion). Returns: Dict mapping dotted paths to preset dicts, e.g.: ``{"backend": {"default": PhysxCfg(), "newton": NewtonCfg()}}`` """ result = {} # Root-level PresetCfg: the cfg itself is a PresetCfg subclass if isinstance(cfg, PresetCfg) and hasattr(cfg, "__dataclass_fields__"): preset_dict = {} for field_name in cfg.__dataclass_fields__: preset_dict[field_name] = getattr(cfg, field_name) result[path] = preset_dict for alt in preset_dict.values(): if hasattr(alt, "__dataclass_fields__"): result.update(collect_presets(alt, path)) return result # Recurse into nested configclass attributes for name in dir(cfg): if name.startswith("_"): continue try: value = getattr(cfg, name) except Exception: continue child_path = f"{path}.{name}" if path else name if hasattr(value, "__dataclass_fields__"): if isinstance(value, PresetCfg): preset_dict = {} for field_name in value.__dataclass_fields__: preset_dict[field_name] = getattr(value, field_name) result[child_path] = preset_dict for alt in preset_dict.values(): if hasattr(alt, "__dataclass_fields__"): result.update(collect_presets(alt, child_path)) else: result.update(collect_presets(value, child_path)) elif isinstance(value, dict): for dict_key, dict_val in value.items(): if hasattr(dict_val, "__dataclass_fields__"): result.update(collect_presets(dict_val, f"{child_path}.{dict_key}")) return result
[docs] def resolve_task_config(task_name: str, agent_cfg_entry_point: str): """Resolve env and agent configs with Hydra overrides, presets, and scalars fully applied. Safe to call before Kit is launched — callable config values are stored as :class:`~isaaclab.utils.string.ResolvableString` and resolved lazily on first use, so no implementation modules are imported eagerly. Args: task_name: Task name (e.g., "Isaac-Velocity-Flat-Anymal-C-v0"). agent_cfg_entry_point: Agent config entry point key (e.g., "rsl_rl_cfg_entry_point"). Returns: Tuple of (env_cfg, agent_cfg) fully resolved. """ task = task_name.split(":")[-1] env_cfg, agent_cfg, presets = register_task(task, agent_cfg_entry_point) global_presets, preset_sel, preset_scalar, global_scalar = parse_overrides(sys.argv[1:], presets) original_argv, sys.argv = sys.argv, [sys.argv[0]] + global_scalar resolved = {} @hydra.main(config_path=None, config_name=task, version_base="1.3") def hydra_main(hydra_cfg, env_cfg=env_cfg, agent_cfg=agent_cfg): hydra_cfg = replace_strings_with_slices(OmegaConf.to_container(hydra_cfg, resolve=True)) env_cfg, agent_cfg = apply_overrides( env_cfg, agent_cfg, hydra_cfg, global_presets, preset_sel, preset_scalar, presets ) env_cfg.from_dict(hydra_cfg["env"]) env_cfg = replace_strings_with_env_cfg_spaces(env_cfg) if isinstance(agent_cfg, dict) or agent_cfg is None: agent_cfg = hydra_cfg["agent"] else: agent_cfg.from_dict(hydra_cfg["agent"]) resolved["env_cfg"] = env_cfg resolved["agent_cfg"] = agent_cfg try: hydra_main() finally: sys.argv = original_argv return resolved["env_cfg"], resolved["agent_cfg"]
[docs] def hydra_task_config(task_name: str, agent_cfg_entry_point: str) -> Callable: """Decorator for Hydra config with REPLACE-only preset semantics. Args: task_name: Task name (e.g., "Isaac-Reach-Franka-v0") agent_cfg_entry_point: Agent config entry point key Returns: Decorated function receiving ``(env_cfg, agent_cfg, *args, **kwargs)`` """ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): task = task_name.split(":")[-1] env_cfg, agent_cfg, presets = register_task(task, agent_cfg_entry_point) # Split args: global presets, path presets, scalars global_presets, preset_sel, preset_scalar, global_scalar = parse_overrides(sys.argv[1:], presets) # Only pass global scalars to Hydra original_argv, sys.argv = sys.argv, [sys.argv[0]] + global_scalar @hydra.main(config_path=None, config_name=task, version_base="1.3") def hydra_main(hydra_cfg, env_cfg=env_cfg, agent_cfg=agent_cfg): hydra_cfg = replace_strings_with_slices(OmegaConf.to_container(hydra_cfg, resolve=True)) env_cfg, agent_cfg = apply_overrides( env_cfg, agent_cfg, hydra_cfg, global_presets, preset_sel, preset_scalar, presets ) env_cfg.from_dict(hydra_cfg["env"]) env_cfg = replace_strings_with_env_cfg_spaces(env_cfg) if isinstance(agent_cfg, dict) or agent_cfg is None: agent_cfg = hydra_cfg["agent"] else: agent_cfg.from_dict(hydra_cfg["agent"]) func(env_cfg, agent_cfg, *args, **kwargs) try: hydra_main() finally: sys.argv = original_argv return wrapper return decorator
[docs] def resolve_preset_defaults(cfg): """Replace PresetCfg fields with their 'default' value before serialization. This must be called before to_dict() so the hydra dict contains only the resolved config rather than the raw PresetCfg with all alternatives. Returns the (possibly replaced) cfg if the root itself is a PresetCfg. """ if isinstance(cfg, PresetCfg) and hasattr(cfg, "__dataclass_fields__"): default = getattr(cfg, "default", None) if default is not None: return resolve_preset_defaults(default) return cfg for name in list(getattr(cfg, "__dataclass_fields__", {}).keys()): try: value = getattr(cfg, name) except Exception: continue if isinstance(value, PresetCfg) and hasattr(value, "__dataclass_fields__"): default = getattr(value, "default", None) if default is not None: setattr(cfg, name, default) if hasattr(default, "__dataclass_fields__"): resolve_preset_defaults(default) elif hasattr(value, "__dataclass_fields__"): resolve_preset_defaults(value) elif isinstance(value, dict): for dict_val in value.values(): if hasattr(dict_val, "__dataclass_fields__"): resolve_preset_defaults(dict_val) return cfg
def register_task(task_name: str, agent_entry: str) -> tuple: """Load configs, collect presets recursively, register base config to Hydra. Presets are collected from nested configclasses and stored separately - NOT registered as Hydra groups to avoid Hydra's merge behavior. Returns: (env_cfg, agent_cfg, presets) where presets = {"env": {"path": {"name": cfg}}, "agent": {...}} """ env_cfg = load_cfg_from_registry(task_name, "env_cfg_entry_point") agent_cfg = None if agent_entry: agent_cfg = load_cfg_from_registry(task_name, agent_entry) # Collect presets recursively from the config tree presets = { "env": collect_presets(env_cfg), "agent": collect_presets(agent_cfg) if agent_cfg else {}, } # Resolve PresetCfg defaults before serialization so to_dict() doesn't # include all preset alternatives in the hydra dict. env_cfg = resolve_preset_defaults(env_cfg) if agent_cfg is not None: agent_cfg = resolve_preset_defaults(agent_cfg) # Convert to dict for Hydra (handle gym spaces and slices) env_cfg = replace_env_cfg_spaces_with_strings(env_cfg) if agent_cfg is not None and hasattr(agent_cfg, "to_dict"): agent_dict = agent_cfg.to_dict() else: agent_dict = agent_cfg env_dict = env_cfg.to_dict() # type: ignore[union-attr] cfg_dict = replace_slices_with_strings({"env": env_dict, "agent": agent_dict}) # Register plain config (no groups) - Hydra only handles global scalars ConfigStore.instance().store(name=task_name, node=OmegaConf.create(cfg_dict)) return env_cfg, agent_cfg, presets def parse_overrides(args: list[str], presets: dict) -> tuple: """Categorize command line args by type. Args: args: Command line args (without script name) presets: {"env": {"path": {"name": cfg}}, "agent": {...}} Returns: (global_presets, preset_sel, preset_scalar, global_scalar) where: - global_presets: [name, ...] - apply to all matching configs - preset_sel: [(section, path, name), ...] - REPLACE selections - preset_scalar: [(full_path, value), ...] - scalars in preset paths - global_scalar: [arg, ...] - pass to Hydra """ # Build lookup of preset group paths (e.g., "env.actions") # Root-level PresetCfg has path="" -> bare "env" or "agent" key preset_paths = {f"{s}.{p}" if p else s for s, v in presets.items() for p in v} global_presets, preset_sel, preset_scalar, global_scalar = [], [], [], [] for arg in args: if "=" not in arg: global_scalar.append(arg) continue key, val = arg.split("=", 1) if key == "presets": # Global presets: presets=name1,name2 -> apply everywhere global_presets.extend(v.strip() for v in val.split(",") if v.strip()) elif key in preset_paths: # Exact match -> preset selection if "." in key: sec, path = key.split(".", 1) else: sec, path = key, "" preset_sel.append((sec, path, val)) elif any(key.startswith(pp + ".") for pp in preset_paths): # Prefix match -> scalar within preset path preset_scalar.append((key, val)) else: # No match -> global scalar, let Hydra handle it global_scalar.append(arg) # Sort preset selections: parents before children preset_sel.sort(key=lambda x: x[1].count(".")) return global_presets, preset_sel, preset_scalar, global_scalar def apply_overrides( env_cfg, agent_cfg, hydra_cfg: dict, global_presets: list, preset_sel: list, preset_scalar: list, presets: dict, ): """Apply preset selections and scalar overrides with REPLACE semantics. This is the core function that implements REPLACE (not merge) behavior: 0. Auto-apply "default" presets for paths not explicitly selected 1. Global presets (presets=name) apply to ALL configs that have matching preset 2. Path-specific presets (env.actions=name) replace specific sections 3. Scalar overrides are applied on top Returns: (env_cfg, agent_cfg) — possibly replaced if root-level PresetCfg was resolved. Raises: ValueError: If multiple global presets conflict on the same path. """ cfgs = {"env": env_cfg, "agent": agent_cfg} def _path_reachable(sec: str, path: str) -> bool: """Check that every ancestor along *path* is non-None on the live config. For "scene.camera" we walk to ``cfg.scene`` and verify it is not None, then check ``cfg.scene.camera`` is not None. If any segment is missing or None the child preset cannot be applied. """ if not path: return cfgs[sec] is not None obj = cfgs[sec] for part in path.split("."): try: if isinstance(obj, dict): obj = obj[part] else: obj = getattr(obj, part) except (AttributeError, TypeError, KeyError): return False if obj is None: return False return True def _apply_node(sec: str, path: str, node): """Replace a config node at the given section/path, handling root (path='').""" if node is None: node_dict = None elif hasattr(node, "to_dict"): node_dict = node.to_dict() elif isinstance(node, Mapping): node_dict = dict(node) else: node_dict = node if path == "": cfgs[sec] = node hydra_cfg[sec] = node_dict else: _setattr(cfgs[sec], path, node) _setattr(hydra_cfg, f"{sec}.{path}", node_dict) # --- Phase 1: Determine selected preset name for every path --------------- # Start with explicit path selections resolved: dict[str, tuple[str, str, str]] = {} # full_path -> (sec, path, name) for sec, path, name in preset_sel: if path not in presets.get(sec, {}): raise ValueError(f"Unknown preset group: {sec}.{path}") if name not in presets[sec][path]: avail = list(presets[sec][path].keys()) raise ValueError(f"Unknown preset '{name}' for {sec}.{path}. Available: {avail}") full_path = f"{sec}.{path}" if path else sec resolved[full_path] = (sec, path, name) # Apply global presets (error on conflict) applied_by: dict[str, str] = {} for name in global_presets: for sec in ("env", "agent"): for path, path_presets in presets.get(sec, {}).items(): if name in path_presets: full_path = f"{sec}.{path}" if path else sec if full_path in applied_by: raise ValueError( f"Conflicting global presets: '{applied_by[full_path]}' and '{name}' " f"both define preset for '{full_path}'" ) applied_by[full_path] = name if full_path not in resolved: resolved[full_path] = (sec, path, name) # Fill remaining paths with "default" (if available) for sec in ("env", "agent"): for path, path_presets in presets.get(sec, {}).items(): full_path = f"{sec}.{path}" if path else sec if full_path not in resolved and "default" in path_presets: resolved[full_path] = (sec, path, "default") # --- Phase 2: Apply in depth order, pruning unreachable children ---------- for full_path in sorted(resolved, key=lambda fp: fp.count(".")): sec, path, name = resolved[full_path] if cfgs[sec] is not None and _path_reachable(sec, path): _apply_node(sec, path, presets[sec][path][name]) # 3. Apply scalar overrides within preset paths for full_path, val_str in preset_scalar: if full_path.startswith("env."): sec, path = "env", full_path[4:] elif full_path.startswith("agent."): sec, path = "agent", full_path[6:] else: continue if cfgs[sec] is not None: val = _parse_val(val_str) _setattr(cfgs[sec], path, val) _setattr(hydra_cfg, full_path, val) return cfgs["env"], cfgs["agent"] def _setattr(obj, path: str, val): """Set nested attribute/key (e.g., "actions.arm_action.scale").""" *parts, leaf = path.split(".") for p in parts: obj = obj[p] if isinstance(obj, Mapping) else getattr(obj, p) if isinstance(obj, dict): obj[leaf] = val else: setattr(obj, leaf, val) def _parse_val(s: str): """Parse string to Python value (bool, None, int, float, or str).""" low = s.lower() if low == "true": return True if low == "false": return False if low in ("none", "null"): return None try: return float(s) if "." in s else int(s) except ValueError: # Strip quotes if present if s[0] in "\"'" and s[-1] in "\"'": return s[1:-1] return s