From 25fbe37bc9c10f40e3d30f99583b48cb4c1c080a Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Dec 2020 21:15:01 +0100 Subject: [PATCH 1/9] update --- pytorch_lightning/metrics/metric.py | 26 ++++++++++++--- pytorch_lightning/utilities/distributed.py | 11 +++++++ tests/metrics/test_ddp.py | 38 ++++++++++++++++++++++ tests/metrics/test_metric.py | 37 +++++++++++++++++++++ 4 files changed, 107 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 0f61b94c55139..413ef08354311 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -22,7 +22,7 @@ from pytorch_lightning.metrics.utils import _flatten, dim_zero_cat, dim_zero_mean, dim_zero_sum from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.distributed import gather_all_tensors +from pytorch_lightning.utilities.distributed import gather_all_tensors, all_except_rank_zero class Metric(nn.Module, ABC): @@ -56,6 +56,9 @@ class Metric(nn.Module, ABC): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be automatically called after each ``compute()``. + Disable to calculate running accumulated metrics. """ def __init__( self, @@ -63,6 +66,7 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__() @@ -70,12 +74,14 @@ def __init__( self.compute_on_step = compute_on_step self.process_group = process_group self.dist_sync_fn = dist_sync_fn + self.auto_reset_on_compute = auto_reset_on_compute self._to_sync = True self.update = self._wrap_update(self.update) self.compute = self._wrap_compute(self.compute) self._computed = None self._forward_cache = None + self._internal_reset = False # initialize state self._defaults = {} @@ -217,6 +223,7 @@ def wrapped_func(*args, **kwargs): self._sync_dist(dist_sync_fn) self._computed = compute(*args, **kwargs) + self._internal_reset = self._to_sync self.reset() return self._computed @@ -238,10 +245,8 @@ def compute(self): # pylint: disable=E0202 """ pass - def reset(self): - """ - This method automatically resets the metric state variables to their default value. - """ + def _reset_to_default(self): + """ reset metric state to their default values """ for attr, default in self._defaults.items(): current_val = getattr(self, attr) if isinstance(current_val, torch.Tensor): @@ -249,6 +254,17 @@ def reset(self): else: setattr(self, attr, deepcopy(default)) + def reset(self): + """ + This method automatically resets the metric state variables to their default value. + """ + if self.auto_reset_on_compute or not self._internal_reset: + reset_fn = self._reset_to_default + else: + reset_fn = all_except_rank_zero(self._reset_to_default) + reset_fn() + self._internal_reset = False + def __getstate__(self): # ignore update and compute functions for pickling return {k: v for k, v in self.__dict__.items() if k not in ["update", "compute"]} diff --git a/pytorch_lightning/utilities/distributed.py b/pytorch_lightning/utilities/distributed.py index 9724f05247c00..b9ab32fd4c949 100644 --- a/pytorch_lightning/utilities/distributed.py +++ b/pytorch_lightning/utilities/distributed.py @@ -41,8 +41,19 @@ def wrapped_fn(*args, **kwargs): return wrapped_fn +def all_except_rank_zero(fn): + + @wraps(fn) + def wrapped_fn(*args, **kwargs): + if all_except_rank_zero.rank != 0: + return fn(*args, **kwargs) + + return wrapped_fn + + # add the attribute to the function but don't overwrite in case Trainer has already set it rank_zero_only.rank = getattr(rank_zero_only, 'rank', int(os.environ.get('LOCAL_RANK', 0))) +all_except_rank_zero.rank = getattr(all_except_rank_zero, 'rank', int(os.environ.get('LOCAL_RANK', 0))) def _warn(*args, **kwargs): diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py index 4cac03cc16e2b..7bdc0531adb8f 100644 --- a/tests/metrics/test_ddp.py +++ b/tests/metrics/test_ddp.py @@ -69,3 +69,41 @@ def compute(self): def test_non_contiguous_tensors(): """ Test that gather_all operation works for non contiguous tensors """ torch.multiprocessing.spawn(_test_non_contiguous_tensors, args=(2,), nprocs=2) + + +def _test_running_accumulation(rank, worldsize): + class RunningMetric(Metric): + def __init__(self): + super().__init__(auto_reset_on_compute=False) + self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, x): + self.x += x + + def compute(self): + return self.x + + running_metric = RunningMetric() + + for _ in range(2): + _ = running_metric(1.0) + acc_val = running_metric.compute() + + assert acc_val == 4.0, "wrong accumulation of metric states" + if rank == 0: + assert running_metric.x != torch.tensor(0.) + else: + assert running_metric.x == torch.tensor(0.) + + for _ in range(2): + val = running_metric(1.0) + acc_val = running_metric.compute() + + assert acc_val == 8.0, "wrong accumulation of metric states" + if rank == 0: + assert running_metric.x != torch.tensor(0.) + else: + assert running_metric.x == torch.tensor(0.) + + running_metric.reset() + assert running_metric.x == torch.tensor(0.), "metric state should have been reset" diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index d97cd1a176cf2..1a5008141537f 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -178,3 +178,40 @@ def test_state_dict(tmpdir): assert metric.state_dict() == OrderedDict(x=0) metric.persistent(False) assert metric.state_dict() == OrderedDict() + + +def test_running_accumulation(tmpdir): + class RunningMetric(Metric): + def __init__(self): + super().__init__(auto_reset_on_compute=False) + self.add_state("x", torch.tensor(0.0), dist_reduce_fx="sum") + + def update(self, x): + self.x += x + + def compute(self): + return self.x + + running_metric = RunningMetric() + + vals = [] + for _ in range(2): + val = running_metric(torch.rand(1).squeeze()) + vals.append(val) + acc_val = running_metric.compute() + + assert sum(vals) == acc_val, "wrong accumulation of metric states" + assert running_metric.x != torch.tensor(0.), "metric state should not have been reset" + + for _ in range(2): + val = running_metric(torch.rand(1).squeeze()) + vals.append(val) + acc_val = running_metric.compute() + + assert sum(vals) == acc_val, "wrong accumulation of metric states" + assert running_metric.x != torch.tensor(0.), "metric state should not have been reset" + + running_metric.reset() + assert running_metric.x == torch.tensor(0.), "metric state should have been reset" + + From 3a2a62614cc7e20451833b69312def321b5b42d9 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Dec 2020 21:18:47 +0100 Subject: [PATCH 2/9] fix --- tests/metrics/test_ddp.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py index 7bdc0531adb8f..090d848c7e883 100644 --- a/tests/metrics/test_ddp.py +++ b/tests/metrics/test_ddp.py @@ -107,3 +107,9 @@ def compute(self): running_metric.reset() assert running_metric.x == torch.tensor(0.), "metric state should have been reset" + + +@pytest.mark.skipif(sys.platform == "win32", reason="DDP not available on windows") +def test_running_accumulation(): + """ Test that disabling automatic reset on compute works in ddp""" + torch.multiprocessing.spawn(_test_running_accumulation, args=(2,), nprocs=2) From 9c66e58930441e01009f9a9e655ce8ee2e5518a1 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Dec 2020 21:26:53 +0100 Subject: [PATCH 3/9] fix --- tests/metrics/test_ddp.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/metrics/test_ddp.py b/tests/metrics/test_ddp.py index 090d848c7e883..f967f825c9bf2 100644 --- a/tests/metrics/test_ddp.py +++ b/tests/metrics/test_ddp.py @@ -72,6 +72,8 @@ def test_non_contiguous_tensors(): def _test_running_accumulation(rank, worldsize): + setup_ddp(rank, worldsize) + class RunningMetric(Metric): def __init__(self): super().__init__(auto_reset_on_compute=False) From f6ec7b3190d5b8696a8f9c33be2ad063f3bd0bdf Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Dec 2020 21:41:36 +0100 Subject: [PATCH 4/9] fix --- tests/metrics/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index c607a466b2068..53c96a418b93a 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -23,6 +23,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "8088" + os.environ["LOCAL_RANK"] = rank if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 640a8e02ffa90247f66208d0a0ed64a24022d750 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Dec 2020 21:42:59 +0100 Subject: [PATCH 5/9] fix --- tests/metrics/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 53c96a418b93a..70d7133a826c2 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -23,7 +23,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "8088" - os.environ["LOCAL_RANK"] = rank + os.environ["LOCAL_RANK"] = f"{rank}" if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 9d1fa0fcb1ffa152763a0e05a18f649c928ba4d3 Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Fri, 18 Dec 2020 22:01:13 +0100 Subject: [PATCH 6/9] fix --- tests/metrics/utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/metrics/utils.py b/tests/metrics/utils.py index 70d7133a826c2..a49bff6805171 100644 --- a/tests/metrics/utils.py +++ b/tests/metrics/utils.py @@ -10,6 +10,7 @@ from torch.multiprocessing import Pool, set_start_method from pytorch_lightning.metrics import Metric +from pytorch_lightning.utilities.distributed import all_except_rank_zero NUM_PROCESSES = 2 NUM_BATCHES = 10 @@ -23,7 +24,7 @@ def setup_ddp(rank, world_size): """ Setup ddp enviroment """ os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = "8088" - os.environ["LOCAL_RANK"] = f"{rank}" + all_except_rank_zero.rank = rank if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) From 10922db03840ba5d9fd45714585188c3e4e7a72f Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 19 Dec 2020 21:44:50 +0100 Subject: [PATCH 7/9] add flag to metrics --- .../metrics/classification/accuracy.py | 5 +++++ .../metrics/classification/average_precision.py | 10 ++++++++++ .../metrics/classification/confusion_matrix.py | 10 ++++++++++ .../metrics/classification/f_beta.py | 14 +++++++++++++- .../metrics/classification/precision_recall.py | 10 ++++++++++ .../classification/precision_recall_curve.py | 10 ++++++++++ pytorch_lightning/metrics/classification/roc.py | 10 ++++++++++ pytorch_lightning/metrics/metric.py | 4 ++-- .../metrics/regression/explained_variance.py | 8 ++++++++ .../metrics/regression/mean_absolute_error.py | 8 ++++++++ .../metrics/regression/mean_squared_error.py | 8 ++++++++ .../regression/mean_squared_log_error.py | 8 ++++++++ pytorch_lightning/metrics/regression/psnr.py | 10 ++++++++++ pytorch_lightning/metrics/regression/ssim.py | 17 +++++++++++++++++ 14 files changed, 129 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/metrics/classification/accuracy.py b/pytorch_lightning/metrics/classification/accuracy.py index 330691a379574..1b3fb265b43e5 100644 --- a/pytorch_lightning/metrics/classification/accuracy.py +++ b/pytorch_lightning/metrics/classification/accuracy.py @@ -53,6 +53,9 @@ class Accuracy(Metric): dist_sync_fn: Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example: @@ -71,12 +74,14 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index 33878cb48965d..1af75b29e3d9c 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -50,6 +50,12 @@ class AveragePrecision(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example (binary case): @@ -78,11 +84,15 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.num_classes = num_classes diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py index b9b0c20e9e30e..518d4afa2371f 100644 --- a/pytorch_lightning/metrics/classification/confusion_matrix.py +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -61,6 +61,12 @@ class ConfusionMatrix(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example: @@ -81,12 +87,16 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.num_classes = num_classes self.normalize = normalize diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index fadfd000ebbe1..897f4368041c7 100755 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -66,6 +66,12 @@ class FBeta(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example: @@ -88,9 +94,15 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( - compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + compute_on_step=compute_on_step, + dist_sync_on_step=dist_sync_on_step, + process_group=process_group, + dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.num_classes = num_classes diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 7e1f843b9c331..905694b01454f 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -58,6 +58,12 @@ class Precision(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example: @@ -78,11 +84,15 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.num_classes = num_classes diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index 620904898535d..ea9d167ce35f5 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -50,6 +50,12 @@ class PrecisionRecallCurve(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example (binary case): @@ -88,11 +94,15 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.num_classes = num_classes diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 2b7d82488b491..25167bc615d6d 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -50,6 +50,12 @@ class ROC(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example (binary case): @@ -91,11 +97,15 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.num_classes = num_classes diff --git a/pytorch_lightning/metrics/metric.py b/pytorch_lightning/metrics/metric.py index 413ef08354311..7258ff366417e 100644 --- a/pytorch_lightning/metrics/metric.py +++ b/pytorch_lightning/metrics/metric.py @@ -57,8 +57,8 @@ class Metric(nn.Module, ABC): Callback that performs the allgather operation on the metric state. When `None`, DDP will be used to perform the allgather. default: None auto_reset_on_compute: - Specify if ``reset()`` should be automatically called after each ``compute()``. - Disable to calculate running accumulated metrics. + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True """ def __init__( self, diff --git a/pytorch_lightning/metrics/regression/explained_variance.py b/pytorch_lightning/metrics/regression/explained_variance.py index 2b98a7b988052..3515935500047 100644 --- a/pytorch_lightning/metrics/regression/explained_variance.py +++ b/pytorch_lightning/metrics/regression/explained_variance.py @@ -57,6 +57,12 @@ class ExplainedVariance(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example: @@ -81,12 +87,14 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) allowed_multioutput = ('raw_values', 'uniform_average', 'variance_weighted') if multioutput not in allowed_multioutput: diff --git a/pytorch_lightning/metrics/regression/mean_absolute_error.py b/pytorch_lightning/metrics/regression/mean_absolute_error.py index 8d0929a5792c3..e7fdb83974757 100644 --- a/pytorch_lightning/metrics/regression/mean_absolute_error.py +++ b/pytorch_lightning/metrics/regression/mean_absolute_error.py @@ -37,6 +37,12 @@ class MeanAbsoluteError(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example: @@ -54,12 +60,14 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.add_state("sum_abs_error", default=torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/regression/mean_squared_error.py b/pytorch_lightning/metrics/regression/mean_squared_error.py index b92c946715554..5efa1ed6b6827 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_error.py @@ -37,6 +37,12 @@ class MeanSquaredError(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example: @@ -55,12 +61,14 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/regression/mean_squared_log_error.py b/pytorch_lightning/metrics/regression/mean_squared_log_error.py index 8beda8036dc5f..c9d3dfd795c29 100644 --- a/pytorch_lightning/metrics/regression/mean_squared_log_error.py +++ b/pytorch_lightning/metrics/regression/mean_squared_log_error.py @@ -39,6 +39,12 @@ class MeanSquaredLogError(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example: @@ -57,12 +63,14 @@ def __init__( dist_sync_on_step: bool = False, process_group: Optional[Any] = None, dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.add_state("sum_squared_log_error", default=torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index db0b5deae6ddc..8dd5b8ceac085 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -46,6 +46,12 @@ class PSNR(Metric): before returning the value at the step. default: False process_group: Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Example: @@ -66,11 +72,15 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) self.add_state("sum_squared_error", default=torch.tensor(0.0), dist_reduce_fx="sum") diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py index e401f42c4d40c..f2833477bdb4a 100644 --- a/pytorch_lightning/metrics/regression/ssim.py +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -36,6 +36,19 @@ class SSIM(Metric): data_range: Range of the image. If ``None``, it is determined from the image (max - min) k1: Parameter of SSIM. Default: 0.01 k2: Parameter of SSIM. Default: 0.03 + compute_on_step: + Forward only calls ``update()`` and return None if this is set to False. default: True + dist_sync_on_step: + Synchronize metric state across processes at each ``forward()`` + before returning the value at the step. default: False + process_group: + Specify the process group on which synchronization is called. default: None (which selects the entire world) + dist_sync_fn: + Callback that performs the allgather operation on the metric state. When `None`, DDP + will be used to perform the allgather. default: None + auto_reset_on_compute: + Specify if ``reset()`` should be called automatically after each ``compute()``. + Disabling it allows for calculate running accumulated metrics. default: True Return: Tensor with SSIM score @@ -60,11 +73,15 @@ def __init__( compute_on_step: bool = True, dist_sync_on_step: bool = False, process_group: Optional[Any] = None, + dist_sync_fn: Callable = None, + auto_reset_on_compute: bool = True ): super().__init__( compute_on_step=compute_on_step, dist_sync_on_step=dist_sync_on_step, process_group=process_group, + dist_sync_fn=dist_sync_fn, + auto_reset_on_compute=auto_reset_on_compute ) rank_zero_warn( 'Metric `SSIM` will save all targets and' From 8d0e7d98db8fa4ccedb6d6ef526e42b35daaf23b Mon Sep 17 00:00:00 2001 From: Nicki Skafte Date: Sat, 19 Dec 2020 22:19:43 +0100 Subject: [PATCH 8/9] fix --- pytorch_lightning/metrics/classification/average_precision.py | 2 +- pytorch_lightning/metrics/classification/confusion_matrix.py | 2 +- pytorch_lightning/metrics/classification/f_beta.py | 2 +- pytorch_lightning/metrics/classification/precision_recall.py | 2 +- .../metrics/classification/precision_recall_curve.py | 2 +- pytorch_lightning/metrics/classification/roc.py | 2 +- pytorch_lightning/metrics/regression/psnr.py | 2 +- pytorch_lightning/metrics/regression/ssim.py | 2 +- tests/metrics/test_metric.py | 2 -- 9 files changed, 8 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/metrics/classification/average_precision.py b/pytorch_lightning/metrics/classification/average_precision.py index 1af75b29e3d9c..0ececc3f7374b 100644 --- a/pytorch_lightning/metrics/classification/average_precision.py +++ b/pytorch_lightning/metrics/classification/average_precision.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any, Union, List +from typing import Any, Callable, List, Optional, Union import torch diff --git a/pytorch_lightning/metrics/classification/confusion_matrix.py b/pytorch_lightning/metrics/classification/confusion_matrix.py index 518d4afa2371f..55537654975f0 100644 --- a/pytorch_lightning/metrics/classification/confusion_matrix.py +++ b/pytorch_lightning/metrics/classification/confusion_matrix.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Callable, Optional import torch diff --git a/pytorch_lightning/metrics/classification/f_beta.py b/pytorch_lightning/metrics/classification/f_beta.py index 897f4368041c7..f5c079af0e63d 100755 --- a/pytorch_lightning/metrics/classification/f_beta.py +++ b/pytorch_lightning/metrics/classification/f_beta.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Callable, Optional import torch diff --git a/pytorch_lightning/metrics/classification/precision_recall.py b/pytorch_lightning/metrics/classification/precision_recall.py index 905694b01454f..a989efff1af45 100644 --- a/pytorch_lightning/metrics/classification/precision_recall.py +++ b/pytorch_lightning/metrics/classification/precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Optional +from typing import Any, Callable, Optional import torch diff --git a/pytorch_lightning/metrics/classification/precision_recall_curve.py b/pytorch_lightning/metrics/classification/precision_recall_curve.py index ea9d167ce35f5..91088e4226f3b 100644 --- a/pytorch_lightning/metrics/classification/precision_recall_curve.py +++ b/pytorch_lightning/metrics/classification/precision_recall_curve.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any, Union, Tuple, List +from typing import Any, Callable, List, Optional, Tuple, Union import torch diff --git a/pytorch_lightning/metrics/classification/roc.py b/pytorch_lightning/metrics/classification/roc.py index 25167bc615d6d..c44982d49b3e5 100644 --- a/pytorch_lightning/metrics/classification/roc.py +++ b/pytorch_lightning/metrics/classification/roc.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Any, Union, List, Tuple +from typing import Any, Callable, List, Optional, Tuple, Union import torch diff --git a/pytorch_lightning/metrics/regression/psnr.py b/pytorch_lightning/metrics/regression/psnr.py index 8dd5b8ceac085..4673718254e9f 100644 --- a/pytorch_lightning/metrics/regression/psnr.py +++ b/pytorch_lightning/metrics/regression/psnr.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from typing import Any, Optional +from typing import Any, Callable, Optional from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.metrics.functional.psnr import ( diff --git a/pytorch_lightning/metrics/regression/ssim.py b/pytorch_lightning/metrics/regression/ssim.py index f2833477bdb4a..766c807b48f60 100644 --- a/pytorch_lightning/metrics/regression/ssim.py +++ b/pytorch_lightning/metrics/regression/ssim.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import torch -from typing import Any, Optional, Sequence +from typing import Any, Callable, Optional, Sequence from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_warn diff --git a/tests/metrics/test_metric.py b/tests/metrics/test_metric.py index 1a5008141537f..1257181ec4822 100644 --- a/tests/metrics/test_metric.py +++ b/tests/metrics/test_metric.py @@ -213,5 +213,3 @@ def compute(self): running_metric.reset() assert running_metric.x == torch.tensor(0.), "metric state should have been reset" - - From d02499ab46c8ccb338c292c52f8b8aa2893342ae Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 6 Jan 2021 21:30:49 +0100 Subject: [PATCH 9/9] chlog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index df06c22d8b02a..e2ab594b05a52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `R2Score` metric ([#5241](https://github.com/PyTorchLightning/pytorch-lightning/pull/5241)) +- Added metrics' arguments allowing for running accumulation computations ([#5193](https://github.com/PyTorchLightning/pytorch-lightning/pull/5193)) + ### Changed