Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Reset metrics before each task starts ([#9410](https://github.com/PyTorchLightning/pytorch-lightning/pull/9410))


- `seed_everything` now fails when an invalid seed value is passed instead of selecting a random seed ([#8787](https://github.com/PyTorchLightning/pytorch-lightning/pull/8787))


### Deprecated

- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ module = [
"pytorch_lightning.utilities.memory",
"pytorch_lightning.utilities.model_summary",
"pytorch_lightning.utilities.parsing",
"pytorch_lightning.utilities.seed",
"pytorch_lightning.utilities.xla_device",
]
ignore_errors = "False"
20 changes: 13 additions & 7 deletions pytorch_lightning/utilities/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,19 @@ def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int:
max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min

try:
if seed is None:
seed = os.environ.get("PL_GLOBAL_SEED")
if seed is None:
env_seed = os.environ.get("PL_GLOBAL_SEED")
if env_seed is None:
seed = _select_seed_randomly(min_seed_value, max_seed_value)
rank_zero_warn(f"No seed found, seed set to {seed}")
else:
try:
seed = int(env_seed)
except ValueError:
seed = _select_seed_randomly(min_seed_value, max_seed_value)
rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")
elif not isinstance(seed, int):
seed = int(seed)
except (TypeError, ValueError):
seed = _select_seed_randomly(min_seed_value, max_seed_value)
rank_zero_warn(f"No correct seed found, seed set to {seed}")

if not (min_seed_value <= seed <= max_seed_value):
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
Expand Down Expand Up @@ -87,7 +93,7 @@ def reset_seed() -> None:
seed_everything(int(seed), workers=bool(workers))


def pl_worker_init_function(worker_id: int, rank: Optional = None) -> None: # pragma: no cover
def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover
"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed
with ``seed_everything(seed, workers=True)``.

Expand Down
4 changes: 2 additions & 2 deletions tests/utilities/test_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
@mock.patch.dict(os.environ, {}, clear=True)
def test_seed_stays_same_with_multiple_seed_everything_calls():
"""Ensure that after the initial seed everything, the seed stays the same for the same run."""
with pytest.warns(UserWarning, match="No correct seed found"):
with pytest.warns(UserWarning, match="No seed found"):
seed_utils.seed_everything()
initial_seed = os.environ.get("PL_GLOBAL_SEED")

Expand All @@ -32,7 +32,7 @@ def test_correct_seed_with_environment_variable():
@mock.patch.object(seed_utils, attribute="_select_seed_randomly", new=lambda *_: 123)
def test_invalid_seed():
"""Ensure that we still fix the seed even if an invalid seed is given."""
with pytest.warns(UserWarning, match="No correct seed found"):
with pytest.warns(UserWarning, match="Invalid seed found"):
seed = seed_utils.seed_everything()
assert seed == 123

Expand Down