This repository was archived by the owner on Aug 5, 2025. It is now read-only.
Support device dispatching during stage creation #923
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.



Description
This PR adds support to a case where the user creates model and trace model on CPU, then creates pipeline stage on GPU.
PiPPy would move only the stage module to the corresponding GPU.
Test
Update:
Sometimes, the
forwardfunction of user code may create constant tensors based on input device:As of now, PT2 tracer does not treat
input_ids.deviceas a symbolic device. As a result,device="cpu"got burned in the generated code:To workaround this, this PR added call in
PipelineStagecreation:After this call, the
device=kwarg oftorch.oneswill be modified to thenew_device.This call is hidden from user, thus when symbolic device support is added, we can silently remove this and not involve user code change.
We also checked native_functions.yaml, all APIs involving the "device" kwarg are generator ops, which are safe to change the device value. (And we should).
Real Example
Cc: @muellerzr @SunMarc