Skip to content

Conversation

@mrshenli
Copy link
Contributor

@mrshenli mrshenli commented Apr 12, 2020

This example shows how to use RPC to implement pipeline parallelism. This can be viewed as a distributed version of single machine multiple GPU pipeline parallelism.

The numbers below show how the total execution time decreases with the increase of num_split.

$ python main.py 
Processing batch 0
Processing batch 1
Processing batch 2
number of splits = 1, execution time = 16.45062756538391
Processing batch 0
Processing batch 1
Processing batch 2
number of splits = 2, execution time = 12.329529762268066
Processing batch 0
Processing batch 1
Processing batch 2
number of splits = 4, execution time = 10.164430618286133
Processing batch 0
Processing batch 1
Processing batch 2
number of splits = 8, execution time = 9.076049566268921

@mrshenli mrshenli changed the title [WIP] Adding distributed pipeline parallelism example Adding distributed pipeline parallelism example Apr 13, 2020
Copy link
Member

@osalpekar osalpekar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left a few suggestions for comments. Thanks for the great example @mrshenli!

labels = torch.zeros(batch_size, num_classes) \
.scatter_(1, one_hot_indices, 1)

with dist_autograd.context() as context_id:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add a short comment about what dist_autograd/dist_optimizer is doing here?


return nn.Sequential(*layers)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should include some comments about what we're doing here at a high level (defining resnet with 2 partitions so we can place them on separate machines.) Also, should we call these Partitions or Shards instead of parts?

Distributed Pipeline Parallel Example

This example shows how to distribute a ResNet50 model on two RPC workers and
then implement distributed pipeline parallelism using RPC.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we include a quick description of the pipelining strategy (pipelining micro-batches within a batch and then synchronously running the optimizer step)? Since this is like GPipe, should we also link the paper here?

@soumith soumith merged commit d431037 into pytorch:master Apr 23, 2020
@mrshenli mrshenli deleted the pipeline branch July 1, 2020 20:01
YinZhengxun pushed a commit to YinZhengxun/mt-exercise-02 that referenced this pull request Mar 30, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants