Skip to content

Commit e49be59

Browse files
awaelchlilexierule
authored andcommitted
fix gpus default for Trainer.add_argparse_args (#6898)
(cherry picked from commit 9c9e2a0)
1 parent dceabcc commit e49be59

File tree

6 files changed

+33
-47
lines changed

6 files changed

+33
-47
lines changed

CHANGELOG.md

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8-
## [UnReleased] - 2021-MM-DD
8+
## [1.3.0] - 2021-MM-DD
99

1010
### Added
1111

12+
- Added utils for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))
13+
1214

1315
- Added more explicit exception message when trying to execute `trainer.test()` or `trainer.validate()` with `fast_dev_run=True` ([#6667](https://github.com/PyTorchLightning/pytorch-lightning/pull/6667))
1416

@@ -81,8 +83,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8183

8284
- Added support for `precision=64`, enabling training with double precision ([#6595](https://github.com/PyTorchLightning/pytorch-lightning/pull/6595))
8385

86+
8487
- Added support for DDP communication hooks ([#6736](https://github.com/PyTorchLightning/pytorch-lightning/issues/6736))
8588

89+
8690
- Added `artifact_location` argument to `MLFlowLogger` which will be passed to the `MlflowClient.create_experiment` call ([#6677](https://github.com/PyTorchLightning/pytorch-lightning/pull/6677))
8791

8892

@@ -111,6 +115,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
111115

112116
### Deprecated
113117

118+
- Deprecated `TrainerTrainingTricksMixin` in favor of a separate utilities module for NaN/Inf detection for gradients and parameters ([#6834](https://github.com/PyTorchLightning/pytorch-lightning/pull/6834/))
119+
120+
114121
- `period` has been deprecated in favor of `every_n_val_epochs` in the `ModelCheckpoint` callback ([#6146](https://github.com/PyTorchLightning/pytorch-lightning/pull/6146))
115122

116123

@@ -221,19 +228,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
221228
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
222229

223230

224-
## [1.2.8] - 2021-04-13
225-
226-
227-
### Changed
228-
229-
230-
### Removed
231-
232-
233-
### Fixed
234-
235-
236-
- Fixed `EarlyStopping` logic when `min_epochs` or `min_steps` requirement is not met ([#6705](https://github.com/PyTorchLightning/pytorch-lightning/pull/6705))
231+
- Fixed `--gpus` default for parser returned by `Trainer.add_argparse_args` ([#6898](https://github.com/PyTorchLightning/pytorch-lightning/pull/6898))
237232

238233

239234
## [1.2.7] - 2021-04-06

pytorch_lightning/utilities/argparse.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,6 @@ def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
186186

187187
if arg == 'gpus' or arg == 'tpu_cores':
188188
use_type = _gpus_allowed_type
189-
arg_default = _gpus_arg_default
190189

191190
# hack for types in (int, float)
192191
if len(arg_types) == 2 and int in set(arg_types) and float in set(arg_types):
@@ -238,13 +237,6 @@ def _gpus_allowed_type(x) -> Union[int, str]:
238237
return int(x)
239238

240239

241-
def _gpus_arg_default(x) -> Union[int, str]:
242-
if ',' in x:
243-
return str(x)
244-
else:
245-
return int(x)
246-
247-
248240
def _int_or_float_type(x) -> Union[int, float]:
249241
if '.' in str(x):
250242
return float(x)

pytorch_lightning/utilities/device_parser.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,6 @@ def parse_gpu_ids(gpus: Optional[Union[int, str, List[int]]]) -> Optional[List[i
5959
If no GPUs are available but the value of gpus variable indicates request for GPUs
6060
then a MisconfigurationException is raised.
6161
"""
62-
63-
# nothing was passed into the GPUs argument
64-
if callable(gpus):
65-
return None
66-
6762
# Check that gpus param is None, Int, String or List
6863
_check_data_type(gpus)
6964

@@ -97,10 +92,6 @@ def parse_tpu_cores(tpu_cores: Union[int, str, List]) -> Optional[Union[List[int
9792
Returns:
9893
a list of tpu_cores to be used or ``None`` if no TPU cores were requested
9994
"""
100-
101-
if callable(tpu_cores):
102-
return None
103-
10495
_check_data_type(tpu_cores)
10596

10697
if isinstance(tpu_cores, str):

tests/loggers/test_wandb.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import os
1515
import pickle
16-
import types
1716
from argparse import ArgumentParser
1817
from unittest import mock
1918

@@ -172,11 +171,10 @@ def wrapper_something():
172171
params.wrapper_something_wo_name = lambda: lambda: '1'
173172
params.wrapper_something = wrapper_something
174173

175-
assert isinstance(params.gpus, types.FunctionType)
176174
params = WandbLogger._convert_params(params)
177175
params = WandbLogger._flatten_dict(params)
178176
params = WandbLogger._sanitize_callable_params(params)
179-
assert params["gpus"] == '_gpus_arg_default'
177+
assert params["gpus"] == "None"
180178
assert params["something"] == "something"
181179
assert params["wrapper_something"] == "wrapper_something"
182180
assert params["wrapper_something_wo_name"] == "<lambda>"

tests/trainer/test_trainer_cli.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import tests.helpers.utils as tutils
2424
from pytorch_lightning import Trainer
2525
from pytorch_lightning.utilities import argparse
26+
from tests.helpers.runif import RunIf
2627

2728

2829
@mock.patch('argparse.ArgumentParser.parse_args')
@@ -45,7 +46,7 @@ def test_default_args(mock_argparse, tmpdir):
4546

4647

4748
@pytest.mark.parametrize('cli_args', [['--accumulate_grad_batches=22'], ['--weights_save_path=./'], []])
48-
def test_add_argparse_args_redefined(cli_args):
49+
def test_add_argparse_args_redefined(cli_args: list):
4950
"""Redefines some default Trainer arguments via the cli and
5051
tests the Trainer initialization correctness.
5152
"""
@@ -90,7 +91,7 @@ def test_get_init_arguments_and_types():
9091

9192

9293
@pytest.mark.parametrize('cli_args', [['--callbacks=1', '--logger'], ['--foo', '--bar=1']])
93-
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
94+
def test_add_argparse_args_redefined_error(cli_args: list, monkeypatch):
9495
"""Asserts thar an error raised in case of passing not default cli arguments."""
9596

9697
class _UnkArgError(Exception):
@@ -171,27 +172,26 @@ def test_argparse_args_parsing(cli_args, expected):
171172
assert Trainer.from_argparse_args(args)
172173

173174

174-
@pytest.mark.parametrize(['cli_args', 'expected_gpu'], [
175-
pytest.param('--gpus 1', [0]),
176-
pytest.param('--gpus 0,', [0]),
175+
@pytest.mark.parametrize(['cli_args', 'expected_parsed', 'expected_device_ids'], [
176+
pytest.param('', None, None),
177+
pytest.param('--gpus 1', 1, [0]),
178+
pytest.param('--gpus 0,', '0,', [0]),
177179
])
178-
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
179-
def test_argparse_args_parsing_gpus(cli_args, expected_gpu):
180+
@RunIf(min_gpus=1)
181+
def test_argparse_args_parsing_gpus(cli_args, expected_parsed, expected_device_ids):
180182
"""Test multi type argument with bool."""
181183
cli_args = cli_args.split(' ') if cli_args else []
182184
with mock.patch("argparse._sys.argv", ["any.py"] + cli_args):
183185
parser = ArgumentParser(add_help=False)
184186
parser = Trainer.add_argparse_args(parent_parser=parser)
185187
args = Trainer.parse_argparser(parser)
186188

189+
assert args.gpus == expected_parsed
187190
trainer = Trainer.from_argparse_args(args)
188-
assert trainer.data_parallel_device_ids == expected_gpu
191+
assert trainer.data_parallel_device_ids == expected_device_ids
189192

190193

191-
@pytest.mark.skipif(
192-
sys.version_info < (3, 7),
193-
reason="signature inspection while mocking is not working in Python < 3.7 despite autospec"
194-
)
194+
@RunIf(min_python="3.7.0")
195195
@pytest.mark.parametrize(['cli_args', 'extra_args'], [
196196
pytest.param({}, {}),
197197
pytest.param({'logger': False}, {}),

tests/utilities/test_argparse_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pytorch_lightning.utilities.argparse import parse_args_from_docstring
1+
from pytorch_lightning.utilities.argparse import parse_args_from_docstring, _gpus_allowed_type, _int_or_float_type
22

33

44
def test_parse_args_from_docstring_normal():
@@ -48,3 +48,13 @@ def test_parse_args_from_docstring_empty():
4848
"""
4949
)
5050
assert len(args_help.keys()) == 0
51+
52+
53+
def test_gpus_allowed_type():
54+
assert _gpus_allowed_type('1,2') == '1,2'
55+
assert _gpus_allowed_type('1') == 1
56+
57+
58+
def test_int_or_float_type():
59+
assert isinstance(_int_or_float_type('0.0'), float)
60+
assert isinstance(_int_or_float_type('0'), int)

0 commit comments

Comments
 (0)