From 43f8149a72fc542c9c7a7057541ec0c1b2b10e72 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 7 Aug 2021 12:52:46 +0200 Subject: [PATCH 1/7] Fix mypy typing in utilities.seed --- pyproject.toml | 1 + pytorch_lightning/utilities/seed.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d07f19ef10986..218d0f50de539 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,7 @@ module = [ "pytorch_lightning.utilities.distributed", "pytorch_lightning.utilities.memory", "pytorch_lightning.utilities.parsing", + "pytorch_lightning.utilities.seed", "pytorch_lightning.utilities.xla_device", ] ignore_errors = "False" diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 732f2d8136b9e..51eff50159991 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -48,13 +48,14 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min - try: - if seed is None: - seed = os.environ.get("PL_GLOBAL_SEED") - seed = int(seed) - except (TypeError, ValueError): - seed = _select_seed_randomly(min_seed_value, max_seed_value) - rank_zero_warn(f"No correct seed found, seed set to {seed}") + if seed is None: + global_seed = os.environ.get("PL_GLOBAL_SEED") + if isinstance(global_seed, str): + seed = int(global_seed) + else: + rank_zero_warn(f"No correct seed found, seed set to {seed}") + seed = _select_seed_randomly(max_seed_value, max_seed_value) + seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") @@ -89,7 +90,7 @@ def reset_seed() -> None: seed_everything(int(seed), workers=bool(workers)) -def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover +def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover """ The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed with ``seed_everything(seed, workers=True)``. From 37a7cba4dfcd7d4dc637e66a1258c9c91d5bfb22 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 7 Aug 2021 12:58:18 +0200 Subject: [PATCH 2/7] Remove redundant seed = int(seed) --- pytorch_lightning/utilities/seed.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 51eff50159991..0ad6a64b160c0 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -55,7 +55,6 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: else: rank_zero_warn(f"No correct seed found, seed set to {seed}") seed = _select_seed_randomly(max_seed_value, max_seed_value) - seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") From d7f616a7ff6c5ba48e68849d1368c66baf892ace Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Sat, 7 Aug 2021 15:36:51 +0200 Subject: [PATCH 3/7] Improve if statement before passing global_seed into int --- pytorch_lightning/utilities/seed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 0ad6a64b160c0..d91e63dd1b56b 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -50,7 +50,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: if seed is None: global_seed = os.environ.get("PL_GLOBAL_SEED") - if isinstance(global_seed, str): + if isinstance(global_seed, str) and all(char.isdigit() for char in global_seed): seed = int(global_seed) else: rank_zero_warn(f"No correct seed found, seed set to {seed}") From 694bb66168b9db27f41d1a896797e1809db988a6 Mon Sep 17 00:00:00 2001 From: "Stancl, Daniel" Date: Thu, 2 Sep 2021 21:03:02 +0200 Subject: [PATCH 4/7] Add typing ignore comments for seed_everything --- pytorch_lightning/utilities/seed.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index d91e63dd1b56b..aaf1134df19f6 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -48,13 +48,15 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min - if seed is None: - global_seed = os.environ.get("PL_GLOBAL_SEED") - if isinstance(global_seed, str) and all(char.isdigit() for char in global_seed): - seed = int(global_seed) - else: - rank_zero_warn(f"No correct seed found, seed set to {seed}") - seed = _select_seed_randomly(max_seed_value, max_seed_value) + try: + # Mypy typing is ignored below as the code simplicity is prefered to mypy correctness. Also, possible errors + # are handled by the exception. + if seed is None: + seed = os.environ.get("PL_GLOBAL_SEED") # type: ignore + seed = int(seed) # type: ignore + except (TypeError, ValueError): + seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"No correct seed found, seed set to {seed}") if not (min_seed_value <= seed <= max_seed_value): rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") From 2fd5d5f4fc83b8fa5ef5083c866fceb4c3fe9f35 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Thu, 9 Sep 2021 13:51:40 +0100 Subject: [PATCH 5/7] Update pytorch_lightning/utilities/seed.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/utilities/seed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index f90155d554bd6..55cfd4ba65d51 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -47,7 +47,7 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: min_seed_value = np.iinfo(np.uint32).min try: - # Mypy typing is ignored below as the code simplicity is prefered to mypy correctness. Also, possible errors + # Mypy typing is ignored below as the code simplicity is preferred to mypy correctness. Also, possible errors # are handled by the exception. if seed is None: seed = os.environ.get("PL_GLOBAL_SEED") # type: ignore From aea2073a4b44e90de6a6175f8b4de50b9ada6d39 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 16:35:29 +0200 Subject: [PATCH 6/7] Stricter alternative --- pytorch_lightning/utilities/seed.py | 22 +++++++++++++--------- tests/utilities/test_seed.py | 4 ++-- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index 55cfd4ba65d51..6a01f19ac7934 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -46,15 +46,19 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min - try: - # Mypy typing is ignored below as the code simplicity is preferred to mypy correctness. Also, possible errors - # are handled by the exception. - if seed is None: - seed = os.environ.get("PL_GLOBAL_SEED") # type: ignore - seed = int(seed) # type: ignore - except (TypeError, ValueError): - seed = _select_seed_randomly(min_seed_value, max_seed_value) - rank_zero_warn(f"No correct seed found, seed set to {seed}") + if seed is None: + env_seed = os.environ.get("PL_GLOBAL_SEED") + if env_seed is None: + seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"No seed found, seed set to {seed}") + else: + try: + seed = int(env_seed) + except ValueError: + seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") + elif not isinstance(seed, int): + seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py index ca103b0a2a318..f51e5143e7e9b 100644 --- a/tests/utilities/test_seed.py +++ b/tests/utilities/test_seed.py @@ -10,7 +10,7 @@ @mock.patch.dict(os.environ, {}, clear=True) def test_seed_stays_same_with_multiple_seed_everything_calls(): """Ensure that after the initial seed everything, the seed stays the same for the same run.""" - with pytest.warns(UserWarning, match="No correct seed found"): + with pytest.warns(UserWarning, match="No seed found"): seed_utils.seed_everything() initial_seed = os.environ.get("PL_GLOBAL_SEED") @@ -32,7 +32,7 @@ def test_correct_seed_with_environment_variable(): @mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123) def test_invalid_seed(): """Ensure that we still fix the seed even if an invalid seed is given.""" - with pytest.warns(UserWarning, match="No correct seed found"): + with pytest.warns(UserWarning, match="Invalid seed found"): seed = seed_utils.seed_everything() assert seed == 123 From d7e73c2c4f03a5d869b058eca66309692b097159 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 23 Sep 2021 16:39:58 +0200 Subject: [PATCH 7/7] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 45c1328193ab2..d85d89169a928 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -221,6 +221,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Reset metrics before each task starts ([#9410](https://github.com/PyTorchLightning/pytorch-lightning/pull/9410)) +- `seed_everything` now fails when an invalid seed value is passed instead of selecting a random seed ([#8787](https://github.com/PyTorchLightning/pytorch-lightning/pull/8787)) + + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`