Skip to content

Commit d0511d3

Browse files
committed
Update on "Add Sequence Parallelism to llama"
Somehow the torch.compile not working although eager sequence parallelism working, so currently don't turn it on by default [ghstack-poisoned]
2 parents ff2d82b + e88d386 commit d0511d3

File tree

5 files changed

+261
-2
lines changed

5 files changed

+261
-2
lines changed

.github/workflows/unit_test.yaml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
name: Unit Test
2+
3+
on:
4+
push:
5+
branches: [ main ]
6+
pull_request:
7+
8+
concurrency:
9+
group: unit-test${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_number || github.ref }}
10+
cancel-in-progress: true
11+
12+
defaults:
13+
run:
14+
shell: bash -l -eo pipefail {0}
15+
16+
jobs:
17+
unit_tests:
18+
runs-on: ubuntu-latest
19+
strategy:
20+
matrix:
21+
python-version: ['3.10']
22+
steps:
23+
- name: Check out repo
24+
uses: actions/checkout@v3
25+
- name: Setup conda env
26+
uses: conda-incubator/setup-miniconda@v2
27+
with:
28+
auto-update-conda: true
29+
miniconda-version: "latest"
30+
activate-environment: test
31+
python-version: ${{ matrix.python-version }}
32+
- name: Update pip
33+
run: python -m pip install --upgrade pip
34+
- name: Install dependencies
35+
run: |
36+
pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
37+
python -m pip install -r requirements.txt
38+
python -m pip install -r dev-requirements.txt
39+
python -m pip install -e .
40+
- name: Run unit tests with coverage
41+
run: pytest test --cov=. --cov-report=xml --durations=20 -vv
42+
- name: Upload Coverage to Codecov
43+
uses: codecov/codecov-action@v3

run_llama_train.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
77
MODEL="debugmodel"
88
NGPU=8
99
MP=4
10+
# Change this string to a meaningful one to enable checkpoint
11+
CHECKPOINT_FOLDER=""
12+
# Please adjust this to a longer interval period. The unit of measurement is in steps.
13+
CHECKPOINT_INTERVAL=5
1014

1115
torchrun --nproc_per_node=${NGPU} \
12-
train.py --steps 10 --compile
16+
train.py --steps 10 --compile \
17+
--checkpoint-folder=${CHECKPOINT_FOLDER} --checkpoint-interval=${CHECKPOINT_INTERVAL}

test/test_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3+
4+
# delete me after adding real tests..
5+
class Test:
6+
def test_test(self):
7+
assert True

torchtrain/checkpoint.py

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
3+
4+
import enum
5+
import os
6+
import re
7+
import time
8+
from typing import Any, Dict
9+
10+
import torch
11+
import torch.distributed as dist
12+
import torch.distributed.checkpoint as dcp
13+
import torch.nn as nn
14+
from torch.distributed.checkpoint.state_dict import (
15+
get_model_state_dict,
16+
get_optimizer_state_dict,
17+
set_model_state_dict,
18+
set_optimizer_state_dict,
19+
)
20+
from torchtrain.logging_utils import rank0_log
21+
22+
23+
class IntervalType(enum.Enum):
24+
SECONDS = enum.auto()
25+
STEPS = enum.auto()
26+
27+
28+
class ModelWrapper:
29+
def __init__(self, model: nn.Module) -> None:
30+
self.model = model
31+
32+
def state_dict(self) -> None:
33+
return get_model_state_dict(self.model)
34+
35+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
36+
set_model_state_dict(self.model, state_dict)
37+
38+
39+
class OptimizerWrapper:
40+
def __init__(self, model: nn.Module, optim: torch.optim.Optimizer) -> None:
41+
self.model = model
42+
self.optim = optim
43+
44+
def state_dict(self) -> None:
45+
return get_optimizer_state_dict(self.model, self.optim)
46+
47+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
48+
set_optimizer_state_dict(self.model, self.optim, optim_state_dict=state_dict)
49+
50+
51+
class CheckpointManager:
52+
def __init__(
53+
self,
54+
model: nn.Module,
55+
optimizer: torch.optim.Optimizer,
56+
states: Dict[str, Any],
57+
folder: str,
58+
interval_type: IntervalType,
59+
interval: int,
60+
) -> None:
61+
self.folder = folder
62+
self.states = states
63+
self.states.update(
64+
{
65+
"model": ModelWrapper(model),
66+
"optimizer": OptimizerWrapper(model, optimizer),
67+
}
68+
)
69+
self.interval_type = interval_type
70+
self.interval = interval
71+
self.begin = 0
72+
self.work = None
73+
self.pg = dist.new_group(backend="gloo")
74+
self.doit = None
75+
76+
def reset(self) -> None:
77+
self.begin = time.monotonic()
78+
79+
def create_checkpoint_id(self, step: int) -> str:
80+
return os.path.join(self.folder, f"step-{step}")
81+
82+
def save(self, curr_step: int, force: bool = False) -> None:
83+
if not self.folder:
84+
return
85+
86+
if not force:
87+
if self.interval_type == IntervalType.STEPS and not (
88+
curr_step % self.interval == 0
89+
):
90+
return
91+
if self.interval_type == IntervalType.SECONDS:
92+
doit = (time.monotonic() - self.begin) >= self.interval
93+
self.doit = torch.tensor(int(doit))
94+
if self.work is None:
95+
self.work = dist.all_reduce(self.doit, group=self.pg, async_op=True)
96+
return
97+
elif curr_step % 5 == 4:
98+
self.work.wait()
99+
self.work = None
100+
doit = self.doit.item()
101+
self.doit = None
102+
if doit == 0:
103+
return
104+
else:
105+
return
106+
107+
if self.work:
108+
self.work.wait()
109+
self.work = None
110+
self.doit = None
111+
112+
rank0_log(f"Saving a checkpoint in step {curr_step}.")
113+
begin = time.monotonic()
114+
dcp.save(self.states, checkpoint_id=self.create_checkpoint_id(curr_step))
115+
self.reset()
116+
rank0_log(
117+
f"Finish saving the checkpoint in step {curr_step}. "
118+
f"{time.monotonic() - begin} seconds"
119+
)
120+
121+
def load(self, step: int = -1) -> bool:
122+
if not self.folder:
123+
return False
124+
if not os.path.isdir(self.folder):
125+
return False
126+
if step != -1 and not os.path.isdir(self.create_checkpoint_id(step)):
127+
return False
128+
129+
if step == -1:
130+
step_counts = []
131+
for filename in os.listdir(self.folder):
132+
match = re.search(r"step-(\d+)", filename)
133+
if match:
134+
step_counts.append(int(match.group(1)))
135+
if not step_counts:
136+
return False
137+
step = max(step_counts)
138+
139+
rank0_log("Loading a checkpoint.")
140+
begin = time.monotonic()
141+
dcp.load(
142+
self.states,
143+
checkpoint_id=self.create_checkpoint_id(step),
144+
)
145+
rank0_log(f"Finish loading a checkpoint. {time.monotonic() - begin} seconds.")
146+
return True

