|
70 | 70 | XLAProfiler, |
71 | 71 | ) |
72 | 72 | from pytorch_lightning.strategies import ParallelStrategy, Strategy |
73 | | -from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy |
74 | 73 | from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin |
75 | 74 | from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations |
76 | 75 | from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector |
|
106 | 105 | from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks |
107 | 106 | from pytorch_lightning.utilities.distributed import distributed_available |
108 | 107 | from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException |
109 | | -from pytorch_lightning.utilities.imports import _fault_tolerant_training |
110 | | -from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module |
| 108 | +from pytorch_lightning.utilities.imports import _fault_tolerant_training, _module_available |
111 | 109 | from pytorch_lightning.utilities.model_helpers import is_overridden |
112 | 110 | from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn |
113 | 111 | from pytorch_lightning.utilities.seed import isolate_rng |
@@ -1469,20 +1467,14 @@ def _call_setup_hook(self) -> None: |
1469 | 1467 |
|
1470 | 1468 | def _call_configure_sharded_model(self) -> None: |
1471 | 1469 | with self.strategy.model_sharded_context(): |
1472 | | - self._handle_meta_model() |
1473 | | - self._call_lightning_module_hook("configure_sharded_model") |
1474 | | - self._call_callback_hooks("on_configure_sharded_model") |
1475 | | - |
1476 | | - def _handle_meta_model(self) -> None: |
1477 | | - if not is_on_meta_device(self.lightning_module): |
1478 | | - return |
| 1470 | + # experimental support for torchdistx |
| 1471 | + if _module_available("torchdistx.deferred_init"): |
| 1472 | + from torchdistx.deferred_init import materialize_module |
1479 | 1473 |
|
1480 | | - if isinstance(self.strategy, DDPSpawnStrategy): |
1481 | | - raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") |
| 1474 | + materialize_module(self.lightning_module) |
1482 | 1475 |
|
1483 | | - materialize_module(self.lightning_module) |
1484 | | - # the trainer reference is lost during materialization |
1485 | | - self.lightning_module.trainer = proxy(self) |
| 1476 | + self._call_lightning_module_hook("configure_sharded_model") |
| 1477 | + self._call_callback_hooks("on_configure_sharded_model") |
1486 | 1478 |
|
1487 | 1479 | def _call_teardown_hook(self) -> None: |
1488 | 1480 | fn = self.state.fn._setup_fn |
|
0 commit comments