Skip to content

Commit 4eeeedc

Browse files
authored
Merge 8be4ef3 into 55dd3a4
2 parents 55dd3a4 + 8be4ef3 commit 4eeeedc

File tree

3 files changed

+17
-4
lines changed

3 files changed

+17
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
137137
- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))
138138

139139

140+
- Fixed comparing required versions ([#6434](https://github.com/PyTorchLightning/pytorch-lightning/pull/6434))
141+
142+
140143
## [1.2.2] - 2021-03-02
141144

142145
### Added

pl_examples/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
_DATASETS_PATH = os.path.join(_PACKAGE_ROOT, 'Datasets')
1616

1717
_TORCHVISION_AVAILABLE = _module_available("torchvision")
18-
_TORCHVISION_MNIST_AVAILABLE = True
18+
_TORCHVISION_MNIST_AVAILABLE = _TORCHVISION_AVAILABLE
1919
_DALI_AVAILABLE = _module_available("nvidia.dali")
2020

21-
if _TORCHVISION_AVAILABLE:
21+
if _TORCHVISION_MNIST_AVAILABLE:
2222
try:
2323
from torchvision.datasets.mnist import MNIST
2424
MNIST(_DATASETS_PATH, download=True)

pytorch_lightning/utilities/imports.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""General utilities"""
15+
import importlib
1516
import operator
1617
import platform
1718
import sys
1819
from distutils.version import LooseVersion
1920
from importlib.util import find_spec
2021

2122
import torch
22-
from pkg_resources import DistributionNotFound, get_distribution
23+
from pkg_resources import DistributionNotFound
2324

2425

2526
def _module_available(module_path: str) -> bool:
@@ -42,8 +43,17 @@ def _module_available(module_path: str) -> bool:
4243

4344

4445
def _compare_version(package: str, op, version) -> bool:
46+
"""Compare package version with some requirements
47+
48+
>>> _compare_version("torch", operator.ge, "0.1")
49+
True
50+
"""
51+
if not _module_available(package):
52+
return False
4553
try:
46-
pkg_version = LooseVersion(get_distribution(package).version)
54+
pkg = importlib.import_module(package)
55+
assert hasattr(pkg, '__version__')
56+
pkg_version = pkg.__version__
4757
return op(pkg_version, LooseVersion(version))
4858
except DistributionNotFound:
4959
return False

0 commit comments

Comments
 (0)