|
4 | 4 | import os |
5 | 5 | import sys |
6 | 6 | import time |
| 7 | +import typing |
7 | 8 |
|
8 | 9 | _this_year = time.strftime("%Y") |
9 | 10 | __version__ = '1.3.0dev' |
|
50 | 51 | _PACKAGE_ROOT = os.path.dirname(__file__) |
51 | 52 | _PROJECT_ROOT = os.path.dirname(_PACKAGE_ROOT) |
52 | 53 |
|
| 54 | + |
53 | 55 | try: |
54 | 56 | # This variable is injected in the __builtins__ by the build |
55 | 57 | # process. It used to enable importing subpackages of skimage when |
|
58 | 60 | except NameError: |
59 | 61 | __LIGHTNING_SETUP__: bool = False |
60 | 62 |
|
61 | | -if __LIGHTNING_SETUP__: # pragma: no-cover |
62 | | - sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover |
63 | | - # We are not importing the rest of the lightning during the build process, as it may not be compiled yet |
64 | | -else: |
| 63 | + |
| 64 | +if typing.TYPE_CHECKING: |
| 65 | + # Simplify the imports for the static type checker. |
| 66 | + from pytorch_lightning import metrics |
| 67 | + from pytorch_lightning.callbacks import Callback |
| 68 | + from pytorch_lightning.core import LightningDataModule, LightningModule |
| 69 | + from pytorch_lightning.trainer import Trainer |
| 70 | + from pytorch_lightning.utilities.seed import seed_everything |
| 71 | + |
| 72 | +elif not __LIGHTNING_SETUP__: # pragma: no-cover |
65 | 73 | from pytorch_lightning import metrics |
66 | 74 | from pytorch_lightning.callbacks import Callback |
67 | 75 | from pytorch_lightning.core import LightningDataModule, LightningModule |
|
76 | 84 | 'seed_everything', |
77 | 85 | 'metrics', |
78 | 86 | ] |
| 87 | +else: |
| 88 | + sys.stdout.write(f'Partial import of `{__name__}` during the build process.\n') # pragma: no-cover |
| 89 | + # We are not importing the rest of the lightning during the build process, as it may not be compiled yet |
| 90 | + |
79 | 91 |
|
80 | 92 | # for compatibility with namespace packages |
81 | 93 | __import__('pkg_resources').declare_namespace(__name__) |
0 commit comments