Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241))

- Added metrics' arguments allowing for running accumulation computations ([#5193](https://github.com/PyTorchLightning/pytorch-lightning/pull/5193))


### Changed

Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/metrics/classification/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class Accuracy(Metric):
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When ``None``, DDP
will be used to perform the allgather
auto_reset_on_compute:
Specify if ``reset()`` should be called automatically after each ``compute()``.
Disabling it allows for calculate running accumulated metrics. default: True

Example:

Expand Down Expand Up @@ -104,12 +107,14 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
auto_reset_on_compute: bool = True
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
auto_reset_on_compute=auto_reset_on_compute
)

self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/metrics/classification/average_precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Any, Union, List
from typing import Any, Callable, List, Optional, Union

import torch

Expand Down Expand Up @@ -50,6 +50,12 @@ class AveragePrecision(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
auto_reset_on_compute:
Specify if ``reset()`` should be called automatically after each ``compute()``.
Disabling it allows for calculate running accumulated metrics. default: True

Example (binary case):

Expand Down Expand Up @@ -78,11 +84,15 @@ def __init__(
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
auto_reset_on_compute: bool = True
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
auto_reset_on_compute=auto_reset_on_compute
)

self.num_classes = num_classes
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/metrics/classification/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Callable, Optional

import torch

Expand Down Expand Up @@ -61,6 +61,12 @@ class ConfusionMatrix(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
auto_reset_on_compute:
Specify if ``reset()`` should be called automatically after each ``compute()``.
Disabling it allows for calculate running accumulated metrics. default: True

Example:

Expand All @@ -81,12 +87,16 @@ def __init__(
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
auto_reset_on_compute: bool = True
):

super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
auto_reset_on_compute=auto_reset_on_compute
)
self.num_classes = num_classes
self.normalize = normalize
Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/metrics/classification/f_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Callable, Optional

import torch

Expand Down Expand Up @@ -66,6 +66,12 @@ class FBeta(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
auto_reset_on_compute:
Specify if ``reset()`` should be called automatically after each ``compute()``.
Disabling it allows for calculate running accumulated metrics. default: True

Example:

Expand All @@ -88,9 +94,15 @@ def __init__(
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
auto_reset_on_compute: bool = True
):
super().__init__(
compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group,
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
auto_reset_on_compute=auto_reset_on_compute
)

self.num_classes = num_classes
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/metrics/classification/precision_recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any, Callable, Optional

import torch

Expand Down Expand Up @@ -58,6 +58,12 @@ class Precision(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
auto_reset_on_compute:
Specify if ``reset()`` should be called automatically after each ``compute()``.
Disabling it allows for calculate running accumulated metrics. default: True

Example:

Expand All @@ -78,11 +84,15 @@ def __init__(
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
auto_reset_on_compute: bool = True
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
auto_reset_on_compute=auto_reset_on_compute
)

self.num_classes = num_classes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Any, Union, Tuple, List
from typing import Any, Callable, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -50,6 +50,12 @@ class PrecisionRecallCurve(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
auto_reset_on_compute:
Specify if ``reset()`` should be called automatically after each ``compute()``.
Disabling it allows for calculate running accumulated metrics. default: True

Example (binary case):

Expand Down Expand Up @@ -89,11 +95,15 @@ def __init__(
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
auto_reset_on_compute: bool = True
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
auto_reset_on_compute=auto_reset_on_compute
)

self.num_classes = num_classes
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/metrics/classification/roc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Any, Union, List, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union

import torch

Expand Down Expand Up @@ -50,6 +50,12 @@ class ROC(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
auto_reset_on_compute:
Specify if ``reset()`` should be called automatically after each ``compute()``.
Disabling it allows for calculate running accumulated metrics. default: True

Example (binary case):

Expand Down Expand Up @@ -91,11 +97,15 @@ def __init__(
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
auto_reset_on_compute: bool = True
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
auto_reset_on_compute=auto_reset_on_compute
)

self.num_classes = num_classes
Expand Down
26 changes: 21 additions & 5 deletions pytorch_lightning/metrics/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import gather_all_tensors
from pytorch_lightning.utilities.distributed import gather_all_tensors, all_except_rank_zero


class Metric(nn.Module, ABC):
Expand Down Expand Up @@ -56,26 +56,32 @@ class Metric(nn.Module, ABC):
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
auto_reset_on_compute:
Specify if ``reset()`` should be called automatically after each ``compute()``.
Disabling it allows for calculate running accumulated metrics. default: True
"""
def __init__(
self,
compute_on_step: bool = True,
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
auto_reset_on_compute: bool = True
):
super().__init__()

self.dist_sync_on_step = dist_sync_on_step
self.compute_on_step = compute_on_step
self.process_group = process_group
self.dist_sync_fn = dist_sync_fn
self.auto_reset_on_compute = auto_reset_on_compute
self._to_sync = True

self.update = self._wrap_update(self.update)
self.compute = self._wrap_compute(self.compute)
self._computed = None
self._forward_cache = None
self._internal_reset = False

# initialize state
self._defaults = {}
Expand Down Expand Up @@ -217,6 +223,7 @@ def wrapped_func(*args, **kwargs):
self._sync_dist(dist_sync_fn)

self._computed = compute(*args, **kwargs)
self._internal_reset = self._to_sync
self.reset()

return self._computed
Expand All @@ -238,17 +245,26 @@ def compute(self): # pylint: disable=E0202
"""
pass

def reset(self):
"""
This method automatically resets the metric state variables to their default value.
"""
def _reset_to_default(self):
""" reset metric state to their default values """
for attr, default in self._defaults.items():
current_val = getattr(self, attr)
if isinstance(current_val, torch.Tensor):
setattr(self, attr, deepcopy(default).to(current_val.device))
else:
setattr(self, attr, deepcopy(default))

def reset(self):
"""
This method automatically resets the metric state variables to their default value.
"""
if self.auto_reset_on_compute or not self._internal_reset:
reset_fn = self._reset_to_default
else:
reset_fn = all_except_rank_zero(self._reset_to_default)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this make it not reset on rank 0? Why would we not want to reset on rank 0?

reset_fn()
self._internal_reset = False

def __getstate__(self):
# ignore update and compute functions for pickling
return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]}
Expand Down
8 changes: 8 additions & 0 deletions pytorch_lightning/metrics/regression/explained_variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ class ExplainedVariance(Metric):
before returning the value at the step. default: False
process_group:
Specify the process group on which synchronization is called. default: None (which selects the entire world)
dist_sync_fn:
Callback that performs the allgather operation on the metric state. When `None`, DDP
will be used to perform the allgather. default: None
auto_reset_on_compute:
Specify if ``reset()`` should be called automatically after each ``compute()``.
Disabling it allows for calculate running accumulated metrics. default: True

Example:

Expand All @@ -81,12 +87,14 @@ def __init__(
dist_sync_on_step: bool = False,
process_group: Optional[Any] = None,
dist_sync_fn: Callable = None,
auto_reset_on_compute: bool = True
):
super().__init__(
compute_on_step=compute_on_step,
dist_sync_on_step=dist_sync_on_step,
process_group=process_group,
dist_sync_fn=dist_sync_fn,
auto_reset_on_compute=auto_reset_on_compute
)
allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted')
if multioutput not in allowed_multioutput:
Expand Down
Loading