Skip to content

Commit a8d9b5f

Browse files
authored
Remove tbptt self.log flags and other dead code [5/n] (#7644)
1 parent 33a1f52 commit a8d9b5f

File tree

5 files changed

+56
-168
lines changed

5 files changed

+56
-168
lines changed

CHANGELOG.md

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,24 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8888
- Deprecated `TrainerModelHooksMixin` in favor of `pytorch_lightning.utilities.signature_utils` ([#7422](https://github.com/PyTorchLightning/pytorch-lightning/pull/7422))
8989

9090

91-
- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
91+
- Deprecated `num_nodes` and `sync_batchnorm` arguments in `DDPPlugin` and `DDPSpawnPlugin` ([#7026](https://github.com/PyTorchLightning/pytorch-lightning/pull/7026))
9292

9393

9494
### Removed
9595

96-
- Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499))
96+
- Prune deprecated classif. metrics from `pytorch_lightning.metrics.functional.classification` ([#7499](https://github.com/PyTorchLightning/pytorch-lightning/pull/7499))
9797

9898

99-
- Removed deprecated data parallel classes `LightningDataParallel` and `LightningDistributedDataParallel` from `pytorch_lightning.overrides.data_parallel` ([7510](https://github.com/PyTorchLightning/pytorch-lightning/pull/7510))
99+
- Removed deprecated data parallel classes `LightningDataParallel` and `LightningDistributedDataParallel` from `pytorch_lightning.overrides.data_parallel` ([#7510](https://github.com/PyTorchLightning/pytorch-lightning/pull/7510))
100100

101101

102-
- Removed deprecated trainer attributes - `get_model` and `accelerator_backend` ([7502](https://github.com/PyTorchLightning/pytorch-lightning/pull/7502))
102+
- Removed deprecated trainer attributes - `get_model` and `accelerator_backend` ([#7502](https://github.com/PyTorchLightning/pytorch-lightning/pull/7502))
103103

104104

105-
- Removed deprecated utils modules `model_utils`, `warning_utils`, `xla_device_utils` and partially `argparse_utils` ([7503](https://github.com/PyTorchLightning/pytorch-lightning/pull/7503))
105+
- Removed support for `self.log(tbptt_reduce_fx)` and `self.log(tbptt_pad_token)`. Please, open a discussion explaining your use-case if you relied on these. ([#7644](https://github.com/PyTorchLightning/pytorch-lightning/pull/7644))
106+
107+
108+
- Removed deprecated utils modules `model_utils`, `warning_utils`, `xla_device_utils` and partially `argparse_utils` ([#7503](https://github.com/PyTorchLightning/pytorch-lightning/pull/7503))
106109

107110

108111
- Removed deprecated trainer attributes - `on_cpu`, `on_tpu`, `use_tpu`, `on_gpu`, `use_dp`, `use_ddp`, `use_ddp2`, `use_horovod`, `use_single_gpu` ([#7501](https://github.com/PyTorchLightning/pytorch-lightning/pull/7501))
@@ -1340,7 +1343,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
13401343
- Fixed getting `experiment_id` from MLFlow only once instead of each training loop ([#3394](https://github.com/PyTorchLightning/pytorch-lightning/pull/3394))
13411344
- Fixed `overfit_batches` which now correctly disables shuffling for the training loader. ([#3501](https://github.com/PyTorchLightning/pytorch-lightning/pull/3501))
13421345
- Fixed gradient norm tracking for `row_log_interval > 1` ([#3489](https://github.com/PyTorchLightning/pytorch-lightning/pull/3489))
1343-
- Fixed `ModelCheckpoint` name formatting ([3164](https://github.com/PyTorchLightning/pytorch-lightning/pull/3163))
1346+
- Fixed `ModelCheckpoint` name formatting ([#3164](https://github.com/PyTorchLightning/pytorch-lightning/pull/3163))
13441347
- Fixed example implementation of AutoEncoder ([#3190](https://github.com/PyTorchLightning/pytorch-lightning/pull/3190))
13451348
- Fixed invalid paths when remote logging with TensorBoard ([#3236](https://github.com/PyTorchLightning/pytorch-lightning/pull/3236))
13461349
- Fixed change `t()` to `transpose()` as XLA devices do not support `.t()` on 1-dim tensor ([#3252](https://github.com/PyTorchLightning/pytorch-lightning/pull/3252))
@@ -1600,8 +1603,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
16001603
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` ([#1908](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908))
16011604
- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458))
16021605
- Speed up single-core TPU training by loading data using `ParallelLoader` ([#2033](https://github.com/PyTorchLightning/pytorch-lightning/pull/2033))
1603-
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756))
1604-
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
1606+
- Added a model hook `transfer_batch_to_device` that enables moving custom data structures to the target device ([#1756](https://github.com/PyTorchLightning/pytorch-lightning/pull/1756))
1607+
- Added [black](https://black.readthedocs.io/en/stable/) formatter for the code with code-checker on pull ([#1610](https://github.com/PyTorchLightning/pytorch-lightning/pull/1610))
16051608
- Added back the slow spawn ddp implementation as `ddp_spawn` ([#2115](https://github.com/PyTorchLightning/pytorch-lightning/pull/2115))
16061609
- Added loading checkpoints from URLs ([#1667](https://github.com/PyTorchLightning/pytorch-lightning/pull/1667))
16071610
- Added a callback method `on_keyboard_interrupt` for handling KeyboardInterrupt events during training ([#2134](https://github.com/PyTorchLightning/pytorch-lightning/pull/2134))

pytorch_lightning/core/lightning.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,8 @@ def log(
263263
on_step: Optional[bool] = None,
264264
on_epoch: Optional[bool] = None,
265265
reduce_fx: Callable = torch.mean,
266-
tbptt_reduce_fx: Callable = torch.mean,
267-
tbptt_pad_token: int = 0,
266+
tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6
267+
tbptt_pad_token: Optional = None, # noqa: Remove in 1.6
268268
enable_graph: bool = False,
269269
sync_dist: bool = False,
270270
sync_dist_op: Union[Any, str] = 'mean',
@@ -299,8 +299,6 @@ def log(
299299
on_step: if True logs at this step. None auto-logs at the training_step but not validation/test_step
300300
on_epoch: if True logs epoch accumulated metrics. None auto-logs at the val/test step but not training_step
301301
reduce_fx: reduction function over step values for end of epoch. Torch.mean by default
302-
tbptt_reduce_fx: function to reduce on truncated back prop
303-
tbptt_pad_token: token to use for padding
304302
enable_graph: if True, will not auto detach the graph
305303
sync_dist: if True, reduces the metric across GPUs/TPUs
306304
sync_dist_op: the op to sync across GPUs/TPUs
@@ -309,6 +307,19 @@ def log(
309307
the name (when using multiple). If False, user needs to give unique names for
310308
each dataloader to not mix values
311309
"""
310+
if tbptt_reduce_fx is not None:
311+
rank_zero_deprecation(
312+
'`self.log(tbptt_reduce_fx=...)` is no longer supported. The flag will be removed in v1.6.'
313+
' Please, open a discussion explaining your use-case in'
314+
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
315+
)
316+
if tbptt_pad_token is not None:
317+
rank_zero_deprecation(
318+
'`self.log(tbptt_pad_token=...)` is no longer supported. The flag will be removed in v1.6.'
319+
' Please, open a discussion explaining your use-case in'
320+
' `https://github.com/PyTorchLightning/pytorch-lightning/discussions`'
321+
)
322+
312323
if self._results is not None:
313324
# TODO: if logged twice fail with crash
314325

@@ -333,8 +344,6 @@ def log(
333344
on_step=on_step,
334345
on_epoch=on_epoch,
335346
reduce_fx=reduce_fx,
336-
tbptt_reduce_fx=tbptt_reduce_fx,
337-
tbptt_pad_token=tbptt_pad_token,
338347
enable_graph=enable_graph,
339348
sync_dist=sync_dist,
340349
sync_dist_op=sync_dist_op,
@@ -352,8 +361,8 @@ def log_dict(
352361
on_step: Optional[bool] = None,
353362
on_epoch: Optional[bool] = None,
354363
reduce_fx: Callable = torch.mean,
355-
tbptt_reduce_fx: Callable = torch.mean,
356-
tbptt_pad_token: int = 0,
364+
tbptt_reduce_fx: Optional = None, # noqa: Remove in 1.6
365+
tbptt_pad_token: Optional = None, # noqa: Remove in 1.6
357366
enable_graph: bool = False,
358367
sync_dist: bool = False,
359368
sync_dist_op: Union[Any, str] = 'mean',
@@ -375,8 +384,6 @@ def log_dict(
375384
on_step: if True logs at this step. None auto-logs for training_step but not validation/test_step
376385
on_epoch: if True logs epoch accumulated metrics. None auto-logs for val/test step but not training_step
377386
reduce_fx: reduction function over step values for end of epoch. Torch.mean by default
378-
tbptt_reduce_fx: function to reduce on truncated back prop
379-
tbptt_pad_token: token to use for padding
380387
enable_graph: if True, will not auto detach the graph
381388
sync_dist: if True, reduces the metric across GPUs/TPUs
382389
sync_dist_op: the op to sync across GPUs/TPUs

pytorch_lightning/core/step_result.py

Lines changed: 2 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,6 @@ def log(
8585
on_step: bool = False,
8686
on_epoch: bool = True,
8787
reduce_fx: Callable = torch.mean,
88-
tbptt_reduce_fx: Callable = torch.mean,
89-
tbptt_pad_token: int = 0,
9088
enable_graph: bool = False,
9189
sync_dist: bool = False,
9290
sync_dist_op: Union[Any, str] = 'mean',
@@ -134,8 +132,6 @@ def log(
134132
on_step=True,
135133
on_epoch=False,
136134
reduce_fx=reduce_fx,
137-
tbptt_reduce_fx=tbptt_reduce_fx,
138-
tbptt_pad_token=tbptt_pad_token,
139135
forked=False,
140136
dataloader_idx=dataloader_idx,
141137
)
@@ -153,8 +149,6 @@ def log(
153149
on_step=False,
154150
on_epoch=True,
155151
reduce_fx=reduce_fx,
156-
tbptt_reduce_fx=tbptt_reduce_fx,
157-
tbptt_pad_token=tbptt_pad_token,
158152
forked=False,
159153
dataloader_idx=dataloader_idx,
160154
)
@@ -169,8 +163,6 @@ def log(
169163
on_step,
170164
on_epoch,
171165
reduce_fx,
172-
tbptt_reduce_fx=tbptt_reduce_fx,
173-
tbptt_pad_token=tbptt_pad_token,
174166
forked=was_forked,
175167
dataloader_idx=dataloader_idx,
176168
)
@@ -187,8 +179,6 @@ def __set_meta(
187179
on_step: bool,
188180
on_epoch: bool,
189181
reduce_fx: Callable,
190-
tbptt_pad_token: int,
191-
tbptt_reduce_fx: Callable,
192182
forked: bool,
193183
dataloader_idx: Union[int, None],
194184
):
@@ -201,8 +191,6 @@ def __set_meta(
201191
on_epoch=on_epoch,
202192
reduce_fx=reduce_fx,
203193
value=meta_value,
204-
tbptt_reduce_fx=tbptt_reduce_fx,
205-
tbptt_pad_token=tbptt_pad_token,
206194
forked=forked,
207195
dataloader_idx=dataloader_idx,
208196
)
@@ -424,47 +412,6 @@ def unpack_batch_size(sample):
424412
size = 1
425413
return size
426414

427-
@classmethod
428-
def gather(cls, outputs):
429-
meta = outputs[0].get('meta')
430-
result = cls()
431-
result = recursive_gather(outputs, result)
432-
recursive_stack(result)
433-
434-
if meta:
435-
result['meta'] = meta
436-
return result
437-
438-
@classmethod
439-
def padded_gather(cls, outputs):
440-
meta = outputs[0].get('meta')
441-
result = cls()
442-
result = recursive_gather(outputs, result)
443-
444-
# find the padding used for other values
445-
default_padding_idx = 0
446-
for name, value in result.items():
447-
if (
448-
name != 'minimize' and isinstance(value, list) and len(value) > 0
449-
and isinstance(value[0], torch.Tensor)
450-
):
451-
default_padding_idx = meta[name]['tbptt_pad_token']
452-
break
453-
454-
# pad across each key individually
455-
for name, value in result.items():
456-
if (isinstance(value, list) and len(value) > 0 and isinstance(value[0], torch.Tensor)):
457-
padding_key = default_padding_idx if name == 'minimize' else meta[name]['tbptt_pad_token']
458-
padded = torch.nn.utils.rnn.pad_sequence(value, batch_first=True, padding_value=padding_key)
459-
result[name] = padded
460-
461-
# also update the result
462-
if meta and name != "minimize":
463-
meta[name]['value'] = padded
464-
if meta:
465-
result['meta'] = meta
466-
return result
467-
468415
@classmethod
469416
def reduce_on_epoch_end(cls, outputs):
470417
# get the batch sizes for all outputs
@@ -522,17 +469,14 @@ def reduce_across_time(cls, time_outputs):
522469
if k in ['meta', 'extra'] or isinstance(value, Metric):
523470
continue
524471

525-
# pick the reduce fx
526-
tbptt_reduce_fx = torch.mean if k == "minimize" else meta[k]['tbptt_reduce_fx']
527-
528472
if isinstance(value, list):
529473
value = torch.tensor(value)
530474

531475
if isinstance(value, dict):
532476
# TODO: recursive reduce:
533-
_recursive_fx_apply(value, tbptt_reduce_fx)
477+
_recursive_fx_apply(value, torch.mean)
534478
else:
535-
result[k] = tbptt_reduce_fx(value.float())
479+
result[k] = torch.mean(value.float())
536480

537481
result['meta'] = meta
538482
return result

tests/core/test_results.py

Lines changed: 0 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -182,98 +182,6 @@ def test_dataloader(self):
182182
assert len(predictions) == len(dm.random_test)
183183

184184

185-
def test_result_gather_stack():
186-
""" Test that tensors get concatenated when they all have the same shape. """
187-
outputs = [
188-
{
189-
"foo": torch.zeros(4, 5)
190-
},
191-
{
192-
"foo": torch.zeros(4, 5)
193-
},
194-
{
195-
"foo": torch.zeros(4, 5)
196-
},
197-
]
198-
result = Result.gather(outputs)
199-
assert isinstance(result["foo"], torch.Tensor)
200-
assert list(result["foo"].shape) == [12, 5]
201-
202-
203-
def test_result_gather_concatenate():
204-
""" Test that tensors get concatenated when they have varying size in first dimension. """
205-
outputs = [
206-
{
207-
"foo": torch.zeros(4, 5)
208-
},
209-
{
210-
"foo": torch.zeros(8, 5)
211-
},
212-
{
213-
"foo": torch.zeros(3, 5)
214-
},
215-
]
216-
result = Result.gather(outputs)
217-
assert isinstance(result["foo"], torch.Tensor)
218-
assert list(result["foo"].shape) == [15, 5]
219-
220-
221-
def test_result_gather_scalar():
222-
""" Test that 0-dim tensors get gathered and stacked correctly. """
223-
outputs = [
224-
{
225-
"foo": torch.tensor(1)
226-
},
227-
{
228-
"foo": torch.tensor(2)
229-
},
230-
{
231-
"foo": torch.tensor(3)
232-
},
233-
]
234-
result = Result.gather(outputs)
235-
assert isinstance(result["foo"], torch.Tensor)
236-
assert list(result["foo"].shape) == [3]
237-
238-
239-
def test_result_gather_different_shapes():
240-
""" Test that tensors of varying shape get gathered into a list. """
241-
outputs = [
242-
{
243-
"foo": torch.tensor(1)
244-
},
245-
{
246-
"foo": torch.zeros(2, 3)
247-
},
248-
{
249-
"foo": torch.zeros(1, 2, 3)
250-
},
251-
]
252-
result = Result.gather(outputs)
253-
expected = [torch.tensor(1), torch.zeros(2, 3), torch.zeros(1, 2, 3)]
254-
assert isinstance(result["foo"], list)
255-
assert all(torch.eq(r, e).all() for r, e in zip(result["foo"], expected))
256-
257-
258-
def test_result_gather_mixed_types():
259-
""" Test that a collection of mixed types gets gathered into a list. """
260-
outputs = [
261-
{
262-
"foo": 1.2
263-
},
264-
{
265-
"foo": ["bar", None]
266-
},
267-
{
268-
"foo": torch.tensor(1)
269-
},
270-
]
271-
result = Result.gather(outputs)
272-
expected = [1.2, ["bar", None], torch.tensor(1)]
273-
assert isinstance(result["foo"], list)
274-
assert result["foo"] == expected
275-
276-
277185
def test_result_retrieve_last_logged_item():
278186
result = Result()
279187
result.log('a', 5., on_step=True, on_epoch=True)

tests/deprecated_api/test_remove_1-6.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,29 @@ def test_v1_6_0_ddp_spawn_num_nodes():
6161
def test_v1_6_0_ddp_spawn_sync_batchnorm():
6262
with pytest.deprecated_call(match="Argument `sync_batchnorm` in `DDPPlugin` is deprecated in v1.4"):
6363
DDPSpawnPlugin(sync_batchnorm=False)
64+
65+
66+
def test_v1_6_0_tbptt_reduce_fx(tmpdir):
67+
68+
class TestModel(BoringModel):
69+
70+
def training_step(self, *args):
71+
self.log("foo", 1, tbptt_reduce_fx=lambda x: x)
72+
return super().training_step(*args)
73+
74+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
75+
with pytest.deprecated_call(match=r"tbptt_reduce_fx=...\)` is no longer supported"):
76+
trainer.fit(TestModel())
77+
78+
79+
def test_v1_6_0_tbptt_pad_token(tmpdir):
80+
81+
class TestModel(BoringModel):
82+
83+
def training_step(self, *args):
84+
self.log("foo", 1, tbptt_pad_token=0)
85+
return super().training_step(*args)
86+
87+
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
88+
with pytest.deprecated_call(match=r"tbptt_pad_token=...\)` is no longer supported"):
89+
trainer.fit(TestModel())

0 commit comments

Comments
 (0)