# Copyright (c) 2022-2025, The Isaac Lab Project Developers.# All rights reserved.## SPDX-License-Identifier: BSD-3-Clauseimportcopyimportosimporttorch
[文档]defexport_policy_as_jit(actor_critic:object,normalizer:object|None,path:str,filename="policy.pt"):"""Export policy into a Torch JIT file. Args: actor_critic: The actor-critic torch module. normalizer: The empirical normalizer module. If None, Identity is used. path: The path to the saving directory. filename: The name of exported JIT file. Defaults to "policy.pt". """policy_exporter=_TorchPolicyExporter(actor_critic,normalizer)policy_exporter.export(path,filename)
[文档]defexport_policy_as_onnx(actor_critic:object,path:str,normalizer:object|None=None,filename="policy.onnx",verbose=False):"""Export policy into a Torch ONNX file. Args: actor_critic: The actor-critic torch module. normalizer: The empirical normalizer module. If None, Identity is used. path: The path to the saving directory. filename: The name of exported ONNX file. Defaults to "policy.onnx". verbose: Whether to print the model summary. Defaults to False. """ifnotos.path.exists(path):os.makedirs(path,exist_ok=True)policy_exporter=_OnnxPolicyExporter(actor_critic,normalizer,verbose)policy_exporter.export(path,filename)
"""Helper Classes - Private."""class_TorchPolicyExporter(torch.nn.Module):"""Exporter of actor-critic into JIT file."""def__init__(self,actor_critic,normalizer=None):super().__init__()self.actor=copy.deepcopy(actor_critic.actor)self.is_recurrent=actor_critic.is_recurrentifself.is_recurrent:self.rnn=copy.deepcopy(actor_critic.memory_a.rnn)self.rnn.cpu()self.register_buffer("hidden_state",torch.zeros(self.rnn.num_layers,1,self.rnn.hidden_size))self.register_buffer("cell_state",torch.zeros(self.rnn.num_layers,1,self.rnn.hidden_size))self.forward=self.forward_lstmself.reset=self.reset_memory# copy normalizer if existsifnormalizer:self.normalizer=copy.deepcopy(normalizer)else:self.normalizer=torch.nn.Identity()defforward_lstm(self,x):x=self.normalizer(x)x,(h,c)=self.rnn(x.unsqueeze(0),(self.hidden_state,self.cell_state))self.hidden_state[:]=hself.cell_state[:]=cx=x.squeeze(0)returnself.actor(x)defforward(self,x):returnself.actor(self.normalizer(x))@torch.jit.exportdefreset(self):passdefreset_memory(self):self.hidden_state[:]=0.0self.cell_state[:]=0.0defexport(self,path,filename):os.makedirs(path,exist_ok=True)path=os.path.join(path,filename)self.to("cpu")traced_script_module=torch.jit.script(self)traced_script_module.save(path)class_OnnxPolicyExporter(torch.nn.Module):"""Exporter of actor-critic into ONNX file."""def__init__(self,actor_critic,normalizer=None,verbose=False):super().__init__()self.verbose=verboseself.actor=copy.deepcopy(actor_critic.actor)self.is_recurrent=actor_critic.is_recurrentifself.is_recurrent:self.rnn=copy.deepcopy(actor_critic.memory_a.rnn)self.rnn.cpu()self.forward=self.forward_lstm# copy normalizer if existsifnormalizer:self.normalizer=copy.deepcopy(normalizer)else:self.normalizer=torch.nn.Identity()defforward_lstm(self,x_in,h_in,c_in):x_in=self.normalizer(x_in)x,(h,c)=self.rnn(x_in.unsqueeze(0),(h_in,c_in))x=x.squeeze(0)returnself.actor(x),h,cdefforward(self,x):returnself.actor(self.normalizer(x))defexport(self,path,filename):self.to("cpu")ifself.is_recurrent:obs=torch.zeros(1,self.rnn.input_size)h_in=torch.zeros(self.rnn.num_layers,1,self.rnn.hidden_size)c_in=torch.zeros(self.rnn.num_layers,1,self.rnn.hidden_size)actions,h_out,c_out=self(obs,h_in,c_in)torch.onnx.export(self,(obs,h_in,c_in),os.path.join(path,filename),export_params=True,opset_version=11,verbose=self.verbose,input_names=["obs","h_in","c_in"],output_names=["actions","h_out","c_out"],dynamic_axes={},)else:obs=torch.zeros(1,self.actor[0].in_features)torch.onnx.export(self,obs,os.path.join(path,filename),export_params=True,opset_version=11,verbose=self.verbose,input_names=["obs"],output_names=["actions"],dynamic_axes={},)