From dda5ad00e5128dd0265cc7fe98ba0a13f5587fa6 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Wed, 10 Jul 2024 07:36:06 -0700 Subject: [PATCH] Removed `_experimental_support_context_fn_in_torch_utils_checkpoint` [ghstack-poisoned] --- torchtitan/parallelisms/parallelize_llama.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 6cafa4abe4..7becb731bc 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -441,13 +441,6 @@ def apply_compile(model, job_config: JobConfig): transformer_block = torch.compile(transformer_block, dynamic=False) model.layers.register_module(layer_id, transformer_block) - ac_config = job_config.activation_checkpoint - if ac_config.mode == "selective" and ac_config.selective_ac_option == "op": - # some temp flags for torch.compile enablement + SAC - torch._dynamo.config._experimental_support_context_fn_in_torch_utils_checkpoint = ( - True - ) - logger.info("Compiled each TransformerBlock with torch.compile") return model