diff --git a/train.py b/train.py index c57771df7a..849ae78498 100644 --- a/train.py +++ b/train.py @@ -197,9 +197,7 @@ def main(job_config: JobConfig): model = models_parallelize_fns[model_name]( model, world_mesh, parallel_dims, job_config ) - # set this as required by DTensor to work with `to_empty` - # TODO: remove in the future when enabled by default for wrapper subclasses - torch.__future__.set_swap_module_params_on_conversion(True) + # allocate sharded model on GPU and initialize weights via DTensor model.to_empty(device="cuda") model.init_weights()