From a7764061d27a5e74407aa4893e8f4703e04d115d Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 11 Nov 2021 22:26:01 +0000 Subject: [PATCH 01/20] update --- pytorch_lightning/utilities/meta.py | 17 +++++++++++++++++ tests/utilities/test_meta.py | 1 + 2 files changed, 18 insertions(+) diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 60e6cc791b7ae..da3be04989c4f 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -272,6 +272,15 @@ def add_subclasses(subclass): if subclass.__bases__[0] != torch.nn.modules.module.Module: _MetaClass.add_subclasses(subclass.__bases__[0]) + def __subclasscheck__(cls, sub): + breakpoint() + + def __subclasshook__(cls, C): + breakpoint() + if cls is _MetaClass: + return isinstance(subclass, cls.__bases__[0]) + return False + def __new__(cls, *args, **kwargs): subclass = cls.__bases__[0] cls.add_subclasses(subclass) @@ -312,6 +321,12 @@ def search(mod: ModuleType) -> List[ModuleType]: setattr(mod, subclass.__name__, _MetaClass) +def mock_isinstance(A, B, isinstance=None): + if isinstance(B, type) and "_MetaClass" in B.__name__: + return isinstance(A, B.__bases__[0]) + return isinstance(A, B) + + @contextmanager def init_meta_context() -> Generator: rank_zero_warn( @@ -319,5 +334,7 @@ def init_meta_context() -> Generator: "where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11." ) _set_meta_device() + __builtins__["isinstance"] = partial(mock_isinstance, isinstance=isinstance) yield + __builtins__["isinstance"] = isinstance.keywords["isinstance"] _unset_meta_device() diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py index 8e36a86c3beef..7dfb8bbeb78a8 100644 --- a/tests/utilities/test_meta.py +++ b/tests/utilities/test_meta.py @@ -36,6 +36,7 @@ def test_init_meta_context(): with init_meta_context(): m = nn.Linear(in_features=1, out_features=1) + assert isinstance(m, nn.Linear) assert m.weight.device.type == "meta" mlp = MLP(4) assert mlp.layer[0].weight.device.type == "meta" From 299f0493820ada70a734eb2146fefaed1936a894 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 11 Nov 2021 22:27:03 +0000 Subject: [PATCH 02/20] update --- pytorch_lightning/utilities/meta.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index da3be04989c4f..4ee5148050587 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -272,15 +272,6 @@ def add_subclasses(subclass): if subclass.__bases__[0] != torch.nn.modules.module.Module: _MetaClass.add_subclasses(subclass.__bases__[0]) - def __subclasscheck__(cls, sub): - breakpoint() - - def __subclasshook__(cls, C): - breakpoint() - if cls is _MetaClass: - return isinstance(subclass, cls.__bases__[0]) - return False - def __new__(cls, *args, **kwargs): subclass = cls.__bases__[0] cls.add_subclasses(subclass) From 2c08d3a22d970bd163e3d7eeb2996942febab45a Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 12 Nov 2021 07:20:55 -0500 Subject: [PATCH 03/20] update --- .../core/mixins/device_dtype_mixin.py | 6 +++++- pytorch_lightning/trainer/trainer.py | 15 ++++++++++++++- pytorch_lightning/utilities/meta.py | 10 +++++----- 3 files changed, 24 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index e02790edddd1e..9ecd2556f016d 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -17,6 +17,8 @@ import torch from torch.nn import Module +import pytorch_lightning as pl + class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ["device", "dtype"] @@ -177,7 +179,9 @@ def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None ) -> None: def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: - if not isinstance(module, DeviceDtypeModuleMixin): + # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't + # work when using `init_meta_device`. + if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): return if device is not None: module._device = device diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b6dfcbfee8bc6..72510339b6f8c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1404,10 +1404,23 @@ def _call_setup_hook(self) -> None: def _call_configure_sharded_model(self) -> None: with self.accelerator.model_sharded_context(): - materialize_module(self.lightning_module) + self._handle_meta_model() self.call_hook("configure_sharded_model") self.call_hook("on_configure_sharded_model") + def _handle_meta_model(self) -> None: + param = next(self.lightning_module.parameters()) + if param.device.type != "meta": + return + + if isinstance(self.training_type_plugin, DDPSpawnPlugin): + raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") + + materialize_module(self.lightning_module) + self.lightning_module.trainer = proxy(self) + # TODO: Find a better place to move the newly materialized model to the device + self.training_type_plugin.model_to_device() + def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 4ee5148050587..a173d040b114a 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -25,6 +25,7 @@ from torch.nn import Module from torch.nn.modules.container import ModuleDict, ModuleList, Sequential +import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_10 @@ -112,7 +113,6 @@ def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: def create_instance(module=None) -> Module: if module: module.__init__(*args, **kwargs) - return module return module_fn(*args, **kwargs) if _tls.in_call: @@ -185,13 +185,12 @@ def materialize_module(root_module: nn.Module) -> nn.Module: if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)): materialize_module(child) else: - setattr(child, name, materialize_fn()) + materialize_fn() return root_module # cache subclasses to optimize the search when resetting the meta device later on. __STORAGE_META__ = {} - __CREATED_MODULES__ = set() @@ -237,7 +236,7 @@ def _set_meta_device() -> None: for subclass in get_all_subclasses(torch.nn.modules.module.Module): - if isinstance(subclass, (Sequential, ModuleList, ModuleDict)): + if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule): continue # if a subclass has already been stored, we should use the cache @@ -268,7 +267,8 @@ def materialize(cls, materialize_fn: Callable): @staticmethod def add_subclasses(subclass): """This is used to unrol the instantion tree while creating the modules.""" - __CREATED_MODULES__.add(subclass) + if subclass != pl.LightningModule: + __CREATED_MODULES__.add(subclass) if subclass.__bases__[0] != torch.nn.modules.module.Module: _MetaClass.add_subclasses(subclass.__bases__[0]) From 8077db8ba37a703ff5ab07f4416960920fa31efc Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 12 Nov 2021 12:26:33 +0000 Subject: [PATCH 04/20] update changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0082201aa1cf9..e580108a31749 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -121,6 +121,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) +- Fix `isinstance` not working with `init_meta_context`, materialize model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493)) + + - From 9f749a89c6c94659cbb5aa40498204380d6039c1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 12 Nov 2021 12:42:03 +0000 Subject: [PATCH 05/20] update on comments --- pytorch_lightning/trainer/trainer.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 72510339b6f8c..17d81ff88441a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -38,7 +38,14 @@ from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.plugins import DDPSpawnPlugin, ParallelPlugin, PLUGIN_INPUT, PrecisionPlugin, TrainingTypePlugin +from pytorch_lightning.plugins import ( + DDPSpawnPlugin, + ParallelPlugin, + PLUGIN_INPUT, + PrecisionPlugin, + TPUSpawnPlugin, + TrainingTypePlugin, +) from pytorch_lightning.profiler import ( AdvancedProfiler, BaseProfiler, @@ -1413,7 +1420,7 @@ def _handle_meta_model(self) -> None: if param.device.type != "meta": return - if isinstance(self.training_type_plugin, DDPSpawnPlugin): + if isinstance(self.training_type_plugin, (DDPSpawnPlugin, TPUSpawnPlugin)): raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") materialize_module(self.lightning_module) From 0d2c68b3b4ec9194ef864930fd4a4815241a734d Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 12 Nov 2021 12:43:13 +0000 Subject: [PATCH 06/20] udpate --- pytorch_lightning/utilities/meta.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index a173d040b114a..a6d7f190d9730 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -113,6 +113,7 @@ def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: def create_instance(module=None) -> Module: if module: module.__init__(*args, **kwargs) + return module return module_fn(*args, **kwargs) if _tls.in_call: @@ -185,7 +186,7 @@ def materialize_module(root_module: nn.Module) -> nn.Module: if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)): materialize_module(child) else: - materialize_fn() + setattr(child, name, materialize_fn()) return root_module From 4e9d805b45f1c6854de2ba36c5d561a5833234f0 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 12 Nov 2021 12:44:44 +0000 Subject: [PATCH 07/20] update --- pytorch_lightning/utilities/meta.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index a6d7f190d9730..386ea7d4d90f2 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -267,7 +267,8 @@ def materialize(cls, materialize_fn: Callable): @staticmethod def add_subclasses(subclass): - """This is used to unrol the instantion tree while creating the modules.""" + """This is used to unroll the instantion tree while creating the modules.""" + # Don't store the LightningModule as skiped from the Meta process. if subclass != pl.LightningModule: __CREATED_MODULES__.add(subclass) if subclass.__bases__[0] != torch.nn.modules.module.Module: From 35ffbf8e779db9e62b1c88a461ad724dfc029016 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 12 Nov 2021 12:47:16 +0000 Subject: [PATCH 08/20] update --- pytorch_lightning/utilities/meta.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 386ea7d4d90f2..6d0f2da8659fa 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -18,7 +18,7 @@ from functools import partial from itertools import chain from types import ModuleType -from typing import Callable, Dict, Generator, Iterator, List, Optional, Set, Type +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type import torch from torch import nn, Tensor @@ -314,7 +314,7 @@ def search(mod: ModuleType) -> List[ModuleType]: setattr(mod, subclass.__name__, _MetaClass) -def mock_isinstance(A, B, isinstance=None): +def mock_isinstance(A: Any, B: Any, isinstance: Callable) -> bool: if isinstance(B, type) and "_MetaClass" in B.__name__: return isinstance(A, B.__bases__[0]) return isinstance(A, B) From cbf424a659f8f2b2f030ecb18d9fb9c669227911 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 12 Nov 2021 12:51:33 +0000 Subject: [PATCH 09/20] update --- pytorch_lightning/trainer/trainer.py | 11 ++--------- pytorch_lightning/utilities/meta.py | 6 ++++-- 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 17d81ff88441a..72510339b6f8c 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -38,14 +38,7 @@ from pytorch_lightning.loops import PredictionLoop, TrainingBatchLoop, TrainingEpochLoop from pytorch_lightning.loops.dataloader.evaluation_loop import EvaluationLoop from pytorch_lightning.loops.fit_loop import FitLoop -from pytorch_lightning.plugins import ( - DDPSpawnPlugin, - ParallelPlugin, - PLUGIN_INPUT, - PrecisionPlugin, - TPUSpawnPlugin, - TrainingTypePlugin, -) +from pytorch_lightning.plugins import DDPSpawnPlugin, ParallelPlugin, PLUGIN_INPUT, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.profiler import ( AdvancedProfiler, BaseProfiler, @@ -1420,7 +1413,7 @@ def _handle_meta_model(self) -> None: if param.device.type != "meta": return - if isinstance(self.training_type_plugin, (DDPSpawnPlugin, TPUSpawnPlugin)): + if isinstance(self.training_type_plugin, DDPSpawnPlugin): raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") materialize_module(self.lightning_module) diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 6d0f2da8659fa..a02f0aafe0f1e 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -314,7 +314,9 @@ def search(mod: ModuleType) -> List[ModuleType]: setattr(mod, subclass.__name__, _MetaClass) -def mock_isinstance(A: Any, B: Any, isinstance: Callable) -> bool: +def _mock_isinstance(A: Any, B: Any, isinstance: Callable) -> bool: + # This functions enables to builtins `isinstance` function to work as expected within + # the context of `init_meta_context` as the nn.Module are replace by their Meta version. if isinstance(B, type) and "_MetaClass" in B.__name__: return isinstance(A, B.__bases__[0]) return isinstance(A, B) @@ -327,7 +329,7 @@ def init_meta_context() -> Generator: "where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11." ) _set_meta_device() - __builtins__["isinstance"] = partial(mock_isinstance, isinstance=isinstance) + __builtins__["isinstance"] = partial(_mock_isinstance, isinstance=isinstance) yield __builtins__["isinstance"] = isinstance.keywords["isinstance"] _unset_meta_device() From 84ab16ff2509dd861505cf0f15d224e65d3801e7 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 12 Nov 2021 09:51:21 -0500 Subject: [PATCH 10/20] update --- pytorch_lightning/core/mixins/device_dtype_mixin.py | 6 ++---- pytorch_lightning/trainer/trainer.py | 2 -- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index 9ecd2556f016d..49204582d9c54 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -17,8 +17,6 @@ import torch from torch.nn import Module -import pytorch_lightning as pl - class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ["device", "dtype"] @@ -180,8 +178,8 @@ def __update_properties( ) -> None: def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't - # work when using `init_meta_device`. - if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): + # work when using `init_meta_context`. + if not isinstance(module, (DeviceDtypeModuleMixin, Module)): return if device is not None: module._device = device diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 72510339b6f8c..4ca1498cb25c5 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1418,8 +1418,6 @@ def _handle_meta_model(self) -> None: materialize_module(self.lightning_module) self.lightning_module.trainer = proxy(self) - # TODO: Find a better place to move the newly materialized model to the device - self.training_type_plugin.model_to_device() def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn From 3deda2f937dd85583af0f391f002a09debeb6a44 Mon Sep 17 00:00:00 2001 From: Thomas Chaton Date: Fri, 12 Nov 2021 09:52:29 -0500 Subject: [PATCH 11/20] add comment --- pytorch_lightning/trainer/trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4ca1498cb25c5..1101eacb3951a 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -1417,6 +1417,7 @@ def _handle_meta_model(self) -> None: raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") materialize_module(self.lightning_module) + # the trainer reference is lost during materialization self.lightning_module.trainer = proxy(self) def _call_teardown_hook(self) -> None: From 46ffd6291ed671dddc425a4de7d58f38d517cb81 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 12 Nov 2021 17:13:56 +0000 Subject: [PATCH 12/20] update --- pytorch_lightning/trainer/trainer.py | 5 ++--- pytorch_lightning/utilities/meta.py | 8 ++++++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 72510339b6f8c..2c5dc4bcc2f7d 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -84,7 +84,7 @@ from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.meta import materialize_module +from pytorch_lightning.utilities.meta import is_meta_device, materialize_module from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import ( @@ -1409,8 +1409,7 @@ def _call_configure_sharded_model(self) -> None: self.call_hook("on_configure_sharded_model") def _handle_meta_model(self) -> None: - param = next(self.lightning_module.parameters()) - if param.device.type != "meta": + if not is_meta_device(self.lightning_module): return if isinstance(self.training_type_plugin, DDPSpawnPlugin): diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index a02f0aafe0f1e..421abc4c00d14 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -333,3 +333,11 @@ def init_meta_context() -> Generator: yield __builtins__["isinstance"] = isinstance.keywords["isinstance"] _unset_meta_device() + + +def is_meta_device(module: nn.Module) -> bool: + try: + param = next(module.parameters()) + return param.device.type == "meta" + except StopIteration: + return False From e277a4d45bb63068172d8d44e6176f1e39c4851d Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 12 Nov 2021 17:16:40 +0000 Subject: [PATCH 13/20] update --- tests/utilities/test_meta.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py index 7dfb8bbeb78a8..7c218c8d90bdd 100644 --- a/tests/utilities/test_meta.py +++ b/tests/utilities/test_meta.py @@ -14,7 +14,7 @@ from torch import nn from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities.meta import init_meta_context, materialize_module +from pytorch_lightning.utilities.meta import init_meta_context, is_meta_device, materialize_module from tests.helpers.runif import RunIf @@ -38,12 +38,16 @@ def test_init_meta_context(): m = nn.Linear(in_features=1, out_features=1) assert isinstance(m, nn.Linear) assert m.weight.device.type == "meta" + assert is_meta_device(m) mlp = MLP(4) assert mlp.layer[0].weight.device.type == "meta" mlp = materialize_module(mlp) assert mlp.layer[0].weight.device.type == "cpu" + assert not is_meta_device(mlp) + assert not is_meta_device(nn.Module()) + model = BoringModel(4) assert model.layer[0].weight.device.type == "meta" materialize_module(model) From e7244834def314bb4561f525a07fcb63ca492f89 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 12 Nov 2021 18:20:08 +0100 Subject: [PATCH 14/20] Minor fixes --- pytorch_lightning/utilities/meta.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 421abc4c00d14..2bfc60b6d7106 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -242,7 +242,7 @@ def _set_meta_device() -> None: # if a subclass has already been stored, we should use the cache if str(subclass) in __STORAGE_META__: - # reset the class import package to its rightfull state. + # reset the class import package to its rightful state. mods, subclass, meta_class = __STORAGE_META__[subclass] for mod in mods: setattr(mod, subclass.__name__, meta_class) @@ -254,21 +254,21 @@ def _set_meta_device() -> None: class _MetaClass(subclass): @classmethod @contextmanager - def instantiation_context(cls, materialize: bool): + def instantiation_context(cls): _unset_meta_device(from_created=True) yield _set_meta_device_populated(from_created=True) @classmethod def materialize(cls, materialize_fn: Callable): - with cls.instantiation_context(materialize=True): + with cls.instantiation_context(): obj = materialize_fn() return obj @staticmethod def add_subclasses(subclass): - """This is used to unroll the instantion tree while creating the modules.""" - # Don't store the LightningModule as skiped from the Meta process. + """This is used to unroll the instantiation tree while creating the modules.""" + # Don't store the LightningModule as skipped from the Meta process. if subclass != pl.LightningModule: __CREATED_MODULES__.add(subclass) if subclass.__bases__[0] != torch.nn.modules.module.Module: @@ -277,7 +277,7 @@ def add_subclasses(subclass): def __new__(cls, *args, **kwargs): subclass = cls.__bases__[0] cls.add_subclasses(subclass) - with cls.instantiation_context(materialize=False): + with cls.instantiation_context(): obj = init_meta(subclass, *args, **kwargs) obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) @@ -297,8 +297,7 @@ def search(mod: ModuleType) -> List[ModuleType]: # Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear # Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear # needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass - out = [] - out.append(search(mod)) + out = [search(mod)] for name in submodules[1:]: mod = getattr(mod, name) out.append(search(mod)) From a2a1084fea96231f0a311ed2f6bc130dc376af70 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 12 Nov 2021 18:33:46 +0100 Subject: [PATCH 15/20] __instancecheck__ magic --- pytorch_lightning/utilities/meta.py | 25 ++++++++++--------------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 2bfc60b6d7106..70a7320a0a1a0 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -248,10 +248,15 @@ def _set_meta_device() -> None: setattr(mod, subclass.__name__, meta_class) continue + class _IsinstanceMetaclass(type(subclass)): + def __instancecheck__(self, instance: Any) -> bool: + """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects.""" + return isinstance(instance, self.__bases__[0]) + # Create a class subclassing current `subclass` overriding its new method. # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` # version of the current subclass module - class _MetaClass(subclass): + class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): @classmethod @contextmanager def instantiation_context(cls): @@ -272,7 +277,7 @@ def add_subclasses(subclass): if subclass != pl.LightningModule: __CREATED_MODULES__.add(subclass) if subclass.__bases__[0] != torch.nn.modules.module.Module: - _MetaClass.add_subclasses(subclass.__bases__[0]) + _MaterializerModule.add_subclasses(subclass.__bases__[0]) def __new__(cls, *args, **kwargs): subclass = cls.__bases__[0] @@ -296,7 +301,7 @@ def search(mod: ModuleType) -> List[ModuleType]: # nn.Module class can be imported at different level and they all need to be mocked. # Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear # Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear - # needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass + # needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule out = [search(mod)] for name in submodules[1:]: mod = getattr(mod, name) @@ -306,19 +311,11 @@ def search(mod: ModuleType) -> List[ModuleType]: mods = [mod for mod in chain(*out) if mod] # store the modules search so it doesn't have to be performed again for this class - __STORAGE_META__[subclass] = (mods, subclass, _MetaClass) + __STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule) # replace all subclass by its meta form for mod in mods: - setattr(mod, subclass.__name__, _MetaClass) - - -def _mock_isinstance(A: Any, B: Any, isinstance: Callable) -> bool: - # This functions enables to builtins `isinstance` function to work as expected within - # the context of `init_meta_context` as the nn.Module are replace by their Meta version. - if isinstance(B, type) and "_MetaClass" in B.__name__: - return isinstance(A, B.__bases__[0]) - return isinstance(A, B) + setattr(mod, subclass.__name__, _MaterializerModule) @contextmanager @@ -328,9 +325,7 @@ def init_meta_context() -> Generator: "where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11." ) _set_meta_device() - __builtins__["isinstance"] = partial(_mock_isinstance, isinstance=isinstance) yield - __builtins__["isinstance"] = isinstance.keywords["isinstance"] _unset_meta_device() From dfb42e4f1c54510122a7715e4e7a30f5c741fe4b Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 12 Nov 2021 21:25:52 +0000 Subject: [PATCH 16/20] update --- pytorch_lightning/core/mixins/device_dtype_mixin.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/core/mixins/device_dtype_mixin.py b/pytorch_lightning/core/mixins/device_dtype_mixin.py index 49204582d9c54..e8b122989cd9c 100644 --- a/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -17,6 +17,8 @@ import torch from torch.nn import Module +import pytorch_lightning as pl + class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ["device", "dtype"] @@ -179,7 +181,7 @@ def __update_properties( def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't # work when using `init_meta_context`. - if not isinstance(module, (DeviceDtypeModuleMixin, Module)): + if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): return if device is not None: module._device = device From e08ece3a8c97a7e584f34ae6fad6fe2503682fde Mon Sep 17 00:00:00 2001 From: Rohit Gupta Date: Mon, 15 Nov 2021 14:31:46 +0530 Subject: [PATCH 17/20] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fee2fa7fc65cd..7cd6fb457bde5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -127,7 +127,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) -- Fix `isinstance` not working with `init_meta_context`, materialize model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493)) +- Fixed `isinstance` not working with `init_meta_context`, materialize model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493)) - Squeeze the early stopping monitor to remove empty tensor dimensions ([#10461](https://github.com/PyTorchLightning/pytorch-lightning/issues/10461)) From c12cd1730f35771e4e9e2bf85923713d59190826 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 15 Nov 2021 09:03:10 +0000 Subject: [PATCH 18/20] update --- tests/utilities/test_meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py index 7c218c8d90bdd..016f9123ef891 100644 --- a/tests/utilities/test_meta.py +++ b/tests/utilities/test_meta.py @@ -31,7 +31,7 @@ def __init__(self, num_layers: int): self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)]) -@RunIf(min_torch="1.10.0") +@RunIf(min_torch="1.10.0", special=True) def test_init_meta_context(): with init_meta_context(): From 2ce60ca9b1ec7d92dedbb9041ac7fbd1a17e95e1 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 15 Nov 2021 11:46:33 +0000 Subject: [PATCH 19/20] update --- tests/utilities/test_meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py index 016f9123ef891..2e3f62d55e870 100644 --- a/tests/utilities/test_meta.py +++ b/tests/utilities/test_meta.py @@ -31,7 +31,7 @@ def __init__(self, num_layers: int): self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)]) -@RunIf(min_torch="1.10.0", special=True) +@RunIf(special=True, min_torch="1.10.0") def test_init_meta_context(): with init_meta_context(): From e5df66e47134cd1a3dfbda36214832aa667b3e79 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 15 Nov 2021 15:42:55 +0000 Subject: [PATCH 20/20] update on comments --- CHANGELOG.md | 2 +- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/utilities/meta.py | 2 +- tests/utilities/test_meta.py | 8 ++++---- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f67ff1fd9dcb..2336564824f43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -131,7 +131,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `CombinedLoader` and `max_size_cycle` didn't receive a `DistributedSampler` ([#10374](https://github.com/PyTorchLightning/pytorch-lightning/issues/10374)) -- Fixed `isinstance` not working with `init_meta_context`, materialize model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493)) +- Fixed `isinstance` not working with `init_meta_context`, materialized model not being moved to the device ([#10493](https://github.com/PyTorchLightning/metrics/pull/10493)) - Fixed an issue that prevented the Trainer to shutdown workers when execution is interrupted due to failure([#10463](https://github.com/PyTorchLightning/pytorch-lightning/issues/10463)) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 774b908acf3a2..8e41aac3da840 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -84,7 +84,7 @@ from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.meta import is_meta_device, materialize_module +from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.seed import reset_seed from pytorch_lightning.utilities.types import ( @@ -1411,7 +1411,7 @@ def _call_configure_sharded_model(self) -> None: self.call_hook("on_configure_sharded_model") def _handle_meta_model(self) -> None: - if not is_meta_device(self.lightning_module): + if not is_on_meta_device(self.lightning_module): return if isinstance(self.training_type_plugin, DDPSpawnPlugin): diff --git a/pytorch_lightning/utilities/meta.py b/pytorch_lightning/utilities/meta.py index 70a7320a0a1a0..6d3c1d6b5f11b 100644 --- a/pytorch_lightning/utilities/meta.py +++ b/pytorch_lightning/utilities/meta.py @@ -329,7 +329,7 @@ def init_meta_context() -> Generator: _unset_meta_device() -def is_meta_device(module: nn.Module) -> bool: +def is_on_meta_device(module: nn.Module) -> bool: try: param = next(module.parameters()) return param.device.type == "meta" diff --git a/tests/utilities/test_meta.py b/tests/utilities/test_meta.py index 2e3f62d55e870..581b949d9167f 100644 --- a/tests/utilities/test_meta.py +++ b/tests/utilities/test_meta.py @@ -14,7 +14,7 @@ from torch import nn from pytorch_lightning.core.lightning import LightningModule -from pytorch_lightning.utilities.meta import init_meta_context, is_meta_device, materialize_module +from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module from tests.helpers.runif import RunIf @@ -38,15 +38,15 @@ def test_init_meta_context(): m = nn.Linear(in_features=1, out_features=1) assert isinstance(m, nn.Linear) assert m.weight.device.type == "meta" - assert is_meta_device(m) + assert is_on_meta_device(m) mlp = MLP(4) assert mlp.layer[0].weight.device.type == "meta" mlp = materialize_module(mlp) assert mlp.layer[0].weight.device.type == "cpu" - assert not is_meta_device(mlp) - assert not is_meta_device(nn.Module()) + assert not is_on_meta_device(mlp) + assert not is_on_meta_device(nn.Module()) model = BoringModel(4) assert model.layer[0].weight.device.type == "meta"