diff --git a/CHANGELOG.md b/CHANGELOG.md index 45c1328193ab2..d85d89169a928 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -221,6 +221,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Reset metrics before each task starts ([#9410](https://github.com/PyTorchLightning/pytorch-lightning/pull/9410)) +- `seed_everything` now fails when an invalid seed value is passed instead of selecting a random seed ([#8787](https://github.com/PyTorchLightning/pytorch-lightning/pull/8787)) + + ### Deprecated - Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()` diff --git a/pyproject.toml b/pyproject.toml index ed22f853107bb..efa71ca939a63 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ module = [ "pytorch_lightning.utilities.memory", "pytorch_lightning.utilities.model_summary", "pytorch_lightning.utilities.parsing", + "pytorch_lightning.utilities.seed", "pytorch_lightning.utilities.xla_device", ] ignore_errors = "False" diff --git a/pytorch_lightning/utilities/seed.py b/pytorch_lightning/utilities/seed.py index dc64ffb78bade..6a01f19ac7934 100644 --- a/pytorch_lightning/utilities/seed.py +++ b/pytorch_lightning/utilities/seed.py @@ -46,13 +46,19 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: max_seed_value = np.iinfo(np.uint32).max min_seed_value = np.iinfo(np.uint32).min - try: - if seed is None: - seed = os.environ.get("PL_GLOBAL_SEED") + if seed is None: + env_seed = os.environ.get("PL_GLOBAL_SEED") + if env_seed is None: + seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"No seed found, seed set to {seed}") + else: + try: + seed = int(env_seed) + except ValueError: + seed = _select_seed_randomly(min_seed_value, max_seed_value) + rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") + elif not isinstance(seed, int): seed = int(seed) - except (TypeError, ValueError): - seed = _select_seed_randomly(min_seed_value, max_seed_value) - rank_zero_warn(f"No correct seed found, seed set to {seed}") if not (min_seed_value <= seed <= max_seed_value): rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") @@ -87,7 +93,7 @@ def reset_seed() -> None: seed_everything(int(seed), workers=bool(workers)) -def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover +def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed with ``seed_everything(seed, workers=True)``. diff --git a/tests/utilities/test_seed.py b/tests/utilities/test_seed.py index ca103b0a2a318..f51e5143e7e9b 100644 --- a/tests/utilities/test_seed.py +++ b/tests/utilities/test_seed.py @@ -10,7 +10,7 @@ @mock.patch.dict(os.environ, {}, clear=True) def test_seed_stays_same_with_multiple_seed_everything_calls(): """Ensure that after the initial seed everything, the seed stays the same for the same run.""" - with pytest.warns(UserWarning, match="No correct seed found"): + with pytest.warns(UserWarning, match="No seed found"): seed_utils.seed_everything() initial_seed = os.environ.get("PL_GLOBAL_SEED") @@ -32,7 +32,7 @@ def test_correct_seed_with_environment_variable(): @mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123) def test_invalid_seed(): """Ensure that we still fix the seed even if an invalid seed is given.""" - with pytest.warns(UserWarning, match="No correct seed found"): + with pytest.warns(UserWarning, match="Invalid seed found"): seed = seed_utils.seed_everything() assert seed == 123