From 100485dfae66087b0978e50b0b860dce4b773674 Mon Sep 17 00:00:00 2001 From: Andrew Gu Date: Fri, 29 Mar 2024 11:36:52 -0700 Subject: [PATCH] Removed setting global flag for `swap_tensors` since not needed anymore [ghstack-poisoned] --- train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/train.py b/train.py index c57771df7a..849ae78498 100644 --- a/train.py +++ b/train.py @@ -197,9 +197,7 @@ def main(job_config: JobConfig): model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # set this as required by DTensor to work with `to_empty` - # TODO: remove in the future when enabled by default for wrapper subclasses - torch.__future__.set_swap_module_params_on_conversion(True) + # allocate sharded model on GPU and initialize weights via DTensor model.to_empty(device="cuda") model.init_weights()