We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
swap_tensors
1 parent 49f9784 commit dca7657Copy full SHA for dca7657
train.py
@@ -197,9 +197,7 @@ def main(job_config: JobConfig):
197
model = models_parallelize_fns[model_name](
198
model, world_mesh, parallel_dims, job_config
199
)
200
- # set this as required by DTensor to work with `to_empty`
201
- # TODO: remove in the future when enabled by default for wrapper subclasses
202
- torch.__future__.set_swap_module_params_on_conversion(True)
+ # allocate sharded model on GPU and initialize weights via DTensor
203
model.to_empty(device="cuda")
204
model.init_weights()
205
0 commit comments