Skip to content

Commit 50294a7

Browse files
puhukcarmocca
andauthored
Remove deprecated automatic logging of gpu metrics (#12657)
Co-authored-by: carmocca <[email protected]>
1 parent 3bd48b8 commit 50294a7

File tree

4 files changed

+5
-58
lines changed

4 files changed

+5
-58
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7575
- Removed the deprecated `terminate_on_nan` argument from the `Trainer` constructor ([#12553](https://github.com/PyTorchLightning/pytorch-lightning/pull/12553))
7676

7777

78-
-
79-
78+
- Removed the deprecated `log_gpu_memory` argument from the `Trainer` constructor ([#12657](https://github.com/PyTorchLightning/pytorch-lightning/pull/12657))
8079

8180
-
81+
- Removed the deprecated automatic logging of GPU stats by the logger connector ([#12657](https://github.com/PyTorchLightning/pytorch-lightning/pull/12657))
8282

8383
### Fixed
8484

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,38 +11,29 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from typing import Any, Dict, Iterable, Optional, Union
14+
from typing import Any, Iterable, Optional, Union
1515

1616
import torch
1717

1818
import pytorch_lightning as pl
19-
from pytorch_lightning.accelerators import GPUAccelerator
2019
from pytorch_lightning.loggers import LightningLoggerBase, TensorBoardLogger
2120
from pytorch_lightning.plugins.environments.slurm_environment import SLURMEnvironment
2221
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
2322
from pytorch_lightning.trainer.states import RunningStage
24-
from pytorch_lightning.utilities import memory
2523
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
2624
from pytorch_lightning.utilities.metrics import metrics_to_scalars
2725
from pytorch_lightning.utilities.model_helpers import is_overridden
2826
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
2927

3028

3129
class LoggerConnector:
32-
def __init__(self, trainer: "pl.Trainer", log_gpu_memory: Optional[str] = None) -> None:
30+
def __init__(self, trainer: "pl.Trainer") -> None:
3331
self.trainer = trainer
34-
if log_gpu_memory is not None:
35-
rank_zero_deprecation(
36-
"Setting `log_gpu_memory` with the trainer flag is deprecated in v1.5 and will be removed in v1.7. "
37-
"Please monitor GPU stats with the `DeviceStatsMonitor` callback directly instead."
38-
)
39-
self.log_gpu_memory = log_gpu_memory
4032
self._val_log_step: int = 0
4133
self._test_log_step: int = 0
4234
self._progress_bar_metrics: _PBAR_DICT = {}
4335
self._logged_metrics: _OUT_DICT = {}
4436
self._callback_metrics: _OUT_DICT = {}
45-
self._gpus_metrics: Dict[str, float] = {}
4637
self._epoch_end_reached = False
4738
self._current_fx: Optional[str] = None
4839
self._batch_idx: Optional[int] = None
@@ -193,9 +184,6 @@ def update_train_step_metrics(self) -> None:
193184
if self.trainer.fit_loop._should_accumulate() and self.trainer.lightning_module.automatic_optimization:
194185
return
195186

196-
# TODO: remove this call in v1.7
197-
self._log_gpus_metrics()
198-
199187
# when metrics should be logged
200188
assert not self._epoch_end_reached
201189
if self.should_update_logs or self.trainer.fast_dev_run:
@@ -210,22 +198,6 @@ def update_train_epoch_metrics(self) -> None:
210198
assert self.trainer._results is not None
211199
self.trainer._results.reset(metrics=True)
212200

213-
def _log_gpus_metrics(self) -> None:
214-
"""
215-
.. deprecated:: v1.5
216-
This function was deprecated in v1.5 in favor of
217-
`pytorch_lightning.accelerators.gpu._get_nvidia_gpu_stats` and will be removed in v1.7.
218-
"""
219-
for key, mem in self.gpus_metrics.items():
220-
if self.log_gpu_memory == "min_max":
221-
self.trainer.lightning_module.log(key, mem, prog_bar=False, logger=True)
222-
else:
223-
gpu_id = int(key.split("/")[0].split(":")[1])
224-
if gpu_id in self.trainer.device_ids:
225-
self.trainer.lightning_module.log(
226-
key, mem, prog_bar=False, logger=True, on_step=True, on_epoch=False
227-
)
228-
229201
"""
230202
Utilities and properties
231203
"""
@@ -298,17 +270,6 @@ def metrics(self) -> _METRICS:
298270
assert self.trainer._results is not None
299271
return self.trainer._results.metrics(on_step)
300272

301-
@property
302-
def gpus_metrics(self) -> Dict[str, float]:
303-
"""
304-
.. deprecated:: v1.5
305-
Will be removed in v1.7.
306-
"""
307-
if isinstance(self.trainer.accelerator, GPUAccelerator) and self.log_gpu_memory:
308-
mem_map = memory.get_memory_profile(self.log_gpu_memory)
309-
self._gpus_metrics.update(mem_map)
310-
return self._gpus_metrics
311-
312273
@property
313274
def callback_metrics(self) -> _OUT_DICT:
314275
if self.trainer._results:

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,6 @@ def __init__(
145145
auto_select_gpus: bool = False,
146146
tpu_cores: Optional[Union[List[int], str, int]] = None,
147147
ipus: Optional[int] = None,
148-
log_gpu_memory: Optional[str] = None, # TODO: Remove in 1.7
149148
enable_progress_bar: bool = True,
150149
overfit_batches: Union[int, float] = 0.0,
151150
track_grad_norm: Union[int, float, str] = -1,
@@ -303,12 +302,6 @@ def __init__(
303302
of the individual loggers.
304303
Default: ``True``.
305304
306-
log_gpu_memory: None, 'min_max', 'all'. Might slow performance.
307-
308-
.. deprecated:: v1.5
309-
Deprecated in v1.5.0 and will be removed in v1.7.0
310-
Please use the ``DeviceStatsMonitor`` callback directly instead.
311-
312305
log_every_n_steps: How often to log within steps.
313306
Default: ``50``.
314307
@@ -461,7 +454,7 @@ def __init__(
461454
amp_level=amp_level,
462455
plugins=plugins,
463456
)
464-
self._logger_connector = LoggerConnector(self, log_gpu_memory)
457+
self._logger_connector = LoggerConnector(self)
465458
self._callback_connector = CallbackConnector(self)
466459
self._checkpoint_connector = CheckpointConnector(self, resume_from_checkpoint)
467460
self._signal_connector = SignalConnector(self)

tests/deprecated_api/test_remove_1-7.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -337,13 +337,6 @@ def test_v1_7_0_weights_summary_trainer(tmpdir):
337337
t.weights_summary = "blah"
338338

339339

340-
def test_v1_7_0_trainer_log_gpu_memory(tmpdir):
341-
with pytest.deprecated_call(
342-
match="Setting `log_gpu_memory` with the trainer flag is deprecated in v1.5 and will be removed"
343-
):
344-
_ = Trainer(log_gpu_memory="min_max")
345-
346-
347340
def test_v1_7_0_deprecated_slurm_job_id():
348341
trainer = Trainer()
349342
with pytest.deprecated_call(match="Method `slurm_job_id` is deprecated in v1.6.0 and will be removed in v1.7.0."):

0 commit comments

Comments
 (0)