# Copyright (c) 2022-2026, The Isaac Lab Project Developers (https://github.com/isaac-sim/IsaacLab/blob/main/CONTRIBUTORS.md).
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause
from __future__ import annotations
import logging
import numpy as np
import torch
import warp as wp
from pxr import Gf, Sdf, Usd, UsdGeom, Vt
import isaaclab.sim as sim_utils
from isaaclab.utils.warp import ProxyArray
from .base_frame_view import BaseFrameView
logger = logging.getLogger(__name__)
[docs]
class UsdFrameView(BaseFrameView):
"""Batched interface for reading and writing transforms of multiple USD prims.
Provides batch operations for getting and setting poses (position and orientation)
of multiple prims at once via USD's ``XformCache``.
The class supports both world-space and local-space pose operations:
- **World poses**: Positions and orientations in the global world frame
- **Local poses**: Positions and orientations relative to each prim's parent
For GPU-accelerated Fabric operations, use the PhysX backend variant
obtained via :class:`~isaaclab.sim.views.FrameView`.
Pose getters return :class:`~isaaclab.utils.warp.ProxyArray`. Setters accept ``wp.array``.
.. note::
**Transform Requirements:**
All prims in the view must be Xformable and have standardized transform operations:
``[translate, orient, scale]``. Non-standard prims will raise a ValueError during
initialization if :attr:`validate_xform_ops` is True. Please use the function
:func:`isaaclab.sim.utils.standardize_xform_ops` to prepare prims before using this view.
.. warning::
This class operates at the USD default time code. Any animation or time-sampled data
will not be affected by write operations. For animated transforms, you need to handle
time-sampled keyframes separately.
"""
[docs]
def __init__(
self,
prim_path: str,
device: str = "cpu",
validate_xform_ops: bool = True,
stage: Usd.Stage | None = None,
**kwargs,
):
"""Initialize the view with matching prims.
Args:
prim_path: USD prim path pattern to match prims. Supports wildcards (``*``) and
regex patterns (e.g., ``"/World/Env_.*/Robot"``). See
:func:`isaaclab.sim.utils.find_matching_prims` for pattern syntax.
device: Device to place arrays on. Can be ``"cpu"`` or CUDA devices like
``"cuda:0"``. Defaults to ``"cpu"``.
validate_xform_ops: Whether to validate that the prims have standard xform operations.
Defaults to True.
stage: USD stage to search for prims. Defaults to None, in which case the current
active stage from the simulation context is used.
**kwargs: Additional keyword arguments (ignored). Allows forward-compatible
construction when callers pass backend-specific options like
``sync_usd_on_fabric_write``.
Raises:
ValueError: If any matched prim is not Xformable or doesn't have standardized
transform operations (translate, orient, scale in that order).
"""
self._prim_path = prim_path
self._device = device
stage = sim_utils.get_current_stage() if stage is None else stage
self._prims: list[Usd.Prim] = sim_utils.find_matching_prims(prim_path, stage=stage)
if validate_xform_ops:
for prim in self._prims:
sim_utils.standardize_xform_ops(prim)
if not sim_utils.validate_standard_xform_ops(prim):
raise ValueError(
f"Prim at path '{prim.GetPath().pathString}' is not a xformable prim with standard transform"
f" operations [translate, orient, scale]. Received type: '{prim.GetTypeName()}'."
" Use sim_utils.standardize_xform_ops() to prepare the prim."
)
self._ALL_INDICES = list(range(len(self._prims)))
# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
@property
def count(self) -> int:
"""Number of prims in this view."""
return len(self._prims)
@property
def device(self) -> str:
"""Device where arrays are allocated (cpu or cuda)."""
return self._device
@property
def prims(self) -> list[Usd.Prim]:
"""List of USD prims being managed by this view."""
return self._prims
@property
def prim_paths(self) -> list[str]:
"""List of prim paths (as strings) for all prims being managed by this view.
The conversion is performed lazily on first access and cached.
"""
if not hasattr(self, "_prim_paths"):
self._prim_paths = [prim.GetPath().pathString for prim in self._prims]
return self._prim_paths
# ------------------------------------------------------------------
# Setters
# ------------------------------------------------------------------
[docs]
def set_world_poses(
self,
positions: wp.array | None = None,
orientations: wp.array | None = None,
indices: wp.array | None = None,
):
"""Set world-space poses for prims in the view.
Converts the desired world pose to local-space relative to each prim's
parent before writing to USD xform ops.
Args:
positions: World-space positions of shape ``(M, 3)``.
orientations: World-space quaternions ``(w, x, y, z)`` of shape ``(M, 4)``.
indices: Indices of prims to set poses for. Defaults to None (all prims).
"""
indices_list = self._resolve_indices(indices)
positions_array = Vt.Vec3dArray.FromNumpy(self._to_numpy(positions)) if positions is not None else None
orientations_array = Vt.QuatdArray.FromNumpy(self._to_numpy(orientations)) if orientations is not None else None
xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default())
with Sdf.ChangeBlock():
for idx, prim_idx in enumerate(indices_list):
prim = self._prims[prim_idx]
parent_prim = prim.GetParent()
world_pos = positions_array[idx] if positions_array is not None else None
world_quat = orientations_array[idx] if orientations_array is not None else None
if parent_prim.IsValid() and parent_prim.GetPath() != Sdf.Path.absoluteRootPath:
if positions_array is None or orientations_array is None:
prim_tf = xform_cache.GetLocalToWorldTransform(prim)
prim_tf.Orthonormalize()
if world_pos is not None:
prim_tf.SetTranslateOnly(world_pos)
if world_quat is not None:
prim_tf.SetRotateOnly(world_quat)
else:
prim_tf = Gf.Matrix4d()
prim_tf.SetTranslateOnly(world_pos)
prim_tf.SetRotateOnly(world_quat)
parent_world_tf = xform_cache.GetLocalToWorldTransform(parent_prim)
local_tf = prim_tf * parent_world_tf.GetInverse()
local_pos = local_tf.ExtractTranslation()
local_quat = local_tf.ExtractRotationQuat()
else:
# Root-level prim: world == local
local_pos = world_pos
local_quat = world_quat
if local_pos is not None:
prim.GetAttribute("xformOp:translate").Set(local_pos)
if local_quat is not None:
prim.GetAttribute("xformOp:orient").Set(local_quat)
[docs]
def set_local_poses(
self,
translations: wp.array | None = None,
orientations: wp.array | None = None,
indices: wp.array | None = None,
):
"""Set local-space poses for prims in the view.
Args:
translations: Local-space translations of shape ``(M, 3)``.
orientations: Local-space quaternions ``(w, x, y, z)`` of shape ``(M, 4)``.
indices: Indices of prims to set poses for. Defaults to None (all prims).
"""
indices_list = self._resolve_indices(indices)
translations_array = Vt.Vec3dArray.FromNumpy(self._to_numpy(translations)) if translations is not None else None
orientations_array = Vt.QuatdArray.FromNumpy(self._to_numpy(orientations)) if orientations is not None else None
with Sdf.ChangeBlock():
for idx, prim_idx in enumerate(indices_list):
prim = self._prims[prim_idx]
if translations_array is not None:
prim.GetAttribute("xformOp:translate").Set(translations_array[idx])
if orientations_array is not None:
prim.GetAttribute("xformOp:orient").Set(orientations_array[idx])
[docs]
def set_scales(self, scales: wp.array, indices: wp.array | None = None):
"""Set scales for prims in the view.
Args:
scales: Scales of shape ``(M, 3)``.
indices: Indices of prims to set scales for. Defaults to None (all prims).
"""
indices_list = self._resolve_indices(indices)
scales_array = Vt.Vec3dArray.FromNumpy(self._to_numpy(scales))
with Sdf.ChangeBlock():
for idx, prim_idx in enumerate(indices_list):
prim = self._prims[prim_idx]
prim.GetAttribute("xformOp:scale").Set(scales_array[idx])
[docs]
def set_visibility(self, visibility: torch.Tensor, indices: wp.array | None = None):
"""Set visibility for prims in the view.
Args:
visibility: Visibility as a boolean tensor of shape ``(M,)``.
indices: Indices of prims to set visibility for. Defaults to None (all prims).
"""
indices_list = self._resolve_indices(indices)
if visibility.shape != (len(indices_list),):
raise ValueError(f"Expected visibility shape ({len(indices_list)},), got {visibility.shape}.")
with Sdf.ChangeBlock():
for idx, prim_idx in enumerate(indices_list):
imageable = UsdGeom.Imageable(self._prims[prim_idx])
if visibility[idx]:
imageable.MakeVisible()
else:
imageable.MakeInvisible()
# ------------------------------------------------------------------
# Getters
# ------------------------------------------------------------------
[docs]
def get_world_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]:
"""Get world-space poses for prims in the view.
Args:
indices: Indices of prims to get poses for. Defaults to None (all prims).
Returns:
A tuple ``(positions, orientations)`` of :class:`~isaaclab.utils.warp.ProxyArray`
wrappers. Use ``.warp`` for the underlying ``wp.array`` or ``.torch`` for a
cached zero-copy ``torch.Tensor`` view.
"""
indices_list = self._resolve_indices(indices)
positions = Vt.Vec3dArray(len(indices_list))
orientations = Vt.QuatdArray(len(indices_list))
xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default())
for idx, prim_idx in enumerate(indices_list):
prim = self._prims[prim_idx]
prim_tf = xform_cache.GetLocalToWorldTransform(prim)
prim_tf.Orthonormalize()
positions[idx] = prim_tf.ExtractTranslation()
orientations[idx] = prim_tf.ExtractRotationQuat()
pos_wp = wp.array(np.array(positions, dtype=np.float32), dtype=wp.float32, device=self._device)
quat_wp = wp.array(np.array(orientations, dtype=np.float32), dtype=wp.float32, device=self._device)
return ProxyArray(pos_wp), ProxyArray(quat_wp)
[docs]
def get_local_poses(self, indices: wp.array | None = None) -> tuple[ProxyArray, ProxyArray]:
"""Get local-space poses for prims in the view.
Args:
indices: Indices of prims to get poses for. Defaults to None (all prims).
Returns:
A tuple ``(translations, orientations)`` of :class:`~isaaclab.utils.warp.ProxyArray`
wrappers. Use ``.warp`` for the underlying ``wp.array`` or ``.torch`` for a
cached zero-copy ``torch.Tensor`` view.
"""
indices_list = self._resolve_indices(indices)
translations = Vt.Vec3dArray(len(indices_list))
orientations = Vt.QuatdArray(len(indices_list))
xform_cache = UsdGeom.XformCache(Usd.TimeCode.Default())
for idx, prim_idx in enumerate(indices_list):
prim = self._prims[prim_idx]
prim_tf = xform_cache.GetLocalTransformation(prim)[0]
prim_tf.Orthonormalize()
translations[idx] = prim_tf.ExtractTranslation()
orientations[idx] = prim_tf.ExtractRotationQuat()
pos_wp = wp.array(np.array(translations, dtype=np.float32), dtype=wp.float32, device=self._device)
quat_wp = wp.array(np.array(orientations, dtype=np.float32), dtype=wp.float32, device=self._device)
return ProxyArray(pos_wp), ProxyArray(quat_wp)
[docs]
def get_scales(self, indices: wp.array | None = None) -> wp.array:
"""Get scales for prims in the view.
Args:
indices: Indices of prims to get scales for. Defaults to None (all prims).
Returns:
A ``wp.array`` of shape ``(M, 3)``.
"""
indices_list = self._resolve_indices(indices)
scales = Vt.Vec3dArray(len(indices_list))
for idx, prim_idx in enumerate(indices_list):
prim = self._prims[prim_idx]
scales[idx] = prim.GetAttribute("xformOp:scale").Get()
return wp.array(np.array(scales, dtype=np.float32), dtype=wp.float32, device=self._device)
[docs]
def get_visibility(self, indices: wp.array | None = None) -> torch.Tensor:
"""Get visibility for prims in the view.
Args:
indices: Indices of prims to get visibility for. Defaults to None (all prims).
Returns:
A tensor of shape ``(M,)`` containing the visibility of each prim (bool).
"""
indices_list = self._resolve_indices(indices)
visibility = torch.zeros(len(indices_list), dtype=torch.bool, device=self._device)
for idx, prim_idx in enumerate(indices_list):
imageable = UsdGeom.Imageable(self._prims[prim_idx])
visibility[idx] = imageable.ComputeVisibility() != UsdGeom.Tokens.invisible
return visibility
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _resolve_indices(self, indices: wp.array | None):
"""Resolve warp indices to an iterable of ints for per-prim USD operations."""
if indices is None or indices == slice(None):
return self._ALL_INDICES
return indices.numpy()
@staticmethod
def _to_numpy(data: wp.array | torch.Tensor) -> np.ndarray:
"""Convert a ``wp.array`` or ``torch.Tensor`` to a numpy array on CPU."""
if isinstance(data, wp.array):
return data.numpy()
return data.cpu().numpy()