@@ -137,6 +137,56 @@ This metrics API is independent of PyTorch Lightning. Metrics can directly be us
137137 To change this, after initializing the metric, the method ``.persistent(mode) `` can
138138 be used to enable (``mode=True ``) or disable (``mode=False ``) this behaviour.
139139
140+ *******************
141+ Metrics and devices
142+ *******************
143+
144+ Metrics are simple subclasses of :class: `~torch.nn.Module ` and their metric states behave
145+ similar to buffers and parameters of modules. This means that metrics states should
146+ be moved to the same device as the input of the metric:
147+
148+ .. code-block :: python
149+
150+ import torch
151+ from pytorch_lightning.metrics import Accuracy
152+
153+ target = torch.tensor([1 , 1 , 0 , 0 ], device = torch.device(" cuda" , 0 ))
154+ preds = torch.tensor([0 , 1 , 0 , 0 ], device = torch.device(" cuda" , 0 ))
155+
156+ # Metric states are always initialized on cpu, and needs to be moved to
157+ # the correct device
158+ confmat = Accuracy(num_classes = 2 ).to(torch.device(" cuda" , 0 ))
159+ out = confmat(preds, target)
160+ print (out.device) # cuda:0
161+
162+ However, when **properly defined ** inside a :class: `~pytorch_lightning.core.lightning.LightningModule `
163+ , Lightning will automatically move the metrics to the same device as the data. Being
164+ **properly defined ** means that the metric is correctly identified as a child module of the
165+ model (check ``.children() `` attribute of the model). Therefore, metrics cannot be placed
166+ in native python ``list `` and ``dict ``, as they will not be correctly identified
167+ as child modules. Instead of ``list `` use :class: `~torch.nn.ModuleList ` and instead of
168+ ``dict `` use :class: `~torch.nn.ModuleDict `.
169+
170+ .. testcode ::
171+
172+ class MyModule(LightningModule):
173+ def __init__(self):
174+ ...
175+ # valid ways metrics will be identified as child modules
176+ self.metric1 = pl.metrics.Accuracy()
177+ self.metric2 = torch.nn.ModuleList(pl.metrics.Accuracy())
178+ self.metric3 = torch.nn.ModuleDict({'accuracy': Accuracy()})
179+
180+ def training_step(self, batch, batch_idx):
181+ # all metrics will be on the same device as the input batch
182+ data, target = batch
183+ preds = self(data)
184+ ...
185+ val1 = self.metric1(preds, target)
186+ val2 = self.metric2[0](preds, target)
187+ val3 = self.metric3['accuracy'](preds, target)
188+
189+
140190*********************
141191Implementing a Metric
142192*********************
0 commit comments