Skip to content

Commit 851e043

Browse files
kandluispre-commit-ci[bot]
authored andcommitted
[bugfix] Minor improvements to apply_to_collection and type signature of log_dict (#7851)
* minor fixeS * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b0f0d8b commit 851e043

File tree

4 files changed

+17
-4
lines changed

4 files changed

+17
-4
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
165165

166166
### Fixed
167167

168+
- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851))
169+
168170
- Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685))
169171

170172
- Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from argparse import Namespace
2727
from functools import partial
2828
from pathlib import Path
29-
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union
29+
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union
3030

3131
import torch
3232
from torch import ScriptModule, Tensor
@@ -363,7 +363,7 @@ def log(
363363

364364
def log_dict(
365365
self,
366-
dictionary: Dict[str, _METRIC_COLLECTION],
366+
dictionary: Mapping[str, _METRIC_COLLECTION],
367367
prog_bar: bool = False,
368368
logger: bool = True,
369369
on_step: Optional[bool] = None,

pytorch_lightning/utilities/apply_func.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import operator
1515
from abc import ABC
16+
from collections import OrderedDict
1617
from collections.abc import Mapping, Sequence
1718
from copy import copy
1819
from functools import partial
@@ -92,12 +93,12 @@ def apply_to_collection(
9293

9394
# Recursively apply to collection items
9495
if isinstance(data, Mapping):
95-
out = [] # can't use dict, need to preserve order if `OrderedDict`
96+
out = []
9697
for k, v in data.items():
9798
v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
9899
if include_none or v is not None:
99100
out.append((k, v))
100-
return elem_type(out)
101+
return elem_type(OrderedDict(out))
101102

102103
is_namedtuple = _is_namedtuple(data)
103104
is_sequence = isinstance(data, Sequence) and not isinstance(data, str)

tests/utilities/test_apply_func.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,16 @@ def test_recursive_application_to_collection():
8484
reduced = apply_to_collection(OrderedDict([('b', 2), ('a', 1)]), int, lambda x: str(x))
8585
assert reduced == OrderedDict([('b', '2'), ('a', '1')])
8686

87+
# custom mappings
88+
class _CustomCollection(dict):
89+
90+
def __init__(self, initial_dict):
91+
super().__init__(initial_dict)
92+
93+
to_reduce = _CustomCollection({'a': 1, 'b': 2, 'c': 3})
94+
reduced = apply_to_collection(to_reduce, int, lambda x: str(x))
95+
assert reduced == _CustomCollection({'a': '1', 'b': '2', 'c': '3'})
96+
8797

8898
def test_apply_to_collection_include_none():
8999
to_reduce = [1, 2, 3.4, 5.6, 7]

0 commit comments

Comments
 (0)