isaaclab_mimic.datagen.selection_strategy 源代码

# Copyright (c) 2024-2025, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

"""
Selection strategies used by Isaac Lab Mimic to select subtask segments from
source human demonstrations.
"""
import abc  # for abstract base class definitions
import torch

import isaaclab.utils.math as PoseUtils

# Global dictionary for remembering name to class mappings.
REGISTERED_SELECTION_STRATEGIES = {}


def make_selection_strategy(name, *args, **kwargs):
    """
    Creates an instance of a selection strategy class, specified by @name,
    which is used to look it up in the registry.
    """
    assert_selection_strategy_exists(name)
    return REGISTERED_SELECTION_STRATEGIES[name](*args, **kwargs)


def register_selection_strategy(cls):
    """
    Register selection strategy class into global registry.
    """
    ignore_classes = ["SelectionStrategy"]
    if cls.__name__ not in ignore_classes:
        REGISTERED_SELECTION_STRATEGIES[cls.NAME] = cls


def assert_selection_strategy_exists(name):
    """
    Allow easy way to check if selection strategy exists.
    """
    if name not in REGISTERED_SELECTION_STRATEGIES:
        raise Exception(
            "assert_selection_strategy_exists: name {} not found. Make sure it is a registered selection strategy"
            " among {}".format(name, ", ".join(REGISTERED_SELECTION_STRATEGIES))
        )


class SelectionStrategyMeta(type):
    """
    This metaclass adds selection strategy classes into the global registry.
    """

    def __new__(meta, name, bases, class_dict):
        cls = super().__new__(meta, name, bases, class_dict)
        register_selection_strategy(cls)
        return cls


class SelectionStrategy(metaclass=SelectionStrategyMeta):
    """
    Defines methods and functions for selection strategies to implement.
    """

    def __init__(self):
        pass

    @property
    @classmethod
    def NAME(self):
        """
        This name (str) will be used to register the selection strategy class in the global
        registry.
        """
        raise NotImplementedError

    @abc.abstractmethod
    def select_source_demo(
        self,
        eef_pose,
        object_pose,
        src_subtask_datagen_infos,
    ):
        """
        Selects source demonstration index using the current robot pose, relevant object pose
        for the current subtask, and relevant information from the source demonstrations for the
        current subtask.

        Args:
            eef_pose (torch.Tensor): current 4x4 eef pose
            object_pose (torch.Tensor): current 4x4 object pose, for the object in this subtask
            src_subtask_datagen_infos (list): DatagenInfo instance for the relevant subtask segment
                in the source demonstrations

        Returns:
            source_demo_ind (int): index of source demonstration - indicates which source subtask segment to use
        """
        raise NotImplementedError


