Skip to content

Commit 9edf083

Browse files
authored
Merge branch 'release/1.2-dev' into refactor/legacy-accel-plug
2 parents 545276b + 8c55a08 commit 9edf083

File tree

5 files changed

+832
-14
lines changed

5 files changed

+832
-14
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
6262
- Added `PyTorchProfiler` ([#5560](https://github.com/PyTorchLightning/pytorch-lightning/pull/5560))
6363

6464

65+
- Added compositional metrics ([#5464](https://github.com/PyTorchLightning/pytorch-lightning/pull/5464))
66+
67+
6568
### Changed
6669

6770
- Changed `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))

docs/source/metrics.rst

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,51 @@ In practise this means that:
258258
val = metric(pred, target) # this value can be backpropagated
259259
val = metric.compute() # this value cannot be backpropagated
260260
261+
******************
262+
Metric Arithmetics
263+
******************
264+
265+
Metrics support most of python built-in operators for arithmetic, logic and bitwise operations.
266+
267+
For example for a metric that should return the sum of two different metrics, implementing a new metric is an overhead that is not necessary.
268+
It can now be done with:
269+
270+
.. code-block:: python
271+
272+
first_metric = MyFirstMetric()
273+
second_metric = MySecondMetric()
274+
275+
new_metric = first_metric + second_metric
276+
277+
``new_metric.update(*args, **kwargs)`` now calls update of ``first_metric`` and ``second_metric``. It forwards all positional arguments but
278+
forwards only the keyword arguments that are available in respective metric's update declaration.
279+
280+
Similarly ``new_metric.compute()`` now calls compute of ``first_metric`` and ``second_metric`` and adds the results up.
281+
282+
This pattern is implemented for the following operators (with ``a`` being metrics and ``b`` being metrics, tensors, integer or floats):
283+
284+
* Addition (``a + b``)
285+
* Bitwise AND (``a & b``)
286+
* Equality (``a == b``)
287+
* Floordivision (``a // b``)
288+
* Greater Equal (``a >= b``)
289+
* Greater (``a > b``)
290+
* Less Equal (``a <= b``)
291+
* Less (``a < b``)
292+
* Matrix Multiplication (``a @ b``)
293+
* Modulo (``a % b``)
294+
* Multiplication (``a * b``)
295+
* Inequality (``a != b``)
296+
* Bitwise OR (``a | b``)
297+
* Power (``a ** b``)
298+
* Substraction (``a - b``)
299+
* True Division (``a / b``)
300+
* Bitwise XOR (``a ^ b``)
301+
* Absolute Value (``abs(a)``)
302+
* Inversion (``~a``)
303+
* Negative Value (``neg(a)``)
304+
* Positive Value (``pos(a)``)
305+
261306
****************
262307
MetricCollection
263308
****************
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import Callable, Union
2+
3+
import torch
4+
5+
from pytorch_lightning.metrics.metric import Metric
6+
7+
8+
class CompositionalMetric(Metric):
9+
"""Composition of two metrics with a specific operator
10+
which will be executed upon metric's compute
11+
12+
"""
13+
14+
def __init__(
15+
self,
16+
operator: Callable,
17+
metric_a: Union[Metric, int, float, torch.Tensor],
18+
metric_b: Union[Metric, int, float, torch.Tensor, None],
19+
):
20+
"""
21+
22+
Args:
23+
operator: the operator taking in one (if metric_b is None)
24+
or two arguments. Will be applied to outputs of metric_a.compute()
25+
and (optionally if metric_b is not None) metric_b.compute()
26+
metric_a: first metric whose compute() result is the first argument of operator
27+
metric_b: second metric whose compute() result is the second argument of operator.
28+
For operators taking in only one input, this should be None
29+
"""
30+
super().__init__()
31+
32+
self.op = operator
33+
34+
if isinstance(metric_a, torch.Tensor):
35+
self.register_buffer("metric_a", metric_a)
36+
else:
37+
self.metric_a = metric_a
38+
39+
if isinstance(metric_b, torch.Tensor):
40+
self.register_buffer("metric_b", metric_b)
41+
else:
42+
self.metric_b = metric_b
43+
44+
def _sync_dist(self, dist_sync_fn=None):
45+
# No syncing required here. syncing will be done in metric_a and metric_b
46+
pass
47+
48+
def update(self, *args, **kwargs):
49+
if isinstance(self.metric_a, Metric):
50+
self.metric_a.update(*args, **self.metric_a._filter_kwargs(**kwargs))
51+
52+
if isinstance(self.metric_b, Metric):
53+
self.metric_b.update(*args, **self.metric_b._filter_kwargs(**kwargs))
54+
55+
def compute(self):
56+
57+
# also some parsing for kwargs?
58+
if isinstance(self.metric_a, Metric):
59+
val_a = self.metric_a.compute()
60+
else:
61+
val_a = self.metric_a
62+
63+
if isinstance(self.metric_b, Metric):
64+
val_b = self.metric_b.compute()
65+
else:
66+
val_b = self.metric_b
67+
68+
if val_b is None:
69+
return self.op(val_a)
70+
71+
return self.op(val_a, val_b)
72+
73+
def reset(self):
74+
if isinstance(self.metric_a, Metric):
75+
self.metric_a.reset()
76+
77+
if isinstance(self.metric_b, Metric):
78+
self.metric_b.reset()
79+
80+
def persistent(self, mode: bool = False):
81+
if isinstance(self.metric_a, Metric):
82+
self.metric_a.persistent(mode=mode)
83+
if isinstance(self.metric_b, Metric):
84+
self.metric_b.persistent(mode=mode)
85+
86+
def __repr__(self):
87+
repr_str = (
88+
self.__class__.__name__
89+
+ f"(\n {self.op.__name__}(\n {repr(self.metric_a)},\n {repr(self.metric_b)}\n )\n)"
90+
)
91+
92+
return repr_str

0 commit comments

Comments
 (0)