Skip to content

Commit dabfeca

Browse files
SkafteNickis-rogShreeyakrohitgr7awaelchli
authored
[Metrics] [Docs] Add section about device placement (#5280)
* update docs * Update docs/source/metrics.rst Co-authored-by: Shreeyak <[email protected]> * Update docs/source/metrics.rst Co-authored-by: Shreeyak <[email protected]> * Update docs/source/metrics.rst Co-authored-by: Shreeyak <[email protected]> * Update docs/source/metrics.rst * Update docs/source/metrics.rst * Update docs/source/metrics.rst * Apply suggestions from code review Co-authored-by: Adrian Wälchli <[email protected]> * Update docs/source/metrics.rst Co-authored-by: Jirka Borovec <[email protected]> * Update docs/source/metrics.rst * Update docs/source/metrics.rst * try fix failing doc test Co-authored-by: Roger Shieh <[email protected]> Co-authored-by: Shreeyak <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent 0c7c9e8 commit dabfeca

File tree

1 file changed

+50
-0
lines changed

1 file changed

+50
-0
lines changed

docs/source/metrics.rst

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,56 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
137137
To change this, after initializing the metric, the method ``.persistent(mode)`` can
138138
be used to enable (``mode=True``) or disable (``mode=False``) this behaviour.
139139

140+
*******************
141+
Metrics and devices
142+
*******************
143+
144+
Metrics are simple subclasses of :class:`~torch.nn.Module` and their metric states behave
145+
similar to buffers and parameters of modules. This means that metrics states should
146+
be moved to the same device as the input of the metric:
147+
148+
.. code-block:: python
149+
150+
import torch
151+
from pytorch_lightning.metrics import Accuracy
152+
153+
target = torch.tensor([1, 1, 0, 0], device=torch.device("cuda", 0))
154+
preds = torch.tensor([0, 1, 0, 0], device=torch.device("cuda", 0))
155+
156+
# Metric states are always initialized on cpu, and needs to be moved to
157+
# the correct device
158+
confmat = Accuracy(num_classes=2).to(torch.device("cuda", 0))
159+
out = confmat(preds, target)
160+
print(out.device) # cuda:0
161+
162+
However, when **properly defined** inside a :class:`~pytorch_lightning.core.lightning.LightningModule`
163+
, Lightning will automatically move the metrics to the same device as the data. Being
164+
**properly defined** means that the metric is correctly identified as a child module of the
165+
model (check ``.children()`` attribute of the model). Therefore, metrics cannot be placed
166+
in native python ``list`` and ``dict``, as they will not be correctly identified
167+
as child modules. Instead of ``list`` use :class:`~torch.nn.ModuleList` and instead of
168+
``dict`` use :class:`~torch.nn.ModuleDict`.
169+
170+
.. testcode::
171+
172+
class MyModule(LightningModule):
173+
def __init__(self):
174+
...
175+
# valid ways metrics will be identified as child modules
176+
self.metric1 = pl.metrics.Accuracy()
177+
self.metric2 = torch.nn.ModuleList(pl.metrics.Accuracy())
178+
self.metric3 = torch.nn.ModuleDict({'accuracy': Accuracy()})
179+
180+
def training_step(self, batch, batch_idx):
181+
# all metrics will be on the same device as the input batch
182+
data, target = batch
183+
preds = self(data)
184+
...
185+
val1 = self.metric1(preds, target)
186+
val2 = self.metric2[0](preds, target)
187+
val3 = self.metric3['accuracy'](preds, target)
188+
189+
140190
*********************
141191
Implementing a Metric
142192
*********************

0 commit comments

Comments
 (0)