Skip to content

Commit fa8cdd4

Browse files
committed
address TODOs as 2D recompiles is fixed
ghstack-source-id: 2927f0a Pull Request resolved: #508
1 parent b99bc5e commit fa8cdd4

File tree

1 file changed

+5
-11
lines changed

1 file changed

+5
-11
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,9 +185,6 @@ def apply_tp(
185185
if enable_async_tp:
186186
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
187187

188-
# TODO: remove cache_size_limit adjustment after 2D compile is fixed
189-
torch._dynamo.config.cache_size_limit = 10000
190-
191188
torch._inductor.config._micro_pipeline_tp = True
192189
enable_symm_mem_for_group(tp_mesh.get_group().group_name)
193190

@@ -280,18 +277,15 @@ def apply_ac(model: nn.Module, ac_config):
280277

281278

282279
def apply_compile(model: nn.Module):
283-
"""Apply torch.compile to each transformer block."""
284-
285-
# the following flag can be used to to accelarate per-TransformerBlock compilation
286-
# TODO(bdhirsh): turning it off because it's currently not working with 2D
287-
# TODO(anijain): remove it after it's enabled in pytorch by default
288-
# torch._dynamo.config.inline_inbuilt_nn_modules = True
289-
280+
"""
281+
Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
282+
repeated structure. Alternatively one can compile the whole model (after applying DP).
283+
"""
290284
for layer_id, transformer_block in model.layers.named_children():
291285
transformer_block = torch.compile(transformer_block, fullgraph=True)
292286
model.layers.register_module(layer_id, transformer_block)
293287

294-
logger.info("Compiled each TransformerBlock with torch.compile")
288+
logger.info("Compiling each TransformerBlock with torch.compile")
295289
return model
296290

297291

0 commit comments

Comments
 (0)