Skip to content

Commit b74f8ac

Browse files
authored
Use apply_to_collection in metrics_to_scalars (#7888)
* Use `apply_to_collection` in `metrics_to_scalars` * Typing * Update CHANGELOG * Update pytorch_lightning/utilities/metrics.py * Whitespace
1 parent 0fda862 commit b74f8ac

File tree

3 files changed

+17
-19
lines changed

3 files changed

+17
-19
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8787
- Changed these `Trainer` methods to be protected: `call_setup_hook`, `call_configure_sharded_model`, `pre_dispatch`, `dispatch`, `post_dispatch`, `call_teardown_hook`, `run_train`, `run_sanity_check`, `run_evaluate`, `run_evaluation`, `run_predict`, `track_output_for_epoch_end`
8888

8989

90+
- Changed `metrics_to_scalars` to work with any collection or value ([#7888](https://github.com/PyTorchLightning/pytorch-lightning/pull/7888))
91+
92+
9093
- Changed `clip_grad_norm` to use `torch.nn.utils.clip_grad_norm_` ([#7025](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
9194

9295

pytorch_lightning/utilities/metrics.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,24 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""Helper functions to operate on metric values. """
15+
import numbers
16+
from typing import Any
1517

1618
import torch
1719

20+
from pytorch_lightning.utilities.apply_func import apply_to_collection
1821
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1922

2023

21-
def metrics_to_scalars(metrics: dict) -> dict:
22-
""" Recursively walk through a dictionary of metrics and convert single-item tensors to scalar values. """
24+
def metrics_to_scalars(metrics: Any) -> Any:
25+
"""Recursively walk through a collection and convert single-item tensors to scalar values"""
2326

24-
# TODO: this is duplicated in MetricsHolder. should be unified
25-
new_metrics = {}
26-
for k, v in metrics.items():
27-
if isinstance(v, torch.Tensor):
28-
if v.numel() != 1:
29-
raise MisconfigurationException(
30-
f"The metric `{k}` does not contain a single element"
31-
f" thus it cannot be converted to float. Found `{v}`"
32-
)
33-
v = v.item()
27+
def to_item(value: torch.Tensor) -> numbers.Number:
28+
if value.numel() != 1:
29+
raise MisconfigurationException(
30+
f"The metric `{value}` does not contain a single element"
31+
f" thus it cannot be converted to float."
32+
)
33+
return value.item()
3434

35-
if isinstance(v, dict):
36-
v = metrics_to_scalars(v)
37-
38-
new_metrics[k] = v
39-
40-
return new_metrics
35+
return apply_to_collection(metrics, torch.Tensor, to_item)

tests/trainer/logging_/test_logger_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -487,7 +487,7 @@ def test_step(self, *args, **kwargs):
487487

488488
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True)
489489

490-
match = "The metric `test` does not contain a single element"
490+
match = "The metric `.*` does not contain a single element"
491491
with pytest.raises(MisconfigurationException, match=match):
492492
trainer.validate(model)
493493
with pytest.raises(MisconfigurationException, match=match):

0 commit comments

Comments
 (0)