Skip to content

Commit 0c1b4ff

Browse files
committed
Add unconditional import when type checking
1 parent 680e83a commit 0c1b4ff

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

pytorch_lightning/__init__.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import os
55
import sys
66
import time
7+
import typing
78

89
_this_year = time.strftime("%Y")
910
__version__ = '1.3.0dev'
@@ -50,18 +51,8 @@
5051
_PACKAGE_ROOT = os.path.dirname(__file__)
5152
_PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT)
5253

53-
try:
54-
# This variable is injected in the __builtins__ by the build
55-
# process. It used to enable importing subpackages of skimage when
56-
# the binaries are not built
57-
_ = None if __LIGHTNING_SETUP__ else None
58-
except NameError:
59-
__LIGHTNING_SETUP__: bool = False
6054

61-
if __LIGHTNING_SETUP__: # pragma: no-cover
62-
sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
63-
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
64-
else:
55+
if typing.TYPE_CHECKING:
6556
from pytorch_lightning import metrics
6657
from pytorch_lightning.callbacks import Callback
6758
from pytorch_lightning.core import LightningDataModule, LightningModule
@@ -76,6 +67,33 @@
7667
'seed_everything',
7768
'metrics',
7869
]
70+
else:
71+
try:
72+
# This variable is injected in the __builtins__ by the build
73+
# process. It used to enable importing subpackages of skimage when
74+
# the binaries are not built
75+
_ = None if __LIGHTNING_SETUP__ else None
76+
except NameError:
77+
__LIGHTNING_SETUP__: bool = False
78+
79+
if __LIGHTNING_SETUP__: # pragma: no-cover
80+
sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover
81+
# We are not importing the rest of the lightning during the build process, as it may not be compiled yet
82+
else:
83+
from pytorch_lightning import metrics
84+
from pytorch_lightning.callbacks import Callback
85+
from pytorch_lightning.core import LightningDataModule, LightningModule
86+
from pytorch_lightning.trainer import Trainer
87+
from pytorch_lightning.utilities.seed import seed_everything
88+
89+
__all__ = [
90+
'Trainer',
91+
'LightningDataModule',
92+
'LightningModule',
93+
'Callback',
94+
'seed_everything',
95+
'metrics',
96+
]
7997

8098
# for compatibility with namespace packages
8199
__import__('pkg_resources').declare_namespace(__name__)

0 commit comments

Comments
 (0)