Skip to content

Commit 1d90c35

Browse files
awaelchlipre-commit-ci[bot]
authored andcommitted
Use PrecisionType enum instead of checking raw values (#10704)
* use precision type * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b28ab34 commit 1d90c35

File tree

5 files changed

+99
-97
lines changed

5 files changed

+99
-97
lines changed

docs/source/extensions/logging.rst

Lines changed: 81 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -14,45 +14,78 @@
1414
Logging
1515
#######
1616

17-
Lightning supports the most popular logging frameworks (TensorBoard, Comet, etc...).
17+
Supported Loggers
18+
=================
19+
20+
The following are loggers we support:
1821

19-
By default, Lightning uses `PyTorch TensorBoard <https://pytorch.org/docs/stable/tensorboard.html>`__ logging under the hood, and stores the logs to a directory (by default in ``lightning_logs/``).
22+
.. note::
23+
The following loggers will normally plot an additional chart (**global_step VS epoch**).
24+
25+
.. note::
26+
Depending on the loggers you use, there might be some additional charts.
27+
28+
.. currentmodule:: pytorch_lightning.loggers
29+
30+
.. autosummary::
31+
:toctree: generated
32+
:nosignatures:
33+
:template: classtemplate.rst
34+
35+
CometLogger
36+
CSVLogger
37+
MLFlowLogger
38+
NeptuneLogger
39+
TensorBoardLogger
40+
TestTubeLogger
41+
WandbLogger
42+
43+
44+
By default, Lightning uses ``TensorBoard`` logger under the hood, and stores the logs to a directory (by default in ``lightning_logs/``).
2045

2146
.. testcode::
2247

2348
from pytorch_lightning import Trainer
2449

25-
# Automatically logs to a directory
26-
# (by default ``lightning_logs/``)
50+
# Automatically logs to a directory (by default lightning_logs/)
2751
trainer = Trainer()
2852

2953
To see your logs:
3054

3155
.. code-block:: bash
3256
57+
# Install tensorboard
58+
pip install tensorboard
3359
tensorboard --logdir=lightning_logs/
3460
61+
To run tensorboard in a jupyter notebook environment, use the following in a jupyter cell:
62+
63+
.. code-block:: bash
64+
65+
%reload_ext tensorboard
66+
%tensorboard --logdir=lightning_logs/
67+
3568
You can also pass a custom Logger to the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
3669

3770
.. testcode::
3871

3972
from pytorch_lightning import loggers as pl_loggers
4073

41-
tb_logger = pl_loggers.TensorBoardLogger("logs/")
74+
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
4275
trainer = Trainer(logger=tb_logger)
4376

44-
Choose from any of the others such as MLflow, Comet, Neptune, WandB, ...
77+
Choose from any of the others such as MLflow, Comet, Neptune, WandB, etc.
4578

4679
.. testcode::
4780

4881
comet_logger = pl_loggers.CometLogger(save_dir="logs/")
4982
trainer = Trainer(logger=comet_logger)
5083

51-
To use multiple loggers, simply pass in a ``list`` or ``tuple`` of loggers ...
84+
To use multiple loggers, simply pass in a ``list`` or ``tuple`` of loggers.
5285

5386
.. testcode::
5487

55-
tb_logger = pl_loggers.TensorBoardLogger("logs/")
88+
tb_logger = pl_loggers.TensorBoardLogger(save_dir="logs/")
5689
comet_logger = pl_loggers.CometLogger(save_dir="logs/")
5790
trainer = Trainer(logger=[tb_logger, comet_logger])
5891

@@ -62,8 +95,8 @@ To use multiple loggers, simply pass in a ``list`` or ``tuple`` of loggers ...
6295

6396
.. note::
6497

65-
All loggers log by default to `os.getcwd()`. To change the path without creating a logger set
66-
`Trainer(default_root_dir='/your/path/to/save/checkpoints')`
98+
All loggers log by default to ``os.getcwd()``. To change the path without creating a logger set
99+
``Trainer(default_root_dir='/your/path/to/save/checkpoints')``
67100

68101
----------
69102

@@ -75,55 +108,52 @@ Lightning offers automatic log functionalities for logging scalars, or manual lo
75108

76109
Automatic Logging
77110
=================
78-
Use the :func:`~~pytorch_lightning.core.lightning.LightningModule.log`
111+
Use the :meth:`~pytorch_lightning.core.lightning.LightningModule.log`
79112
method to log from anywhere in a :doc:`lightning module <../common/lightning_module>` and :doc:`callbacks <../extensions/callbacks>`
80-
except functions with `batch_start` in their names.
113+
except functions with ``batch_start`` in their names.
114+
# TODO: check the hooks that doesn't support logging
81115

82116
.. code-block:: python
83117
84118
def training_step(self, batch, batch_idx):
85119
self.log("my_metric", x)
86120
87121
88-
# or a dict
122+
# or a dict to get multiple metrics on the same plot of the logger supports it
89123
def training_step(self, batch, batch_idx):
90124
self.log("performance", {"acc": acc, "recall": recall})
91125
92-
Depending on where log is called from, Lightning auto-determines the correct logging mode for you. \
93-
But of course you can override the default behavior by manually setting the :func:`~~pytorch_lightning.core.lightning.LightningModule.log` parameters.
126+
Depending on where log is called from, Lightning auto-determines the correct logging mode for you. But of course you can
127+
override the default behavior by manually setting the :meth:`~pytorch_lightning.core.lightning.LightningModule.log` parameters.
94128

95129
.. code-block:: python
96130
97131
def training_step(self, batch, batch_idx):
98132
self.log("my_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
99133
100-
The :func:`~~pytorch_lightning.core.lightning.LightningModule.log` method has a few options:
101-
102-
* `on_step`: Logs the metric at the current step. Defaults to `True` in :func:`~~pytorch_lightning.core.lightning.LightningModule.training_step`, and :func:`~pytorch_lightning.core.lightning.LightningModule.training_step_end`.
103-
104-
* `on_epoch`: Automatically accumulates and logs at the end of the epoch. Defaults to True anywhere in validation or test loops, and in :func:`~~pytorch_lightning.core.lightning.LightningModule.training_epoch_end`.
105-
106-
* `prog_bar`: Logs to the progress bar.
107-
108-
* `logger`: Logs to the logger like Tensorboard, or any other custom logger passed to the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
134+
The :meth:`~pytorch_lightning.core.lightning.LightningModule.log` method has a few options:
109135

136+
* ``on_step``: Logs the metric at the current step.
137+
* ``on_epoch``: Automatically accumulates and logs at the end of the epoch.
138+
* ``prog_bar``: Logs to the progress bar.
139+
* ``logger``: Logs to the logger like ``Tensorboard``, or any other custom logger passed to the :class:`~pytorch_lightning.trainer.trainer.Trainer`.
110140

111141
.. note::
112142

113143
- Setting ``on_epoch=True`` will cache all your logged values during the full training epoch and perform a
114144
reduction in ``on_train_epoch_end``. We recommend using `TorchMetrics <https://torchmetrics.readthedocs.io/>`_, when working with custom reduction.
115145

116146
- Setting both ``on_step=True`` and ``on_epoch=True`` will create two keys per metric you log with
117-
suffix ``_step`` and ``_epoch``, respectively. You can refer to these keys e.g. in the `monitor`
147+
suffix ``_step`` and ``_epoch`` respectively. You can refer to these keys e.g. in the `monitor`
118148
argument of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` or in the graphs plotted to the logger of your choice.
119149

120150

121-
If your work requires to log in an unsupported function, please open an issue with a clear description of why it is blocking you.
151+
If your work requires to log in an unsupported method, please open an issue with a clear description of why it is blocking you.
122152

123153

124-
Manual logging
125-
==============
126-
If you want to log anything that is not a scalar, like histograms, text, images, etc... you may need to use the logger object directly.
154+
Manual logging Non-Scalar Artifacts
155+
===================================
156+
If you want to log anything that is not a scalar, like histograms, text, images, etc. you may need to use the logger object directly.
127157

128158
.. code-block:: python
129159
@@ -136,14 +166,6 @@ If you want to log anything that is not a scalar, like histograms, text, images,
136166
tensorboard.add_figure(...)
137167
138168
139-
Access your logs
140-
================
141-
Once your training starts, you can view the logs by using your favorite logger or booting up the Tensorboard logs:
142-
143-
.. code-block:: bash
144-
145-
tensorboard --logdir ./lightning_logs
146-
147169
----------
148170

149171
********************
@@ -155,9 +177,8 @@ Use the :func:`~pytorch_lightning.loggers.base.rank_zero_experiment` and :func:`
155177

156178
.. testcode::
157179

158-
from pytorch_lightning.utilities import rank_zero_only
159-
from pytorch_lightning.loggers import LightningLoggerBase
160-
from pytorch_lightning.loggers.base import rank_zero_experiment
180+
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
181+
from pytorch_lightning.utilities.distributed import rank_zero_only
161182

162183

163184
class MyLogger(LightningLoggerBase):
@@ -217,27 +238,26 @@ Logging frequency
217238
=================
218239

219240
It may slow training down to log every single batch. By default, Lightning logs every 50 rows, or 50 training steps.
220-
To change this behaviour, set the `log_every_n_steps` :class:`~pytorch_lightning.trainer.trainer.Trainer` flag.
241+
To change this behaviour, set the ``log_every_n_steps`` :class:`~pytorch_lightning.trainer.trainer.Trainer` flag.
221242

222243
.. testcode::
223244

224245
k = 10
225246
trainer = Trainer(log_every_n_steps=k)
226247

227248

228-
229249
Log writing frequency
230250
=====================
231251

232252
Writing to a logger can be expensive, so by default Lightning writes logs to disk or to the given logger every 100 training steps.
233-
To change this behaviour, set the interval at which you wish to flush logs to the filesystem using the `flush_logs_every_n_steps` :class:`~pytorch_lightning.trainer.trainer.Trainer` flag.
253+
To change this behaviour, set the interval at which you wish to flush logs to the filesystem using the ``flush_logs_every_n_steps`` :class:`~pytorch_lightning.trainer.trainer.Trainer` flag.
234254

235255
.. testcode::
236256

237257
k = 100
238258
trainer = Trainer(flush_logs_every_n_steps=k)
239259

240-
Unlike the `log_every_n_steps`, this argument does not apply to all loggers.
260+
Unlike the ``log_every_n_steps``, this argument does not apply to all loggers.
241261
The example shown here works with :class:`~pytorch_lightning.loggers.tensorboard.TensorBoardLogger`,
242262
which is the default logger in Lightning.
243263

@@ -246,8 +266,8 @@ which is the default logger in Lightning.
246266
************
247267
Progress Bar
248268
************
249-
You can add any metric to the progress bar using :func:`~~pytorch_lightning.core.lightning.LightningModule.log`
250-
method, setting `prog_bar=True`.
269+
You can add any metric to the progress bar using :meth:`~pytorch_lightning.core.lightning.LightningModule.log`
270+
method, setting ``prog_bar=True``.
251271

252272

253273
.. code-block:: python
@@ -261,15 +281,19 @@ Modifying the progress bar
261281

262282
The progress bar by default already includes the training loss and version number of the experiment
263283
if you are using a logger. These defaults can be customized by overriding the
264-
:func:`~pytorch_lightning.callbacks.base.ProgressBarBase.get_metrics` hook in your module.
284+
:meth:`~pytorch_lightning.callbacks.progress.base.ProgressBarBase.get_metrics` hook in your logger.
265285

266286
.. code-block:: python
267287
268-
def get_metrics(self):
269-
# don't show the version number
270-
items = super().get_metrics()
271-
items.pop("v_num", None)
272-
return items
288+
from pytorch_lightning.callbacks.progress import Tqdm
289+
290+
291+
class CustomProgressBar(Tqdm):
292+
def get_metrics(self, *args, **kwargs):
293+
# don't show the version number
294+
items = super().get_metrics()
295+
items.pop("v_num", None)
296+
return items
273297
274298
275299
----------
@@ -303,16 +327,16 @@ Read more about custom Python logging `here <https://docs.python.org/3/library/l
303327
Logging hyperparameters
304328
***********************
305329

306-
When training a model, it's useful to know what hyperparams went into that model.
307-
When Lightning creates a checkpoint, it stores a key "hyper_parameters" with the hyperparams.
330+
When training a model, it is useful to know what hyperparams went into that model.
331+
When Lightning creates a checkpoint, it stores a key ``"hyper_parameters"`` with the hyperparams.
308332

309333
.. code-block:: python
310334
311335
lightning_checkpoint = torch.load(filepath, map_location=lambda storage, loc: storage)
312336
hyperparams = lightning_checkpoint["hyper_parameters"]
313337
314338
Some loggers also allow logging the hyperparams used in the experiment. For instance,
315-
when using the TestTubeLogger or the TensorBoardLogger, all hyperparams will show
339+
when using the ``TestTubeLogger`` or the ``TensorBoardLogger``, all hyperparams will show
316340
in the `hparams tab <https://pytorch.org/docs/stable/tensorboard.html#torch.utils.tensorboard.writer.SummaryWriter.add_hparams>`_.
317341

318342
.. note::
@@ -334,7 +358,7 @@ in the `hparams tab <https://pytorch.org/docs/stable/tensorboard.html#torch.util
334358
self.log("hp/metric_1", some_scalar_1)
335359
self.log("hp/metric_2", some_scalar_2)
336360
337-
In the example, using `hp/` as a prefix allows for the metrics to be grouped under "hp" in the tensorboard scalar tab where you can collapse them.
361+
In the example, using ``"hp/"`` as a prefix allows for the metrics to be grouped under "hp" in the tensorboard scalar tab where you can collapse them.
338362

339363
----------
340364

@@ -343,7 +367,7 @@ Snapshot code
343367
*************
344368

345369
Loggers also allow you to snapshot a copy of the code used in this experiment.
346-
For example, TestTubeLogger does this with a flag:
370+
For example, ``TestTubeLogger`` does this with a flag:
347371

348372
.. code-block:: python
349373
@@ -352,34 +376,3 @@ For example, TestTubeLogger does this with a flag:
352376
logger = TestTubeLogger(".", create_git_tag=True)
353377
354378
----------
355-
356-
*****************
357-
Supported Loggers
358-
*****************
359-
360-
The following are loggers we support
361-
362-
.. note::
363-
The following loggers will normally plot an additional chart (**global_step VS epoch**).
364-
365-
.. note::
366-
postfix ``_step`` and ``_epoch`` will be appended to the name you logged
367-
if ``on_step`` and ``on_epoch`` are set to ``True`` in ``self.log()``.
368-
369-
.. note::
370-
Depending on the loggers you use, there might be some additional charts.
371-
372-
.. currentmodule:: pytorch_lightning.loggers
373-
374-
.. autosummary::
375-
:toctree: generated
376-
:nosignatures:
377-
:template: classtemplate.rst
378-
379-
CometLogger
380-
CSVLogger
381-
MLFlowLogger
382-
NeptuneLogger
383-
TensorBoardLogger
384-
TestTubeLogger
385-
WandbLogger

pytorch_lightning/plugins/training_type/deepspeed.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from pytorch_lightning.utilities import GradClipAlgorithmType
3838
from pytorch_lightning.utilities.apply_func import apply_to_collection
3939
from pytorch_lightning.utilities.distributed import log, rank_zero_info
40-
from pytorch_lightning.utilities.enums import _StrategyType, AMPType
40+
from pytorch_lightning.utilities.enums import _StrategyType, AMPType, PrecisionType
4141
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4242
from pytorch_lightning.utilities.imports import _DEEPSPEED_AVAILABLE
4343
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -445,7 +445,11 @@ def init_deepspeed(self):
445445

446446
if self.zero_stage_3 and self.partition_module:
447447
# Ensure the entire model has been moved to the appropriate device
448-
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32
448+
dtype = (
449+
torch.float16
450+
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
451+
else torch.float32
452+
)
449453
deepspeed.zero.Init(
450454
module=model, remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
451455
)
@@ -502,7 +506,11 @@ def _initialize_deepspeed_train(self, model):
502506
def model_sharded_context(self) -> Generator[None, None, None]:
503507
if self.zero_stage_3:
504508
assert self._config_initialized
505-
dtype = torch.float16 if self.precision_plugin.precision in (16, "mixed") else torch.float32
509+
dtype = (
510+
torch.float16
511+
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED)
512+
else torch.float32
513+
)
506514
model_parallel_context = deepspeed.zero.Init(
507515
remote_device=self.remote_device, pin_memory=True, config=self.config, dtype=dtype
508516
)
@@ -629,7 +637,7 @@ def _auto_select_batch_size(self):
629637
return batch_size
630638

631639
def _format_precision_config(self) -> None:
632-
if self.precision_plugin.precision in (16, "mixed"):
640+
if self.precision_plugin.precision in (PrecisionType.HALF, PrecisionType.MIXED):
633641
if "fp16" not in self.config and self.precision_plugin.amp_type == AMPType.NATIVE:
634642
# FP16 is a DeepSpeed standalone AMP implementation
635643
rank_zero_info("Enabling DeepSpeed FP16.")

pytorch_lightning/plugins/training_type/fully_sharded.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytorch_lightning.plugins.precision import PrecisionPlugin
2222
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
2323
from pytorch_lightning.utilities import _FAIRSCALE_FULLY_SHARDED_AVAILABLE
24-
from pytorch_lightning.utilities.enums import _StrategyType
24+
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626

2727
if _FAIRSCALE_FULLY_SHARDED_AVAILABLE:
@@ -139,7 +139,7 @@ def wrap_policy(*args, **kwargs):
139139
cpu_offload=self.cpu_offload,
140140
move_grads_to_cpu=self.move_grads_to_cpu,
141141
flatten_parameters=self.flatten_parameters,
142-
mixed_precision=precision == "mixed",
142+
mixed_precision=(precision == PrecisionType.MIXED),
143143
reshard_after_forward=self.reshard_after_forward,
144144
fp32_reduce_scatter=self.fp32_reduce_scatter,
145145
compute_dtype=self.compute_dtype,

0 commit comments

Comments
 (0)