多GPU和多节点训练#
Isaac Lab支持多GPU和多节点的强化学习。目前,此功能仅适用于RL-Games和skrl库工作流程。我们正在努力将此功能扩展到其他工作流程中。
注意
多GPU和多节点训练仅在Linux上受支持。目前不支持Windows。这是由于Windows上NCCL库的限制。
多GPU训练#
对于复杂的强化学习环境,可能希望跨多个GPU扩展训练。在Isaac Lab中可以通过分别使用 PyTorch分布式 框架或者 JAX distributed 模块来实现这一点。
torch.distributed()
在PyTorch中,API用于启动多个训练进程,其中进程的数量必须等于或小于可用的GPU数量。每个进程在专用GPU上运行,并启动其自己的Isaac Sim和Isaac Lab环境实例。在训练过程中,梯度在进程之间汇总,并在周期结束时广播回进程。
在JAX中,由于机器学习框架不会自动从单个程序调用中启动多个进程,skrl库提供了一个模块来启动它们。
要使用多个GPU进行训练,请使用以下命令,其中 --proc_per_node
表示可用的GPU数量:
python -m torch.distributed.run --nnodes=1 --nproc_per_node=2 source/standalone/workflows/rl_games/train.py --task=Isaac-Cartpole-v0 --headless --distributed
python -m torch.distributed.run --nnodes=1 --nproc_per_node=2 source/standalone/workflows/skrl/train.py --task=Isaac-Cartpole-v0 --headless --distributed
python -m skrl.utils.distributed.jax --nnodes=1 --nproc_per_node=2 source/standalone/workflows/skrl/train.py --task=Isaac-Cartpole-v0 --headless --distributed --ml_framework jax
多节点训练#
要在单台计算机上跨多个GPU扩展训练,还可以在多个节点上训练。要在多个节点/机器上训练,需要在每个节点上启动一个单独的进程。
对于主节点,请使用以下命令,其中 --nproc_per_node
表示可用的GPU数量, --nnodes
表示节点的数量:
python -m torch.distributed.run --nproc_per_node=2 --nnodes=2 --node_rank=0 --rdzv_id=123 --rdzv_backend=c10d --rdzv_endpoint=localhost:5555 source/standalone/workflows/rl_games/train.py --task=Isaac-Cartpole-v0 --headless --distributed
python -m torch.distributed.run --nproc_per_node=2 --nnodes=2 --node_rank=0 --rdzv_id=123 --rdzv_backend=c10d --rdzv_endpoint=localhost:5555 source/standalone/workflows/skrl/train.py --task=Isaac-Cartpole-v0 --headless --distributed
python -m skrl.utils.distributed.jax --nproc_per_node=2 --nnodes=2 --node_rank=0 --coordinator_address=ip_of_master_machine:5555 source/standalone/workflows/skrl/train.py --task=Isaac-Cartpole-v0 --headless --distributed --ml_framework jax
注意,端口( 5555
)可以更换为任何其他可用端口。
对于非主节点,请使用以下命令,将 --node_rank
替换为每台机器的索引:
python -m torch.distributed.run --nproc_per_node=2 --nnodes=2 --node_rank=1 --rdzv_id=123 --rdzv_backend=c10d --rdzv_endpoint=ip_of_master_machine:5555 source/standalone/workflows/rl_games/train.py --task=Isaac-Cartpole-v0 --headless --distributed
python -m torch.distributed.run --nproc_per_node=2 --nnodes=2 --node_rank=1 --rdzv_id=123 --rdzv_backend=c10d --rdzv_endpoint=ip_of_master_machine:5555 source/standalone/workflows/skrl/train.py --task=Isaac-Cartpole-v0 --headless --distributed
python -m skrl.utils.distributed.jax --nproc_per_node=2 --nnodes=2 --node_rank=1 --coordinator_address=ip_of_master_machine:5555 source/standalone/workflows/skrl/train.py --task=Isaac-Cartpole-v0 --headless --distributed --ml_framework jax
有关使用PyTorch进行多节点训练的更多详细信息,请访问 PyTorch 文档 。有关使用JAX进行多节点训练的更多详细信息,请访问 skrl 文档 和 JAX 文档 。
备注
如PyTorch文档中所述, 多节点训练受到节点间通信延迟的瓶颈
。当这种延迟较高时,多节点训练可能表现不如在单节点实例上运行。