Skip to content

Commit e296f36

Browse files
SeanNarenBordacarmoccatchaton
committed
[BUG] Check environ before selecting a seed to prevent warning message (#4743)
* Check environment var independently to selecting a seed to prevent unnecessary warning message * Add if statement to check if PL_GLOBAL_SEED has been set * Added seed test to ensure that the seed stays the same, in case * if * Delete global seed after test has finished * Fix code, add tests * Ensure seed does not exist before tests start * Refactor test based on review, add log call * Ensure we clear the os environ in patched dict Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: chaton <[email protected]> (cherry picked from commit 635df27)
1 parent fc58f66 commit e296f36

File tree

2 files changed

+62
-10
lines changed

2 files changed

+62
-10
lines changed

pytorch_lightning/utilities/seed.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
import numpy as np
2222
import torch
23-
2423
from pytorch_lightning import _logger as log
24+
from pytorch_lightning.utilities import rank_zero_warn
2525

2626

2727
def seed_everything(seed: Optional[int] = None) -> int:
@@ -41,18 +41,17 @@ def seed_everything(seed: Optional[int] = None) -> int:
4141

4242
try:
4343
if seed is None:
44-
seed = os.environ.get("PL_GLOBAL_SEED", _select_seed_randomly(min_seed_value, max_seed_value))
44+
seed = os.environ.get("PL_GLOBAL_SEED")
4545
seed = int(seed)
4646
except (TypeError, ValueError):
4747
seed = _select_seed_randomly(min_seed_value, max_seed_value)
48+
rank_zero_warn(f"No correct seed found, seed set to {seed}")
4849

49-
if (seed > max_seed_value) or (seed < min_seed_value):
50-
log.warning(
51-
f"{seed} is not in bounds, \
52-
numpy accepts from {min_seed_value} to {max_seed_value}"
53-
)
50+
if not (min_seed_value <= seed <= max_seed_value):
51+
rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")
5452
seed = _select_seed_randomly(min_seed_value, max_seed_value)
5553

54+
log.info(f"Global seed set to {seed}")
5655
os.environ["PL_GLOBAL_SEED"] = str(seed)
5756
random.seed(seed)
5857
np.random.seed(seed)
@@ -62,6 +61,4 @@ def seed_everything(seed: Optional[int] = None) -> int:
6261

6362

6463
def _select_seed_randomly(min_seed_value: int = 0, max_seed_value: int = 255) -> int:
65-
seed = random.randint(min_seed_value, max_seed_value)
66-
log.warning(f"No correct seed found, seed set to {seed}")
67-
return seed
64+
return random.randint(min_seed_value, max_seed_value)

tests/utilities/test_seed.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
3+
from unittest import mock
4+
import pytest
5+
6+
import pytorch_lightning.utilities.seed as seed_utils
7+
8+
9+
@mock.patch.dict(os.environ, {}, clear=True)
10+
def test_seed_stays_same_with_multiple_seed_everything_calls():
11+
"""
12+
Ensure that after the initial seed everything,
13+
the seed stays the same for the same run.
14+
"""
15+
with pytest.warns(UserWarning, match="No correct seed found"):
16+
seed_utils.seed_everything()
17+
initial_seed = os.environ.get("PL_GLOBAL_SEED")
18+
19+
with pytest.warns(None) as record:
20+
seed_utils.seed_everything()
21+
assert not record # does not warn
22+
seed = os.environ.get("PL_GLOBAL_SEED")
23+
24+
assert initial_seed == seed
25+
26+
27+
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "2020"}, clear=True)
28+
def test_correct_seed_with_environment_variable():
29+
"""
30+
Ensure that the PL_GLOBAL_SEED environment is read
31+
"""
32+
assert seed_utils.seed_everything() == 2020
33+
34+
35+
@mock.patch.dict(os.environ, {"PL_GLOBAL_SEED": "invalid"}, clear=True)
36+
@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123)
37+
def test_invalid_seed():
38+
"""
39+
Ensure that we still fix the seed even if an invalid seed is given
40+
"""
41+
with pytest.warns(UserWarning, match="No correct seed found"):
42+
seed = seed_utils.seed_everything()
43+
assert seed == 123
44+
45+
46+
@mock.patch.dict(os.environ, {}, clear=True)
47+
@mock.patch.object(seed_utils, attribute='_select_seed_randomly', new=lambda *_: 123)
48+
@pytest.mark.parametrize("seed", (10e9, -10e9))
49+
def test_out_of_bounds_seed(seed):
50+
"""
51+
Ensure that we still fix the seed even if an out-of-bounds seed is given
52+
"""
53+
with pytest.warns(UserWarning, match="is not in bounds"):
54+
actual = seed_utils.seed_everything(seed)
55+
assert actual == 123

0 commit comments

Comments
 (0)