Elastic Training


By applying TorchElastic, bagua can do elastic training. We usually use the capabilities of Elastic Training to support the following two types of jobs:

Fault tolerant jobs

Jobs that run on infrastructure where nodes get replaced frequently, either due to flaky hardware or by design. Or mission critical production grade jobs that need to be run with resilience to failures.

Dynamic capacity management

Jobs that run on preemptible resources that can be taken away at any time (e.g. AWS spot instances) or shared pools where the pool size can change dynamically based on demand.


You can find a complete example at Bagua examples.

1. Make your program recoverable

Elastic training means that new nodes will be added during the training process. Your training program need to save the training status in time, so that the new joining process can join the training from the most recent state.

For example:

model = ...

for train_loop():
    torch.save(model.state_dict(), YOUR_CHECKPOINT_PATH)

2. Launch job

You can launch elastic training job with bagua.distributed.run. For example:

Fault tolerant (fixed number of workers, no elasticity)

python -m bagua.distributed.run \
        --nnodes=NUM_NODES \
        --nproc_per_node=NUM_TRAINERS \
        --rdzv_id=JOB_ID \
        --rdzv_backend=c10d \
        --rdzv_endpoint=HOST_NODE_ADDR \
        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

Part of the node failure will not cause the job to fail, the job will wait for the node to recover.

HOST_NODE_ADDR, in form <host>[:<port>] (e.g. node1.example.com:29400), specifies the node and the port on which the C10d rendezvous backend should be instantiated and hosted. It can be any node in your training cluster, but ideally you should pick a node that has a high bandwidth.

If no port number is specified HOST_NODE_ADDR defaults to <host>:29400.

Elastic training(min=1, max=4)

python -m bagua.distributed.run \
        --nnodes=1:4 \
        --nproc_per_node=NUM_TRAINERS \
        --rdzv_id=JOB_ID \
        --rdzv_backend=c10d \
        --rdzv_endpoint=HOST_NODE_ADDR \
        YOUR_TRAINING_SCRIPT.py (--arg1 ... train script args...)

For this example, the number of training nodes can be dynamically adjusted from 1 to 4.


  1. PyTorch Elastic overview
  2. torch.distributed.run API Doc