diff --git a/pytorch_lightning/__init__.py b/pytorch_lightning/__init__.py index 569078c994ba4..d14a094b72b00 100644 --- a/pytorch_lightning/__init__.py +++ b/pytorch_lightning/__init__.py @@ -4,6 +4,7 @@ import os import sys import time +import typing _this_year = time.strftime("%Y") __version__ = '1.3.0dev' @@ -50,6 +51,7 @@ _PACKAGE_ROOT = os.path.dirname(__file__) _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) + try: # This variable is injected in the __builtins__ by the build # process. It used to enable importing subpackages of skimage when @@ -58,10 +60,16 @@ except NameError: __LIGHTNING_SETUP__: bool = False -if __LIGHTNING_SETUP__: # pragma: no-cover - sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover - # We are not importing the rest of the lightning during the build process, as it may not be compiled yet -else: + +if typing.TYPE_CHECKING: + # Simplify the imports for the static type checker. + from pytorch_lightning import metrics + from pytorch_lightning.callbacks import Callback + from pytorch_lightning.core import LightningDataModule, LightningModule + from pytorch_lightning.trainer import Trainer + from pytorch_lightning.utilities.seed import seed_everything + +elif not __LIGHTNING_SETUP__: # pragma: no-cover from pytorch_lightning import metrics from pytorch_lightning.callbacks import Callback from pytorch_lightning.core import LightningDataModule, LightningModule @@ -76,6 +84,12 @@ 'seed_everything', 'metrics', ] +else: + sys.stdout.write( + f'Partial import of `{__name__}` during the build process.\n' + ) # pragma: no-cover + # We are not importing the rest of the lightning during the build process, as it may not be compiled yet + # for compatibility with namespace packages __import__('pkg_resources').declare_namespace(__name__)