Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion tests/trainer/test_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,8 +803,12 @@ def _user_worker_init_fn(_):
pass


@RunIf(max_torch="1.8.9")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we rather have an arg that spec max wersion as < instead of <=

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or set it just as 1.8 as max without minor numbers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

< would be more practical I guess but it would not be obvious from the name if < or <=.
Maybe a different approach would be to have torch_gt, torch_ge, torch_lt, torch_le but that would be more argumens ...
not sure

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may say that more practical would be setting level

  • max 1.8.1 will keep 1.8.0, 1.8.1 and ignore 1.8.2 and above
  • max 1.8 will keep 1.8.0, 1.8.1, 1.8.x and ignore 1.9.x and above
  • max 1 will keep 1.8.0, 1.x and ignore 2.x and above

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 for the gt, ge, lt and le. Another option would be to have something like `torch_req='<1.9``. But this would involve more parsing logic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli should the pl worker init function when used with PT 1.9 also be a no-op? given that pytorch now sets the seed correctly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for numpy yes, we don't need to worry about it in 1.9+.

But another part of the pl_worker_init_fn is that it derives not only a unique seed for each worker but it makes it also unique across all distributed processes (we do that by incorporating the global_rank into the seed sequence). I don't think 1.9 does that.

def test_missing_worker_init_fn():
""" Test that naive worker seed initialization leads to undesired random state in subprocesses. """
"""
Test that naive worker seed initialization leads to undesired random state in subprocesses.
PyTorch 1.9+ does not have this issue.
"""
dataset = NumpyRandomDataset()

seed_everything(0)
Expand Down