[文档]class RandomStrategy(SelectionStrategy): """ Pick source demonstration randomly. """ # name for registering this class into registry NAME = "random"
[文档] def select_source_demo( self, eef_pose, object_pose, src_subtask_datagen_infos, ): """ Selects source demonstration index using the current robot pose, relevant object pose for the current subtask, and relevant information from the source demonstrations for the current subtask. Args: eef_pose (torch.Tensor): current 4x4 eef pose object_pose (torch.Tensor): current 4x4 object pose, for the object in this subtask src_subtask_datagen_infos (list): DatagenInfo instance for the relevant subtask segment in the source demonstrations Returns: source_demo_ind (int): index of source demonstration - indicates which source subtask segment to use """ # random selection n_src_demo = len(src_subtask_datagen_infos) return torch.randint(0, n_src_demo, (1,)).item()
[文档]class NearestNeighborObjectStrategy(SelectionStrategy): """ Pick source demonstration to be the one with the closest object pose to the object in the current scene. """ # name for registering this class into registry NAME = "nearest_neighbor_object"
[文档] def select_source_demo( self, eef_pose, object_pose, src_subtask_datagen_infos, pos_weight=1.0, rot_weight=1.0, nn_k=3, ): """ Selects source demonstration index using the current robot pose, relevant object pose for the current subtask, and relevant information from the source demonstrations for the current subtask. Args: eef_pose (torch.Tensor): current 4x4 eef pose object_pose (torch.Tensor): current 4x4 object pose, for the object in this subtask src_subtask_datagen_infos (list): DatagenInfo instance for the relevant subtask segment in the source demonstrations pos_weight (float): weight on position for minimizing pose distance rot_weight (float): weight on rotation for minimizing pose distance nn_k (int): pick source demo index uniformly at randomly from the top @nn_k nearest neighbors Returns: source_demo_ind (int): index of source demonstration - indicates which source subtask segment to use """ # collect object poses from start of subtask source segments into tensor of shape [N, 4, 4] src_object_poses = [] for di in src_subtask_datagen_infos: src_obj_pose = list(di.object_poses.values()) assert len(src_obj_pose) == 1 # use object pose at start of subtask segment src_object_poses.append(src_obj_pose[0][0]) src_object_poses = torch.stack(src_object_poses) # split into positions and rotations all_src_obj_pos, all_src_obj_rot = PoseUtils.unmake_pose(src_object_poses) obj_pos, obj_rot = PoseUtils.unmake_pose(object_pose) # prepare for broadcasting obj_pos = obj_pos.view(-1, 3) obj_rot_T = obj_rot.transpose(0, 1).view(-1, 3, 3) # pos dist is just L2 between positions pos_dists = torch.sqrt(((all_src_obj_pos - obj_pos) ** 2).sum(dim=-1)) # get angle (in axis-angle representation of delta rotation matrix) using the following formula # (see http://www.boris-belousov.net/2016/12/01/quat-dist/) # batched matrix mult, [N, 3, 3] x [1, 3, 3] -> [N, 3, 3] delta_R = torch.matmul(all_src_obj_rot, obj_rot_T) arc_cos_in = (torch.diagonal(delta_R, dim1=-2, dim2=-1).sum(dim=-1) - 1.0) / 2.0 arc_cos_in = torch.clamp(arc_cos_in, -1.0, 1.0) # clip for numerical stability rot_dists = torch.acos(arc_cos_in) # weight distances with coefficients dists_to_minimize = pos_weight * pos_dists + rot_weight * rot_dists # clip top-k parameter to max possible value nn_k = min(nn_k, len(dists_to_minimize)) # return one of the top-K nearest neighbors uniformly at random rand_k = torch.randint(0, nn_k, (1,)).item() top_k_neighbors_in_order = torch.argsort(dists_to_minimize)[:nn_k] return top_k_neighbors_in_order[rand_k]
[文档]class NearestNeighborRobotDistanceStrategy(SelectionStrategy): """ Pick source demonstration to be the one that minimizes the distance the robot end effector will need to travel from the current pose to the first pose in the transformed segment. """ # name for registering this class into registry NAME = "nearest_neighbor_robot_distance"
[文档] def select_source_demo( self, eef_pose, object_pose, src_subtask_datagen_infos, pos_weight=1.0, rot_weight=1.0, nn_k=3, ): """ Selects source demonstration index using the current robot pose, relevant object pose for the current subtask, and relevant information from the source demonstrations for the current subtask. Args: eef_pose (torch.Tensor): current 4x4 eef pose object_pose (torch.Tensor): current 4x4 object pose, for the object in this subtask src_subtask_datagen_infos (list): DatagenInfo instance for the relevant subtask segment in the source demonstrations pos_weight (float): weight on position for minimizing pose distance rot_weight (float): weight on rotation for minimizing pose distance nn_k (int): pick source demo index uniformly at randomly from the top @nn_k nearest neighbors Returns: source_demo_ind (int): index of source demonstration - indicates which source subtask segment to use """ # collect eef and object poses from start of subtask source segments into tensors of shape [N, 4, 4] src_eef_poses = [] src_object_poses = [] for di in src_subtask_datagen_infos: # use eef pose at start of subtask segment src_eef_poses.append(di.eef_pose[0]) # use object pose at start of subtask segment src_obj_pose = list(di.object_poses.values()) assert len(src_obj_pose) == 1 src_object_poses.append(src_obj_pose[0][0]) src_eef_poses = torch.stack(src_eef_poses) src_object_poses = torch.stack(src_object_poses) # Get source eef poses with respect to object frames. # note: frame A is world, frame B is object src_object_poses_inv = PoseUtils.pose_inv(src_object_poses) src_eef_poses_in_obj = PoseUtils.pose_in_A_to_pose_in_B( pose_in_A=src_eef_poses, pose_A_in_B=src_object_poses_inv, ) # Use this to find the first pose for the transformed subtask segment for each source demo. # Note this is the same logic used in PoseUtils.transform_poses_from_frame_A_to_frame_B transformed_eef_poses = PoseUtils.pose_in_A_to_pose_in_B( pose_in_A=src_eef_poses_in_obj, pose_A_in_B=object_pose, ) # split into positions and rotations all_transformed_eef_pos, all_transformed_eef_rot = PoseUtils.unmake_pose(transformed_eef_poses) eef_pos, eef_rot = PoseUtils.unmake_pose(eef_pose) # now measure distance from each of these transformed eef poses to our current eef pose # and choose the source demo that minimizes this distance # prepare for broadcasting eef_pos = eef_pos.view(-1, 3) eef_rot_T = eef_rot.transpose(0, 1).view(-1, 3, 3) # pos dist is just L2 between positions pos_dists = torch.sqrt(((all_transformed_eef_pos - eef_pos) ** 2).sum(dim=-1)) # get angle (in axis-angle representation of delta rotation matrix) using the following formula # (see http://www.boris-belousov.net/2016/12/01/quat-dist/) # batched matrix mult, [N, 3, 3] x [1, 3, 3] -> [N, 3, 3] delta_R = torch.matmul(all_transformed_eef_rot, eef_rot_T) arc_cos_in = (torch.diagonal(delta_R, dim1=-2, dim2=-1).sum(dim=-1) - 1.0) / 2.0 arc_cos_in = torch.clamp(arc_cos_in, -1.0, 1.0) # clip for numerical stability rot_dists = torch.acos(arc_cos_in) # weight distances with coefficients dists_to_minimize = pos_weight * pos_dists + rot_weight * rot_dists # clip top-k parameter to max possible value nn_k = min(nn_k, len(dists_to_minimize)) # return one of the top-K nearest neighbors uniformly at random rand_k = torch.randint(0, nn_k, (1,)).item() top_k_neighbors_in_order = torch.argsort(dists_to_minimize)[:nn_k] return top_k_neighbors_in_order[rand_k]