Skip to content

Commit 6284a4b

Browse files
ddrevickyrohitgr7
andauthored
Apply suggestions from code review
Co-authored-by: Rohit Gupta <[email protected]>
1 parent 43ec93f commit 6284a4b

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

pytorch_lightning/trainer/connectors/profiler_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,6 @@ def on_trainer_init(self, profiler: Union[BaseProfiler, bool, str]):
3535
elif profiler == "advanced":
3636
profiler = AdvancedProfiler()
3737
elif isinstance(profiler, str):
38-
raise ValueError('when passing string value for the `profiler` parameter of'
39-
' `Trainer`, it can only be `simple` or `advanced`')
38+
raise ValueError("When passing string value for the `profiler` parameter of"
39+
" `Trainer`, it can only be 'simple' or 'advanced'")
4040
self.trainer.profiler = profiler or PassThroughProfiler()

tests/test_deprecated.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,13 @@ def test_tbd_remove_in_v0_11_0_trainer_gpu():
2828
gpu_usage = GpuUsageLogger()
2929

3030

31-
@pytest.mark.parametrize(['input_arg', 'expected'], [
31+
@pytest.mark.parametrize(['profiler', 'expected'], [
3232
(True, SimpleProfiler),
3333
(False, PassThroughProfiler),
3434
])
35-
def test_trainer_profiler_remove_in_v0_11_0_trainer(input_arg, expected):
35+
def test_trainer_profiler_remove_in_v0_11_0_trainer(profiler, expected):
3636
with pytest.deprecated_call(match='will be removed in v0.11.0'):
37-
trainer = Trainer(profiler=input_arg)
37+
trainer = Trainer(profiler=profiler)
3838
assert isinstance(trainer.profiler, expected)
3939

4040

tests/trainer/test_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,6 +1196,6 @@ def test_trainer_profiler_correct_args(input_arg, expected):
11961196

11971197

11981198
def test_trainer_profiler_incorrect_str_arg():
1199-
with pytest.raises(ValueError, match=r'when passing string value for the `profiler` parameter of '
1200-
'`Trainer`, it can only be `simple` or `advanced`'):
1199+
with pytest.raises(ValueError, match=r"When passing string value for the `profiler` parameter of"
1200+
"` Trainer`, it can only be 'simple' or 'advanced'"):
12011201
Trainer(profiler="unknown_profiler")

0 commit comments

Comments
 (0)