@@ -196,53 +196,76 @@ Metric API
196196.. autoclass :: pytorch_lightning.metrics.Metric
197197 :noindex:
198198
199- *************
200- Class metrics
201- *************
199+ ***************************
200+ Class vs Functional Metrics
201+ ***************************
202202
203+ The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.
204+
205+ Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
206+ If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface.
207+
208+ **********************
203209Classification Metrics
204- ----------------------
210+ **********************
205211
206- Accuracy
207- ~~~~~~~~
212+ Input types
213+ -----------
208214
209- .. autoclass :: pytorch_lightning.metrics. classification.Accuracy
210- :noindex :
215+ For the purposes of classification metrics, inputs (predictions and targets) are split
216+ into these categories (`` N `` stands for the batch size and `` C `` for number of classes) :
211217
212- Precision
213- ~~~~~~~~~
218+ .. csv-table :: \*dtype ``binary`` means integers that are either 0 or 1
219+ :header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
220+ :widths: 20, 10, 10, 10, 10
214221
215- .. autoclass :: pytorch_lightning.metrics.classification.Precision
216- :noindex:
222+ "Binary", "(N,)", "``float ``", "(N,)", "``binary ``\* "
223+ "Multi-class", "(N,)", "``int ``", "(N,)", "``int ``"
224+ "Multi-class with probabilities", "(N, C)", "``float ``", "(N,)", "``int ``"
225+ "Multi-label", "(N, ...)", "``float ``", "(N, ...)", "``binary ``\* "
226+ "Multi-dimensional multi-class", "(N, ...)", "``int ``", "(N, ...)", "``int ``"
227+ "Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float ``", "(N, ...)", "``int ``"
217228
218- Recall
219- ~~~~~~
229+ .. note ::
230+ All dimensions of size 1 (except ``N ``) are "squeezed out" at the beginning, so
231+ that, for example, a tensor of shape ``(N, 1) `` is treated as ``(N, ) ``.
220232
221- .. autoclass :: pytorch_lightning.metrics.classification.Recall
222- :noindex:
233+ When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
234+ the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
223235
224- FBeta
225- ~~~~~
236+ .. testcode ::
226237
227- .. autoclass :: pytorch_lightning.metrics.classification.FBeta
228- :noindex:
238+ # Binary inputs
239+ binary_preds = torch.tensor([0.6, 0.1, 0.9])
240+ binary_target = torch.tensor([1, 0, 2])
229241
230- F1
231- ~~
242+ # Multi-class inputs
243+ mc_preds = torch.tensor([0, 2, 1])
244+ mc_target = torch.tensor([0, 1, 2])
232245
233- .. autoclass :: pytorch_lightning.metrics.classification.F1
234- :noindex:
246+ # Multi-class inputs with probabilities
247+ mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
248+ mc_target_probs = torch.tensor([0, 1, 2])
235249
236- ConfusionMatrix
237- ~~~~~~~~~~~~~~~
250+ # Multi-label inputs
251+ ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
252+ ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
238253
239- .. autoclass :: pytorch_lightning.metrics.classification.ConfusionMatrix
240- :noindex:
254+ In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class
255+ but are actually binary/multi-label. For example, if both predictions and targets are 1d
256+ binary tensors. Or it could be the other way around, you want to treat binary/multi-label
257+ inputs as 2-class (multi-dimensional) multi-class inputs.
241258
242- PrecisionRecallCurve
243- ~~~~~~~~~~~~~~~~~~~~
259+ For these cases, the metrics where this distinction would make a difference, expose the
260+ `` is_multiclass `` argument.
244261
245- .. autoclass :: pytorch_lightning.metrics.classification.PrecisionRecallCurve
262+ Class Metrics (Classification)
263+ ------------------------------
264+
265+ Accuracy
266+ ~~~~~~~~
267+
268+ .. autoclass :: pytorch_lightning.metrics.classification.Accuracy
246269 :noindex:
247270
248271AveragePrecision
@@ -251,67 +274,51 @@ AveragePrecision
251274.. autoclass :: pytorch_lightning.metrics.classification.AveragePrecision
252275 :noindex:
253276
254- ROC
255- ~~~
277+ ConfusionMatrix
278+ ~~~~~~~~~~~~~~~
256279
257- .. autoclass :: pytorch_lightning.metrics.classification.ROC
280+ .. autoclass :: pytorch_lightning.metrics.classification.ConfusionMatrix
258281 :noindex:
259282
260- Regression Metrics
261- ------------------
262-
263- MeanSquaredError
264- ~~~~~~~~~~~~~~~~
283+ F1
284+ ~~
265285
266- .. autoclass :: pytorch_lightning.metrics.regression.MeanSquaredError
286+ .. autoclass :: pytorch_lightning.metrics.classification.F1
267287 :noindex:
268288
289+ FBeta
290+ ~~~~~
269291
270- MeanAbsoluteError
271- ~~~~~~~~~~~~~~~~~
272-
273- .. autoclass :: pytorch_lightning.metrics.regression.MeanAbsoluteError
292+ .. autoclass :: pytorch_lightning.metrics.classification.FBeta
274293 :noindex:
275294
295+ Precision
296+ ~~~~~~~~~
276297
277- MeanSquaredLogError
278- ~~~~~~~~~~~~~~~~~~~
279-
280- .. autoclass :: pytorch_lightning.metrics.regression.MeanSquaredLogError
298+ .. autoclass :: pytorch_lightning.metrics.classification.Precision
281299 :noindex:
282300
301+ PrecisionRecallCurve
302+ ~~~~~~~~~~~~~~~~~~~~
283303
284- ExplainedVariance
285- ~~~~~~~~~~~~~~~~~
286-
287- .. autoclass :: pytorch_lightning.metrics.regression.ExplainedVariance
304+ .. autoclass :: pytorch_lightning.metrics.classification.PrecisionRecallCurve
288305 :noindex:
289306
307+ Recall
308+ ~~~~~~
290309
291- PSNR
292- ~~~~
293-
294- .. autoclass :: pytorch_lightning.metrics.regression.PSNR
310+ .. autoclass :: pytorch_lightning.metrics.classification.Recall
295311 :noindex:
296312
313+ ROC
314+ ~~~
297315
298- SSIM
299- ~~~~
300-
301- .. autoclass :: pytorch_lightning.metrics.regression.SSIM
316+ .. autoclass :: pytorch_lightning.metrics.classification.ROC
302317 :noindex:
303318
304- ******************
305- Functional Metrics
306- ******************
307-
308- The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.
309-
310- Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
311- If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface.
312319
313- Classification
314- --------------
320+ Functional Metrics ( Classification)
321+ -----------------------------------
315322
316323accuracy [func]
317324~~~~~~~~~~~~~~~
@@ -417,6 +424,12 @@ recall [func]
417424.. autofunction :: pytorch_lightning.metrics.functional.classification.recall
418425 :noindex:
419426
427+ select_topk [func]
428+ ~~~~~~~~~~~~~~~~~~~~~
429+
430+ .. autofunction :: pytorch_lightning.metrics.utils.select_topk
431+ :noindex:
432+
420433
421434stat_scores [func]
422435~~~~~~~~~~~~~~~~~~
@@ -445,9 +458,57 @@ to_onehot [func]
445458.. autofunction :: pytorch_lightning.metrics.utils.to_onehot
446459 :noindex:
447460
461+ ******************
462+ Regression Metrics
463+ ******************
464+
465+ Class Metrics (Regression)
466+ --------------------------
448467
449- Regression
450- ----------
468+ ExplainedVariance
469+ ~~~~~~~~~~~~~~~~~
470+
471+ .. autoclass :: pytorch_lightning.metrics.regression.ExplainedVariance
472+ :noindex:
473+
474+
475+ MeanAbsoluteError
476+ ~~~~~~~~~~~~~~~~~
477+
478+ .. autoclass :: pytorch_lightning.metrics.regression.MeanAbsoluteError
479+ :noindex:
480+
481+
482+ MeanSquaredError
483+ ~~~~~~~~~~~~~~~~
484+
485+ .. autoclass :: pytorch_lightning.metrics.regression.MeanSquaredError
486+ :noindex:
487+
488+
489+ MeanSquaredLogError
490+ ~~~~~~~~~~~~~~~~~~~
491+
492+ .. autoclass :: pytorch_lightning.metrics.regression.MeanSquaredLogError
493+ :noindex:
494+
495+
496+ PSNR
497+ ~~~~
498+
499+ .. autoclass :: pytorch_lightning.metrics.regression.PSNR
500+ :noindex:
501+
502+
503+ SSIM
504+ ~~~~
505+
506+ .. autoclass :: pytorch_lightning.metrics.regression.SSIM
507+ :noindex:
508+
509+
510+ Functional Metrics (Regression)
511+ -------------------------------
451512
452513explained_variance [func]
453514~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -470,17 +531,17 @@ mean_squared_error [func]
470531 :noindex:
471532
472533
473- psnr [func]
474- ~~~~~~~~~~~
534+ mean_squared_log_error [func]
535+ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
475536
476- .. autofunction :: pytorch_lightning.metrics.functional.psnr
537+ .. autofunction :: pytorch_lightning.metrics.functional.mean_squared_log_error
477538 :noindex:
478539
479540
480- mean_squared_log_error [func]
481- ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
541+ psnr [func]
542+ ~~~~~~~~~~~
482543
483- .. autofunction :: pytorch_lightning.metrics.functional.mean_squared_log_error
544+ .. autofunction :: pytorch_lightning.metrics.functional.psnr
484545 :noindex:
485546
486547
@@ -490,22 +551,22 @@ ssim [func]
490551.. autofunction :: pytorch_lightning.metrics.functional.ssim
491552 :noindex:
492553
493-
554+ ***
494555NLP
495- ---
556+ ***
496557
497558bleu_score [func]
498- ~~~~~~~~~~~~~~~~~
559+ -----------------
499560
500561.. autofunction :: pytorch_lightning.metrics.functional.nlp.bleu_score
501562 :noindex:
502563
503-
564+ ********
504565Pairwise
505- --------
566+ ********
506567
507568embedding_similarity [func]
508- ~~~~~~~~~~~~~~~~~~~~~~~~~~~
569+ ---------------------------
509570
510571.. autofunction :: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity
511572 :noindex:
0 commit comments