Skip to content

Commit 678f642

Browse files
authored
Merge branch 'master' into bug/fixfix_ddp_manual
2 parents 3751352 + 68ba493 commit 678f642

File tree

7 files changed

+1101
-231
lines changed

7 files changed

+1101
-231
lines changed

docs/source/metrics.rst

Lines changed: 145 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -196,53 +196,76 @@ Metric API
196196
.. autoclass:: pytorch_lightning.metrics.Metric
197197
:noindex:
198198

199-
*************
200-
Class metrics
201-
*************
199+
***************************
200+
Class vs Functional Metrics
201+
***************************
202202

203+
The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.
204+
205+
Also, the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
206+
If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also using the class interface.
207+
208+
**********************
203209
Classification Metrics
204-
----------------------
210+
**********************
205211

206-
Accuracy
207-
~~~~~~~~
212+
Input types
213+
-----------
208214

209-
.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
210-
:noindex:
215+
For the purposes of classification metrics, inputs (predictions and targets) are split
216+
into these categories (``N`` stands for the batch size and ``C`` for number of classes):
211217

212-
Precision
213-
~~~~~~~~~
218+
.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
219+
:header: "Type", "preds shape", "preds dtype", "target shape", "target dtype"
220+
:widths: 20, 10, 10, 10, 10
214221

215-
.. autoclass:: pytorch_lightning.metrics.classification.Precision
216-
:noindex:
222+
"Binary", "(N,)", "``float``", "(N,)", "``binary``\*"
223+
"Multi-class", "(N,)", "``int``", "(N,)", "``int``"
224+
"Multi-class with probabilities", "(N, C)", "``float``", "(N,)", "``int``"
225+
"Multi-label", "(N, ...)", "``float``", "(N, ...)", "``binary``\*"
226+
"Multi-dimensional multi-class", "(N, ...)", "``int``", "(N, ...)", "``int``"
227+
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"
217228

218-
Recall
219-
~~~~~~
229+
.. note::
230+
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
231+
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.
220232

221-
.. autoclass:: pytorch_lightning.metrics.classification.Recall
222-
:noindex:
233+
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
234+
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
223235

224-
FBeta
225-
~~~~~
236+
.. testcode::
226237

227-
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
228-
:noindex:
238+
# Binary inputs
239+
binary_preds = torch.tensor([0.6, 0.1, 0.9])
240+
binary_target = torch.tensor([1, 0, 2])
229241

230-
F1
231-
~~
242+
# Multi-class inputs
243+
mc_preds = torch.tensor([0, 2, 1])
244+
mc_target = torch.tensor([0, 1, 2])
232245

233-
.. autoclass:: pytorch_lightning.metrics.classification.F1
234-
:noindex:
246+
# Multi-class inputs with probabilities
247+
mc_preds_probs = torch.tensor([[0.8, 0.2, 0], [0.1, 0.2, 0.7], [0.3, 0.6, 0.1]])
248+
mc_target_probs = torch.tensor([0, 1, 2])
235249

236-
ConfusionMatrix
237-
~~~~~~~~~~~~~~~
250+
# Multi-label inputs
251+
ml_preds = torch.tensor([[0.2, 0.8, 0.9], [0.5, 0.6, 0.1], [0.3, 0.1, 0.1]])
252+
ml_target = torch.tensor([[0, 1, 1], [1, 0, 0], [0, 0, 0]])
238253

239-
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
240-
:noindex:
254+
In some rare cases, you might have inputs which appear to be (multi-dimensional) multi-class
255+
but are actually binary/multi-label. For example, if both predictions and targets are 1d
256+
binary tensors. Or it could be the other way around, you want to treat binary/multi-label
257+
inputs as 2-class (multi-dimensional) multi-class inputs.
241258

242-
PrecisionRecallCurve
243-
~~~~~~~~~~~~~~~~~~~~
259+
For these cases, the metrics where this distinction would make a difference, expose the
260+
``is_multiclass`` argument.
244261

245-
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve
262+
Class Metrics (Classification)
263+
------------------------------
264+
265+
Accuracy
266+
~~~~~~~~
267+
268+
.. autoclass:: pytorch_lightning.metrics.classification.Accuracy
246269
:noindex:
247270

248271
AveragePrecision
@@ -251,67 +274,51 @@ AveragePrecision
251274
.. autoclass:: pytorch_lightning.metrics.classification.AveragePrecision
252275
:noindex:
253276

254-
ROC
255-
~~~
277+
ConfusionMatrix
278+
~~~~~~~~~~~~~~~
256279

257-
.. autoclass:: pytorch_lightning.metrics.classification.ROC
280+
.. autoclass:: pytorch_lightning.metrics.classification.ConfusionMatrix
258281
:noindex:
259282

260-
Regression Metrics
261-
------------------
262-
263-
MeanSquaredError
264-
~~~~~~~~~~~~~~~~
283+
F1
284+
~~
265285

266-
.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError
286+
.. autoclass:: pytorch_lightning.metrics.classification.F1
267287
:noindex:
268288

289+
FBeta
290+
~~~~~
269291

270-
MeanAbsoluteError
271-
~~~~~~~~~~~~~~~~~
272-
273-
.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError
292+
.. autoclass:: pytorch_lightning.metrics.classification.FBeta
274293
:noindex:
275294

295+
Precision
296+
~~~~~~~~~
276297

277-
MeanSquaredLogError
278-
~~~~~~~~~~~~~~~~~~~
279-
280-
.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError
298+
.. autoclass:: pytorch_lightning.metrics.classification.Precision
281299
:noindex:
282300

301+
PrecisionRecallCurve
302+
~~~~~~~~~~~~~~~~~~~~
283303

284-
ExplainedVariance
285-
~~~~~~~~~~~~~~~~~
286-
287-
.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
304+
.. autoclass:: pytorch_lightning.metrics.classification.PrecisionRecallCurve
288305
:noindex:
289306

307+
Recall
308+
~~~~~~
290309

291-
PSNR
292-
~~~~
293-
294-
.. autoclass:: pytorch_lightning.metrics.regression.PSNR
310+
.. autoclass:: pytorch_lightning.metrics.classification.Recall
295311
:noindex:
296312

313+
ROC
314+
~~~
297315

298-
SSIM
299-
~~~~
300-
301-
.. autoclass:: pytorch_lightning.metrics.regression.SSIM
316+
.. autoclass:: pytorch_lightning.metrics.classification.ROC
302317
:noindex:
303318

304-
******************
305-
Functional Metrics
306-
******************
307-
308-
The functional metrics follow the simple paradigm input in, output out. This means, they don't provide any advanced mechanisms for syncing across DDP nodes or aggregation over batches. They simply compute the metric value based on the given inputs.
309-
310-
Also the integration within other parts of PyTorch Lightning will never be as tight as with the class-based interface.
311-
If you look for just computing the values, the functional metrics are the way to go. However, if you are looking for the best integration and user experience, please consider also to use the class interface.
312319

313-
Classification
314-
--------------
320+
Functional Metrics (Classification)
321+
-----------------------------------
315322

316323
accuracy [func]
317324
~~~~~~~~~~~~~~~
@@ -417,6 +424,12 @@ recall [func]
417424
.. autofunction:: pytorch_lightning.metrics.functional.classification.recall
418425
:noindex:
419426

427+
select_topk [func]
428+
~~~~~~~~~~~~~~~~~~~~~
429+
430+
.. autofunction:: pytorch_lightning.metrics.utils.select_topk
431+
:noindex:
432+
420433

421434
stat_scores [func]
422435
~~~~~~~~~~~~~~~~~~
@@ -445,9 +458,57 @@ to_onehot [func]
445458
.. autofunction:: pytorch_lightning.metrics.utils.to_onehot
446459
:noindex:
447460

461+
******************
462+
Regression Metrics
463+
******************
464+
465+
Class Metrics (Regression)
466+
--------------------------
448467

449-
Regression
450-
----------
468+
ExplainedVariance
469+
~~~~~~~~~~~~~~~~~
470+
471+
.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
472+
:noindex:
473+
474+
475+
MeanAbsoluteError
476+
~~~~~~~~~~~~~~~~~
477+
478+
.. autoclass:: pytorch_lightning.metrics.regression.MeanAbsoluteError
479+
:noindex:
480+
481+
482+
MeanSquaredError
483+
~~~~~~~~~~~~~~~~
484+
485+
.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredError
486+
:noindex:
487+
488+
489+
MeanSquaredLogError
490+
~~~~~~~~~~~~~~~~~~~
491+
492+
.. autoclass:: pytorch_lightning.metrics.regression.MeanSquaredLogError
493+
:noindex:
494+
495+
496+
PSNR
497+
~~~~
498+
499+
.. autoclass:: pytorch_lightning.metrics.regression.PSNR
500+
:noindex:
501+
502+
503+
SSIM
504+
~~~~
505+
506+
.. autoclass:: pytorch_lightning.metrics.regression.SSIM
507+
:noindex:
508+
509+
510+
Functional Metrics (Regression)
511+
-------------------------------
451512

452513
explained_variance [func]
453514
~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -470,17 +531,17 @@ mean_squared_error [func]
470531
:noindex:
471532

472533

473-
psnr [func]
474-
~~~~~~~~~~~
534+
mean_squared_log_error [func]
535+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
475536

476-
.. autofunction:: pytorch_lightning.metrics.functional.psnr
537+
.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error
477538
:noindex:
478539

479540

480-
mean_squared_log_error [func]
481-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
541+
psnr [func]
542+
~~~~~~~~~~~
482543

483-
.. autofunction:: pytorch_lightning.metrics.functional.mean_squared_log_error
544+
.. autofunction:: pytorch_lightning.metrics.functional.psnr
484545
:noindex:
485546

486547

@@ -490,22 +551,22 @@ ssim [func]
490551
.. autofunction:: pytorch_lightning.metrics.functional.ssim
491552
:noindex:
492553

493-
554+
***
494555
NLP
495-
---
556+
***
496557

497558
bleu_score [func]
498-
~~~~~~~~~~~~~~~~~
559+
-----------------
499560

500561
.. autofunction:: pytorch_lightning.metrics.functional.nlp.bleu_score
501562
:noindex:
502563

503-
564+
********
504565
Pairwise
505-
--------
566+
********
506567

507568
embedding_similarity [func]
508-
~~~~~~~~~~~~~~~~~~~~~~~~~~~
569+
---------------------------
509570

510571
.. autofunction:: pytorch_lightning.metrics.functional.self_supervised.embedding_similarity
511572
:noindex:

0 commit comments

Comments
 (0)