Skip to content

Commit dca7657

Browse files
author
Andrew Gu
committed
Removed setting global flag for swap_tensors since not needed anymore
ghstack-source-id: 484237b Pull Request resolved: #178
1 parent 49f9784 commit dca7657

File tree

1 file changed

+1
-3
lines changed

1 file changed

+1
-3
lines changed

train.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,7 @@ def main(job_config: JobConfig):
197197
model = models_parallelize_fns[model_name](
198198
model, world_mesh, parallel_dims, job_config
199199
)
200-
# set this as required by DTensor to work with `to_empty`
201-
# TODO: remove in the future when enabled by default for wrapper subclasses
202-
torch.__future__.set_swap_module_params_on_conversion(True)
200+
# allocate sharded model on GPU and initialize weights via DTensor
203201
model.to_empty(device="cuda")
204202
model.init_weights()
205203

0 commit comments

Comments
 (0)