Skip to content

Conversation

@mori360
Copy link
Contributor

@mori360 mori360 commented Oct 17, 2024

resolve #620
Add config: --training.enable_cpu_offload

Command: CONFIG_FILE="./train_configs/llama3_8b.toml" ./run_llama_train.sh

For non-pp case:
Screenshot 2024-10-23 at 1 45 56 PM

For pp case:
cpu offload+pp

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 17, 2024
@awgu awgu changed the title Enable FDSP2 cpu offloading Enable FSDP2 cpu offloading Oct 18, 2024
train.py Outdated
model.to_empty(device=init_device)
model.init_weights()
if job_config.training.enable_cpu_offload:
with torch.device("cuda"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I somehow think the model.init_weights(buffer_device="cuda") change sounds better. It is straightforward on what we want to achieve, while making minimum change to code.

@mori360 mori360 requested a review from tianyu-l October 24, 2024 23:52
@mori360 mori360 marked this pull request as ready for review October 24, 2024 23:52
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good in general. Had some inline comments. In particular, let's figure out if CPU offloading and PP should coexist; if so we should add support for that as well.

train.py Outdated
init_device = (
"cpu"
if job_config.checkpoint.create_seed_checkpoint
or job_config.training.enable_cpu_offload
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to do the same for the PP case (several lines above)? Or are we assuming if PP is used, CPU offloading is not an option?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I test it with pp.
Based on llama2_7b.toml, cpu_offload also works but has a much obvious latency in training.
Screenshot 2024-10-25 at 3 40 25 PM

@mori360 mori360 marked this pull request as draft October 25, 2024 02:46
"Enable CPU Offload with PP",
"enable_cpu_offload+PP",
ngpu=4,
),
Copy link
Contributor Author

@mori360 mori360 Oct 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test with pp, could remove pp later if not necessary in the CI test

@mori360 mori360 marked this pull request as ready for review October 28, 2024 21:55
@mori360 mori360 requested a review from tianyu-l October 28, 2024 21:55
@mori360 mori360 requested a review from awgu October 28, 2024 21:57
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm! thanks!

@mori360 mori360 merged commit 193ce98 into pytorch:main Oct 28, 2024
5 checks passed
mori360 added a commit to mori360/torchtitan that referenced this pull request Nov 26, 2024
resolve pytorch#620 
Add config: `--training.enable_cpu_offload`

Command: `CONFIG_FILE="./train_configs/llama3_8b.toml"
./run_llama_train.sh`

For non-pp case:
<img width="611" alt="Screenshot 2024-10-23 at 1 45 56 PM"
src="https://github.com/user-attachments/assets/8692f8a6-c0f3-460e-8eb6-7f7195bed370">

For pp case:
<img width="587" alt="cpu offload+pp"
src="https://github.com/user-attachments/assets/73e40861-47e2-4845-a41c-4bfea2860109">
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Is there way to offload training memory to DRAM (using FSDP2?) for training Llama3-8B with torchtitan?

3 participants