|
33 | 33 | RowwiseParallel, |
34 | 34 | ) |
35 | 35 | from torchtrain.config_manager import JobConfig |
36 | | - |
37 | 36 | from torchtrain.logging_utils import rank0_log |
38 | 37 |
|
39 | 38 | logger = logging.getLogger(__name__) |
@@ -67,12 +66,54 @@ def partition_fn(name, module, device_mesh): |
67 | 66 | ) |
68 | 67 |
|
69 | 68 |
|
| 69 | +# AC/selective AC |
| 70 | +no_recompute_list = { |
| 71 | + torch.ops.aten.mm.default, |
| 72 | + torch.ops.aten._scaled_dot_product_efficient_attention.default, |
| 73 | + torch.ops.aten._scaled_dot_product_flash_attention.default, |
| 74 | + torch.ops.c10d_functional.reduce_scatter_tensor.default, |
| 75 | +} |
| 76 | + |
70 | 77 | # Uses PTD FSDP AC wrapper |
71 | | -# TODO: why is config needed here? |
72 | | -def checkpoint_wrapper(module, job_config: JobConfig): |
73 | | - return ptd_checkpoint_wrapper( |
74 | | - module, checkpoint_impl=CheckpointImpl.NO_REENTRANT, preserve_rng_state=False |
75 | | - ) |
| 78 | +def checkpoint_wrapper(module, enable_selective_ac=False): |
| 79 | + if enable_selective_ac: |
| 80 | + from torch.utils.checkpoint import ( |
| 81 | + _pt2_selective_checkpoint_context_fn_gen, |
| 82 | + checkpoint, |
| 83 | + ) |
| 84 | + |
| 85 | + def _get_custom_policy(meta): |
| 86 | + def _custom_policy(mode, func, *args, **kwargs): |
| 87 | + mm_count_key = f"{mode}_mm_count" |
| 88 | + if mm_count_key not in meta: |
| 89 | + meta[mm_count_key] = 0 |
| 90 | + if func == torch.ops.aten.mm.default: |
| 91 | + meta[mm_count_key] += 1 |
| 92 | + # Saves output of all compute ops, except every second mm |
| 93 | + return func in no_recompute_list and not ( |
| 94 | + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 |
| 95 | + ) |
| 96 | + |
| 97 | + return _custom_policy |
| 98 | + |
| 99 | + def selective_checkpointing_context_fn(): |
| 100 | + meta = {} |
| 101 | + return _pt2_selective_checkpoint_context_fn_gen(_get_custom_policy(meta)) |
| 102 | + |
| 103 | + return ptd_checkpoint_wrapper( |
| 104 | + module, |
| 105 | + checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
| 106 | + checkpoint_fn=checkpoint, |
| 107 | + context_fn=selective_checkpointing_context_fn, |
| 108 | + use_reentrant=False, |
| 109 | + preserve_rng_state=False, |
| 110 | + ) |
| 111 | + else: |
| 112 | + return ptd_checkpoint_wrapper( |
| 113 | + module, |
| 114 | + checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
| 115 | + preserve_rng_state=False, |
| 116 | + ) |
76 | 117 |
|
77 | 118 |
|
78 | 119 | def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): |
@@ -168,10 +209,13 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig): |
168 | 209 |
|
169 | 210 | with enable_wrap(wrapper_cls=FSDP, **fsdp_config): |
170 | 211 | for layer_id, transformer_block in enumerate(model.layers): |
171 | | - # apply AC to each layer |
172 | 212 | # before wrapping with FSDP, we need to make sure the layer is on GPU |
173 | 213 | transformer_block = transformer_block.cuda() |
174 | | - transformer_block = checkpoint_wrapper(transformer_block, job_config) |
| 214 | + |
| 215 | + # apply selective AC |
| 216 | + transformer_block = checkpoint_wrapper( |
| 217 | + transformer_block, job_config.training.enable_selective_ac |
| 218 | + ) |
175 | 219 |
|
176 | 220 | # Wraps each layer with FSDP |
177 | 221 | model.layers[layer_id] = wrap(transformer_block) |
|
0 commit comments