Skip to content
This repository was archived by the owner on Aug 5, 2025. It is now read-only.

Conversation

@kwen2501
Copy link
Contributor

@kwen2501 kwen2501 commented Dec 26, 2023

Description

This PR adds support to a case where the user creates model and trace model on CPU, then creates pipeline stage on GPU.
PiPPy would move only the stage module to the corresponding GPU.

Test

torchrun --nproc-per-node 2 test_cpu_init.py

Update:

Sometimes, the forward function of user code may create constant tensors based on input device:

device = input_ids.device
attention_mask = torch.ones(…, device=device)

As of now, PT2 tracer does not treat input_ids.device as a symbolic device. As a result, device="cpu" got burned in the generated code:

ones = torch.ones(…, device = device(type='cpu'))

To workaround this, this PR added call in PipelineStage creation:

def _move_ops_to_device(new_device)

After this call, the device= kwarg of torch.ones will be modified to the new_device.
This call is hidden from user, thus when symbolic device support is added, we can silently remove this and not involve user code change.

We also checked native_functions.yaml, all APIs involving the "device" kwarg are generator ops, which are safe to change the device value. (And we should).

Real Example

cd examples/cpu_init
torchrun --nproc-per-node 4 bert_cpu_init.py

Cc: @muellerzr @SunMarc

Copy link
Contributor

@lessw2020 lessw2020 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good overall!
I did find an issue on latest nightlies that will require a change in IR.py to make it work there, but isn't directly related to this PR.
Specifically there was a tracing refactor that will generate this issue:
AttributeError: module 'torch._export' has no attribute '_export_to_torch_ir'

Modifying the import to use the new location (torch.export rather than torch._export) and all is well - simple fix ala:
updated_torch_export

Not sure how you want to integrate that but that would be the one item to be fixed to ensure code works on latest nightlies.
With that:
cpu_completion
and
cpu_test_equivalence

@lessw2020
Copy link
Contributor

note - to separate the import issue I hit from this PR, as it is independent, made a new PR expressly for that:
#924

@kwen2501 kwen2501 merged commit a4cc35f into main Dec 28, 2023
kwen2501 added a commit that referenced this pull request Jan 2, 2024
## Description

This PR adds support to a case where the user creates model and trace
model on CPU, then creates pipeline stage on GPU.
PiPPy would move only the stage module to the corresponding GPU.

## Test
```
torchrun --nproc-per-node 2 test_cpu_init.py
```

## Update:
Sometimes, the `forward` function of user code may create constant
tensors based on input device:
```
device = input_ids.device
attention_mask = torch.ones(…, device=device)
```
As of now, PT2 tracer does not treat `input_ids.device` as a symbolic
device. As a result, `device="cpu"` got burned in the generated code:
```
ones = torch.ones(…, device = device(type='cpu'))
```
To workaround this, this PR added call in `PipelineStage` creation:
```
def _move_ops_to_device(new_device)
```
After this call, the `device=` kwarg of `torch.ones` will be modified to
the `new_device`.
This call is hidden from user, thus when symbolic device support is
added, we can silently remove this and not involve user code change.

We also checked native_functions.yaml, all APIs involving the "device"
kwarg are generator ops, which are safe to change the device value. (And
we should).

## Real Example
```
cd examples/cpu_init
torchrun --nproc-per-node 4 bert_cpu_init.py
```
 
Cc: @muellerzr @SunMarc
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants