Skip to content

Commit 9dbdffc

Browse files
SkafteNickiteddykokerrohitgr7
authored
[Metrics] R2Score (#5241)
* add r2metric * change init * add test * add docs * add math * Apply suggestions from code review Co-authored-by: Teddy Koker <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> * changelog * adjusted parameter * add more test * pep8 * Apply suggestions from code review Co-authored-by: Rohit Gupta <[email protected]> * add warnings for adjusted score Co-authored-by: Teddy Koker <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 73e06fd commit 9dbdffc

File tree

8 files changed

+400
-4
lines changed

8 files changed

+400
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919

2020
- `StatScores` metric to compute the number of true positives, false positives, true negatives and false negatives ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))
2121

22+
- Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241))
23+
2224

2325
### Changed
2426

docs/source/metrics.rst

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ Classification Metrics
212212
Input types
213213
-----------
214214

215-
For the purposes of classification metrics, inputs (predictions and targets) are split
215+
For the purposes of classification metrics, inputs (predictions and targets) are split
216216
into these categories (``N`` stands for the batch size and ``C`` for number of classes):
217217

218218
.. csv-table:: \*dtype ``binary`` means integers that are either 0 or 1
@@ -227,10 +227,10 @@ into these categories (``N`` stands for the batch size and ``C`` for number of c
227227
"Multi-dimensional multi-class with probabilities", "(N, C, ...)", "``float``", "(N, ...)", "``int``"
228228

229229
.. note::
230-
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
230+
All dimensions of size 1 (except ``N``) are "squeezed out" at the beginning, so
231231
that, for example, a tensor of shape ``(N, 1)`` is treated as ``(N, )``.
232232

233-
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
233+
When predictions or targets are integers, it is assumed that class labels start at 0, i.e.
234234
the possible class labels are 0, 1, 2, 3, etc. Below are some examples of different input types
235235

236236
.. testcode::
@@ -573,6 +573,12 @@ SSIM
573573
:noindex:
574574

575575

576+
R2Score
577+
~~~~~~~
578+
579+
.. autoclass:: pytorch_lightning.metrics.regression.R2Score
580+
:noindex:
581+
576582
Functional Metrics (Regression)
577583
-------------------------------
578584

@@ -617,6 +623,13 @@ ssim [func]
617623
.. autofunction:: pytorch_lightning.metrics.functional.ssim
618624
:noindex:
619625

626+
r2score [func]
627+
~~~~~~~~~~~~~~
628+
629+
.. autofunction:: pytorch_lightning.metrics.functional.r2score
630+
:noindex:
631+
632+
620633
***
621634
NLP
622635
***

pytorch_lightning/metrics/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,5 @@
3434
ExplainedVariance,
3535
PSNR,
3636
SSIM,
37+
R2Score
3738
)

pytorch_lightning/metrics/functional/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
to_categorical,
2929
to_onehot,
3030
)
31-
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
3231
# TODO: unify metrics between class and functional, add below
3332
from pytorch_lightning.metrics.functional.accuracy import accuracy # noqa: F401
33+
from pytorch_lightning.metrics.functional.confusion_matrix import confusion_matrix # noqa: F401
3434
from pytorch_lightning.metrics.functional.explained_variance import explained_variance # noqa: F401
3535
from pytorch_lightning.metrics.functional.f_beta import fbeta, f1 # noqa: F401
3636
from pytorch_lightning.metrics.functional.hamming_distance import hamming_distance # noqa: F401
@@ -40,6 +40,7 @@
4040
from pytorch_lightning.metrics.functional.nlp import bleu_score # noqa: F401
4141
from pytorch_lightning.metrics.functional.precision_recall_curve import precision_recall_curve # noqa: F401
4242
from pytorch_lightning.metrics.functional.psnr import psnr # noqa: F401
43+
from pytorch_lightning.metrics.functional.r2score import r2score # noqa: F401
4344
from pytorch_lightning.metrics.functional.roc import roc # noqa: F401
4445
from pytorch_lightning.metrics.functional.self_supervised import embedding_similarity # noqa: F401
4546
from pytorch_lightning.metrics.functional.ssim import ssim # noqa: F401
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Tuple
15+
16+
import torch
17+
18+
from pytorch_lightning.metrics.utils import _check_same_shape
19+
from pytorch_lightning.utilities import rank_zero_warn
20+
21+
22+
def _r2score_update(
23+
preds: torch.tensor,
24+
target: torch.Tensor,
25+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
26+
_check_same_shape(preds, target)
27+
if preds.ndim > 2:
28+
raise ValueError('Expected both prediction and target to be 1D or 2D tensors,'
29+
f' but recevied tensors with dimension {preds.shape}')
30+
if len(preds) < 2:
31+
raise ValueError('Needs atleast two samples to calculate r2 score.')
32+
33+
sum_error = torch.sum(target, dim=0)
34+
sum_squared_error = torch.sum(torch.pow(target, 2.0), dim=0)
35+
residual = torch.sum(torch.pow(target - preds, 2.0), dim=0)
36+
total = target.size(0)
37+
38+
return sum_squared_error, sum_error, residual, total
39+
40+
41+
def _r2score_compute(sum_squared_error: torch.Tensor,
42+
sum_error: torch.Tensor,
43+
residual: torch.Tensor,
44+
total: torch.Tensor,
45+
adjusted: int = 0,
46+
multioutput: str = "uniform_average") -> torch.Tensor:
47+
mean_error = sum_error / total
48+
diff = sum_squared_error - sum_error * mean_error
49+
raw_scores = 1 - (residual / diff)
50+
51+
if multioutput == "raw_values":
52+
r2score = raw_scores
53+
elif multioutput == "uniform_average":
54+
r2score = torch.mean(raw_scores)
55+
elif multioutput == "variance_weighted":
56+
diff_sum = torch.sum(diff)
57+
r2score = torch.sum(diff / diff_sum * raw_scores)
58+
else:
59+
raise ValueError('Argument `multioutput` must be either `raw_values`,'
60+
f' `uniform_average` or `variance_weighted`. Received {multioutput}.')
61+
62+
if adjusted < 0 or not isinstance(adjusted, int):
63+
raise ValueError('`adjusted` parameter should be an integer larger or'
64+
' equal to 0.')
65+
66+
if adjusted != 0:
67+
if adjusted > total - 1:
68+
rank_zero_warn("More independent regressions than datapoints in"
69+
" adjusted r2 score. Falls back to standard r2 score.",
70+
UserWarning)
71+
elif adjusted == total - 1:
72+
rank_zero_warn("Division by zero in adjusted r2 score. Falls back to"
73+
" standard r2 score.", UserWarning)
74+
else:
75+
r2score = 1 - (1 - r2score) * (total - 1) / (total - adjusted - 1)
76+
return r2score
77+
78+
79+
def r2score(
80+
preds: torch.Tensor,
81+
target: torch.Tensor,
82+
adjusted: int = 0,
83+
multioutput: str = "uniform_average",
84+
) -> torch.Tensor:
85+
r"""
86+
Computes r2 score also known as `coefficient of determination
87+
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:
88+
89+
.. math:: R^2 = 1 - \frac{SS_res}{SS_tot}
90+
91+
where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
92+
:math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
93+
adjusted r2 score given by
94+
95+
.. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
96+
97+
where the parameter :math:`k` (the number of independent regressors) should
98+
be provided as the ``adjusted`` argument.
99+
100+
Args:
101+
pred: estimated labels
102+
target: ground truth labels
103+
adjusted: number of independent regressors for calculating adjusted r2 score.
104+
Default 0 (standard r2 score).
105+
multioutput: Defines aggregation in the case of multiple output scores. Can be one
106+
of the following strings (default is ``'uniform_average'``.):
107+
108+
* ``'raw_values'`` returns full set of scores
109+
* ``'uniform_average'`` scores are uniformly averaged
110+
* ``'variance_weighted'`` scores are weighted by their individual variances
111+
112+
Example:
113+
114+
>>> from pytorch_lightning.metrics.functional import r2score
115+
>>> target = torch.tensor([3, -0.5, 2, 7])
116+
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
117+
>>> r2score(preds, target)
118+
tensor(0.9486)
119+
120+
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
121+
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
122+
>>> r2score(preds, target, multioutput='raw_values')
123+
tensor([0.9654, 0.9082])
124+
"""
125+
sum_squared_error, sum_error, residual, total = _r2score_update(preds, target)
126+
return _r2score_compute(sum_squared_error, sum_error, residual, total, adjusted, multioutput)

pytorch_lightning/metrics/regression/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance # noqa: F401
1818
from pytorch_lightning.metrics.regression.psnr import PSNR # noqa: F401
1919
from pytorch_lightning.metrics.regression.ssim import SSIM # noqa: F401
20+
from pytorch_lightning.metrics.regression.r2score import R2Score # noqa: F401
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import Any, Callable, Optional
15+
16+
import torch
17+
18+
from pytorch_lightning.metrics.metric import Metric
19+
from pytorch_lightning.metrics.functional.r2score import (
20+
_r2score_update,
21+
_r2score_compute
22+
)
23+
24+
25+
class R2Score(Metric):
26+
r"""
27+
Computes r2 score also known as `coefficient of determination
28+
<https://en.wikipedia.org/wiki/Coefficient_of_determination>`_:
29+
30+
.. math:: R^2 = 1 - \frac{SS_res}{SS_tot}
31+
32+
where :math:`SS_res=\sum_i (y_i - f(x_i))^2` is the sum of residual squares, and
33+
:math:`SS_tot=\sum_i (y_i - \bar{y})^2` is total sum of squares. Can also calculate
34+
adjusted r2 score given by
35+
36+
.. math:: R^2_adj = 1 - \frac{(1-R^2)(n-1)}{n-k-1}
37+
38+
where the parameter :math:`k` (the number of independent regressors) should
39+
be provided as the `adjusted` argument.
40+
41+
Forward accepts
42+
43+
- ``preds`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput)
44+
- ``target`` (float tensor): ``(N,)`` or ``(N, M)`` (multioutput)
45+
46+
In the case of multioutput, as default the variances will be uniformly
47+
averaged over the additional dimensions. Please see argument `multioutput`
48+
for changing this behavior.
49+
50+
Args:
51+
num_outputs:
52+
Number of outputs in multioutput setting (default is 1)
53+
adjusted:
54+
number of independent regressors for calculating adjusted r2 score.
55+
Default 0 (standard r2 score).
56+
multioutput:
57+
Defines aggregation in the case of multiple output scores. Can be one
58+
of the following strings (default is ``'uniform_average'``.):
59+
60+
* ``'raw_values'`` returns full set of scores
61+
* ``'uniform_average'`` scores are uniformly averaged
62+
* ``'variance_weighted'`` scores are weighted by their individual variances
63+
64+
compute_on_step:
65+
Forward only calls ``update()`` and return None if this is set to False. default: True
66+
dist_sync_on_step:
67+
Synchronize metric state across processes at each ``forward()``
68+
before returning the value at the step. default: False
69+
process_group:
70+
Specify the process group on which synchronization is called. default: None (which selects the entire world)
71+
72+
Example:
73+
74+
>>> from pytorch_lightning.metrics import R2Score
75+
>>> target = torch.tensor([3, -0.5, 2, 7])
76+
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
77+
>>> r2score = R2Score()
78+
>>> r2score(preds, target)
79+
tensor(0.9486)
80+
81+
>>> target = torch.tensor([[0.5, 1], [-1, 1], [7, -6]])
82+
>>> preds = torch.tensor([[0, 2], [-1, 2], [8, -5]])
83+
>>> r2score = R2Score(num_outputs=2, multioutput='raw_values')
84+
>>> r2score(preds, target)
85+
tensor([0.9654, 0.9082])
86+
"""
87+
def __init__(
88+
self,
89+
num_outputs: int = 1,
90+
adjusted: int = 0,
91+
multioutput: str = "uniform_average",
92+
compute_on_step: bool = True,
93+
dist_sync_on_step: bool = False,
94+
process_group: Optional[Any] = None,
95+
dist_sync_fn: Callable = None,
96+
):
97+
super().__init__(
98+
compute_on_step=compute_on_step,
99+
dist_sync_on_step=dist_sync_on_step,
100+
process_group=process_group,
101+
dist_sync_fn=dist_sync_fn,
102+
)
103+
104+
self.num_outputs = num_outputs
105+
106+
if adjusted < 0 or not isinstance(adjusted, int):
107+
raise ValueError('`adjusted` parameter should be an integer larger or'
108+
' equal to 0.')
109+
self.adjusted = adjusted
110+
111+
allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')
112+
if multioutput not in allowed_multioutput:
113+
raise ValueError(
114+
f'Invalid input to argument `multioutput`. Choose one of the following: {allowed_multioutput}'
115+
)
116+
self.multioutput = multioutput
117+
118+
self.add_state("sum_squared_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
119+
self.add_state("sum_error", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
120+
self.add_state("residual", default=torch.zeros(self.num_outputs), dist_reduce_fx="sum")
121+
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
122+
123+
def update(self, preds: torch.Tensor, target: torch.Tensor):
124+
"""
125+
Update state with predictions and targets.
126+
127+
Args:
128+
preds: Predictions from model
129+
target: Ground truth values
130+
"""
131+
sum_squared_error, sum_error, residual, total = _r2score_update(preds, target)
132+
133+
self.sum_squared_error += sum_squared_error
134+
self.sum_error += sum_error
135+
self.residual += residual
136+
self.total += total
137+
138+
def compute(self) -> torch.Tensor:
139+
"""
140+
Computes r2 score over the metric states.
141+
"""
142+
return _r2score_compute(self.sum_squared_error, self.sum_error, self.residual,
143+
self.total, self.adjusted, self.multioutput)

0 commit comments

Comments
 (0)