Skip to content

Commit 180b213

Browse files
committed
add selective activation checkpointing
ghstack-source-id: 2fbae95 Pull Request resolved: #97
1 parent 96d1cb1 commit 180b213

File tree

3 files changed

+63
-13
lines changed

3 files changed

+63
-13
lines changed

torchtrain/config_manager.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -94,17 +94,17 @@ def init_args_from_command_line(
9494
help="collect profiler traces every x iterations",
9595
)
9696
# metrics configs
97+
parser.add_argument(
98+
"--metrics.enable_tensorboard",
99+
action="store_true",
100+
help="whether to log metrics to TensorBoard",
101+
)
97102
parser.add_argument(
98103
"--metrics.log_freq",
99104
type=int,
100105
default=10,
101106
help="how often to log metrics to TensorBoard",
102107
)
103-
parser.add_argument(
104-
"--metrics.enable_tensorboard",
105-
action="store_true",
106-
help="how often to log metrics to TensorBoard",
107-
)
108108
parser.add_argument(
109109
"--metrics.save_tb_folder",
110110
type=str,
@@ -215,4 +215,9 @@ def init_args_from_command_line(
215215
"is an empty string, checkpointing is disabled."
216216
),
217217
)
218+
parser.add_argument(
219+
"--metrics.enable_selective_ac",
220+
action="store_false",
221+
help="whether to enable selective activation checkpointing",
222+
)
218223
return parser.parse_args(args_list)

torchtrain/parallelisms/parallelize_llama.py

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
RowwiseParallel,
3434
)
3535
from torchtrain.config_manager import JobConfig
36-
3736
from torchtrain.logging_utils import rank0_log
3837

3938
logger = logging.getLogger(__name__)
@@ -67,12 +66,54 @@ def partition_fn(name, module, device_mesh):
6766
)
6867

6968

69+
# AC/selective AC
70+
no_recompute_list = {
71+
torch.ops.aten.mm.default,
72+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
73+
torch.ops.aten._scaled_dot_product_flash_attention.default,
74+
torch.ops.c10d_functional.reduce_scatter_tensor.default,
75+
}
76+
7077
# Uses PTD FSDP AC wrapper
71-
# TODO: why is config needed here?
72-
def checkpoint_wrapper(module, job_config: JobConfig):
73-
return ptd_checkpoint_wrapper(
74-
module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False
75-
)
78+
def checkpoint_wrapper(module, enable_selective_ac=False):
79+
if enable_selective_ac:
80+
from torch.utils.checkpoint import (
81+
_pt2_selective_checkpoint_context_fn_gen,
82+
checkpoint,
83+
)
84+
85+
def _get_custom_policy(meta):
86+
def _custom_policy(mode, func, *args, **kwargs):
87+
mm_count_key = f"{mode}_mm_count"
88+
if mm_count_key not in meta:
89+
meta[mm_count_key] = 0
90+
if func == torch.ops.aten.mm.default:
91+
meta[mm_count_key] += 1
92+
# Saves output of all compute ops, except every second mm
93+
return func in no_recompute_list and not (
94+
func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0
95+
)
96+
97+
return _custom_policy
98+
99+
def selective_checkpointing_context_fn():
100+
meta = {}
101+
return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta))
102+
103+
return ptd_checkpoint_wrapper(
104+
module,
105+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
106+
checkpoint_fn=checkpoint,
107+
context_fn=selective_checkpointing_context_fn,
108+
use_reentrant=False,
109+
preserve_rng_state=False,
110+
)
111+
else:
112+
return ptd_checkpoint_wrapper(
113+
module,
114+
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
115+
preserve_rng_state=False,
116+
)
76117

77118

78119
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
@@ -168,10 +209,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
168209

169210
with enable_wrap(wrapper_cls=FSDP, **fsdp_config):
170211
for layer_id, transformer_block in enumerate(model.layers):
171-
# apply AC to each layer
172212
# before wrapping with FSDP, we need to make sure the layer is on GPU
173213
transformer_block = transformer_block.cuda()
174-
transformer_block = checkpoint_wrapper(transformer_block, job_config)
214+
215+
# apply selective AC
216+
transformer_block = checkpoint_wrapper(
217+
transformer_block, job_config.training.enable_selective_ac
218+
)
175219

176220
# Wraps each layer with FSDP
177221
model.layers[layer_id] = wrap(transformer_block)

train_configs/debug_model.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,4 @@ checkpoint_interval = 3600
3737
checkpoint_interval_type = "steps"
3838
checkpoint_folder = ""
3939
dataset = "alpaca"
40+
enable_selective_ac = false

0 commit comments

Comments
 (0)