omni.isaac.lab_tasks.utils.wrappers.rsl_rl.exporter 源代码

# Copyright (c) 2022-2024, The Isaac Lab Project Developers.
# All rights reserved.
#
# SPDX-License-Identifier: BSD-3-Clause

import copy
import os
import torch


[文档]def export_policy_as_jit(actor_critic: object, normalizer: object | None, path: str, filename="policy.pt"): """Export policy into a Torch JIT file. Args: actor_critic: The actor-critic torch module. normalizer: The empirical normalizer module. If None, Identity is used. path: The path to the saving directory. filename: The name of exported JIT file. Defaults to "policy.pt". """ policy_exporter = _TorchPolicyExporter(actor_critic, normalizer) policy_exporter.export(path, filename)
[文档]def export_policy_as_onnx( actor_critic: object, path: str, normalizer: object | None = None, filename="policy.onnx", verbose=False ): """Export policy into a Torch ONNX file. Args: actor_critic: The actor-critic torch module. normalizer: The empirical normalizer module. If None, Identity is used. path: The path to the saving directory. filename: The name of exported ONNX file. Defaults to "policy.onnx". verbose: Whether to print the model summary. Defaults to False. """ if not os.path.exists(path): os.makedirs(path, exist_ok=True) policy_exporter = _OnnxPolicyExporter(actor_critic, normalizer, verbose) policy_exporter.export(path, filename)
""" Helper Classes - Private. """ class _TorchPolicyExporter(torch.nn.Module): """Exporter of actor-critic into JIT file.""" def __init__(self, actor_critic, normalizer=None): super().__init__() self.actor = copy.deepcopy(actor_critic.actor) self.is_recurrent = actor_critic.is_recurrent if self.is_recurrent: self.rnn = copy.deepcopy(actor_critic.memory_a.rnn) self.rnn.cpu() self.register_buffer("hidden_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)) self.register_buffer("cell_state", torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size)) self.forward = self.forward_lstm self.reset = self.reset_memory # copy normalizer if exists if normalizer: self.normalizer = copy.deepcopy(normalizer) else: self.normalizer = torch.nn.Identity() def forward_lstm(self, x): x = self.normalizer(x) x, (h, c) = self.rnn(x.unsqueeze(0), (self.hidden_state, self.cell_state)) self.hidden_state[:] = h self.cell_state[:] = c x = x.squeeze(0) return self.actor(x) def forward(self, x): return self.actor(self.normalizer(x)) @torch.jit.export def reset(self): pass def reset_memory(self): self.hidden_state[:] = 0.0 self.cell_state[:] = 0.0 def export(self, path, filename): os.makedirs(path, exist_ok=True) path = os.path.join(path, filename) self.to("cpu") traced_script_module = torch.jit.script(self) traced_script_module.save(path) class _OnnxPolicyExporter(torch.nn.Module): """Exporter of actor-critic into ONNX file.""" def __init__(self, actor_critic, normalizer=None, verbose=False): super().__init__() self.verbose = verbose self.actor = copy.deepcopy(actor_critic.actor) self.is_recurrent = actor_critic.is_recurrent if self.is_recurrent: self.rnn = copy.deepcopy(actor_critic.memory_a.rnn) self.rnn.cpu() self.forward = self.forward_lstm # copy normalizer if exists if normalizer: self.normalizer = copy.deepcopy(normalizer) else: self.normalizer = torch.nn.Identity() def forward_lstm(self, x_in, h_in, c_in): x_in = self.normalizer(x_in) x, (h, c) = self.rnn(x_in.unsqueeze(0), (h_in, c_in)) x = x.squeeze(0) return self.actor(x), h, c def forward(self, x): return self.actor(self.normalizer(x)) def export(self, path, filename): self.to("cpu") if self.is_recurrent: obs = torch.zeros(1, self.rnn.input_size) h_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) c_in = torch.zeros(self.rnn.num_layers, 1, self.rnn.hidden_size) actions, h_out, c_out = self(obs, h_in, c_in) torch.onnx.export( self, (obs, h_in, c_in), os.path.join(path, filename), export_params=True, opset_version=11, verbose=self.verbose, input_names=["obs", "h_in", "c_in"], output_names=["actions", "h_out", "c_out"], dynamic_axes={}, ) else: obs = torch.zeros(1, self.actor[0].in_features) torch.onnx.export( self, obs, os.path.join(path, filename), export_params=True, opset_version=11, verbose=self.verbose, input_names=["obs"], output_names=["actions"], dynamic_axes={}, )