train.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@
44
import argparse
55
import os
66
from dataclasses import dataclass, field
7-
from typing import List, Union
7+
from typing import Any, Dict, List, Union
88

99
# torch imports
1010
import torch
1111
import torch.nn.functional as F
1212
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
1313
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
1414

15+
from torchtrain.checkpoint import CheckpointManager, IntervalType
16+
1517
# torchtrain related
1618
from torchtrain.datasets import create_tokenizer, dataloader_fn
1719
from torchtrain.logging_utils import init_logger, rank0_log
@@ -29,6 +31,18 @@ class TrainState:
2931
current_loss: float = -1
3032
losses: List[float] = field(default_factory=list)
3133

34+
def state_dict(self) -> Dict[str, Any]:
35+
return {
36+
"step": torch.tensor(self.step, dtype=torch.int32),
37+
"current_loss": torch.tensor(self.current_loss, dtype=torch.float32),
38+
"losses": torch.tensor(self.current_loss, dtype=torch.float32),
39+
}
40+
41+
def load_state_dict(self, state_dict) -> None:
42+
self.step = state_dict["step"].item()
43+
self.current_loss = state_dict["current_loss"].item()
44+
self.losses = state_dict["losses"].tolist()
45+
3246

3347
def build_optimizer(model, args):
3448
# build optimizer
@@ -116,7 +130,22 @@ def main(args):
116130
# train loop
117131
model.train()
118132

133+
checkpoint = CheckpointManager(
134+
model=model,
135+
optimizer=optimizer,
136+
states={"train_state": train_state},
137+
folder=args.checkpoint_folder,
138+
interval_type=(
139+
IntervalType.SECONDS
140+
if args.checkpoint_interval_type == "seconds"
141+
else IntervalType.STEPS
142+
),
143+
interval=args.checkpoint_interval,
144+
)
145+
checkpoint.load()
146+
119147
with maybe_run_profiler() as torch_profiler:
148+
checkpoint.reset()
120149
while train_state.step < args.steps or args.steps == -1:
121150
train_state.step += 1
122151
# get batch
@@ -161,6 +190,8 @@ def main(args):
161190
)
162191
scheduler.step()
163192

193+
checkpoint.save(train_state.step, force=(train_state.step == args.steps))
194+
164195

165196
if __name__ == "__main__":
166197
parser = argparse.ArgumentParser(description="TorchTrain arg parser.")
@@ -224,6 +255,33 @@ def main(args):
224255
parser.add_argument(
225256
"--compile", action="store_true", help="Whether to compile the model."
226257
)
258+
parser.add_argument(
259+
"--checkpoint-interval",
260+
type=int,
261+
default=3600,
262+
help=(
263+
"Checkpointing interval. The unit of measurement is in seconds or "
264+
"steps depending on --checkpoint-internval-type."
265+
),
266+
)
267+
parser.add_argument(
268+
"--checkpoint-interval-type",
269+
type=str,
270+
default="steps",
271+
help=(
272+
"The checkpointing interval unit of measurement."
273+
"The default value is step."
274+
),
275+
)
276+
parser.add_argument(
277+
"--checkpoint-folder",
278+
type=str,
279+
default="",
280+
help=(
281+
"The folder to store the checkpoints. If this is not specified or "
282+
"is an empty string, checkpointing is disabled."
283+
),
284+
)
227285

228286
args = parser.parse_args()
229287
main(args)

0 commit comments

Comments
 (0)