File tree Expand file tree Collapse file tree 1 file changed +5
-11
lines changed Expand file tree Collapse file tree 1 file changed +5
-11
lines changed Original file line number Diff line number Diff line change @@ -185,9 +185,6 @@ def apply_tp(
185185 if enable_async_tp :
186186 from torch .distributed ._symmetric_memory import enable_symm_mem_for_group
187187
188- # TODO: remove cache_size_limit adjustment after 2D compile is fixed
189- torch ._dynamo .config .cache_size_limit = 10000
190-
191188 torch ._inductor .config ._micro_pipeline_tp = True
192189 enable_symm_mem_for_group (tp_mesh .get_group ().group_name )
193190
@@ -280,18 +277,15 @@ def apply_ac(model: nn.Module, ac_config):
280277
281278
282279def apply_compile (model : nn .Module ):
283- """Apply torch.compile to each transformer block."""
284-
285- # the following flag can be used to to accelarate per-TransformerBlock compilation
286- # TODO(bdhirsh): turning it off because it's currently not working with 2D
287- # TODO(anijain): remove it after it's enabled in pytorch by default
288- # torch._dynamo.config.inline_inbuilt_nn_modules = True
289-
280+ """
281+ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to
282+ repeated structure. Alternatively one can compile the whole model (after applying DP).
283+ """
290284 for layer_id , transformer_block in model .layers .named_children ():
291285 transformer_block = torch .compile (transformer_block , fullgraph = True )
292286 model .layers .register_module (layer_id , transformer_block )
293287
294- logger .info ("Compiled each TransformerBlock with torch.compile" )
288+ logger .info ("Compiling each TransformerBlock with torch.compile" )
295289 return model
296290
297291
You can’t perform that action at this time.
0 commit comments