diff --git a/CHANGELOG.md b/CHANGELOG.md index ef2f8070b1019..3e3c6959c55d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -162,6 +162,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851)) + - Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685)) - Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a24dc9d367f87..67dceb37579fc 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -26,7 +26,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union import torch from torch import ScriptModule, Tensor @@ -363,7 +363,7 @@ def log( def log_dict( self, - dictionary: Dict[str, _METRIC_COLLECTION], + dictionary: Mapping[str, _METRIC_COLLECTION], prog_bar: bool = False, logger: bool = True, on_step: Optional[bool] = None, diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 456bc5b5cd406..70d8e4a2b6ff6 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -13,6 +13,7 @@ # limitations under the License. import operator from abc import ABC +from collections import OrderedDict from collections.abc import Mapping, Sequence from copy import copy from functools import partial @@ -92,12 +93,12 @@ def apply_to_collection( # Recursively apply to collection items if isinstance(data, Mapping): - out = [] # can't use dict, need to preserve order if `OrderedDict` + out = [] for k, v in data.items(): v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) if include_none or v is not None: out.append((k, v)) - return elem_type(out) + return elem_type(OrderedDict(out)) is_namedtuple = _is_namedtuple(data) is_sequence = isinstance(data, Sequence) and not isinstance(data, str) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 1cc6ded7b25db..2457cf998c2cd 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -84,6 +84,16 @@ def test_recursive_application_to_collection(): reduced = apply_to_collection(OrderedDict([('b', 2), ('a', 1)]), int, lambda x: str(x)) assert reduced == OrderedDict([('b', '2'), ('a', '1')]) + # custom mappings + class _CustomCollection(dict): + + def __init__(self, initial_dict): + super().__init__(initial_dict) + + to_reduce = _CustomCollection({'a': 1, 'b': 2, 'c': 3}) + reduced = apply_to_collection(to_reduce, int, lambda x: str(x)) + assert reduced == _CustomCollection({'a': '1', 'b': '2', 'c': '3'}) + def test_apply_to_collection_include_none(): to_reduce = [1, 2, 3.4, 5.6, 7]