From 2d8976380a928d7a0711afc712013d971b16aa5f Mon Sep 17 00:00:00 2001 From: Luis Perez Date: Sun, 6 Jun 2021 18:07:45 -0700 Subject: [PATCH 1/3] minor fixeS --- pytorch_lightning/core/lightning.py | 4 ++-- pytorch_lightning/utilities/apply_func.py | 5 +++-- tests/utilities/test_apply_func.py | 9 +++++++++ 3 files changed, 14 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index a24dc9d367f87..67dceb37579fc 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -26,7 +26,7 @@ from argparse import Namespace from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, TYPE_CHECKING, Union import torch from torch import ScriptModule, Tensor @@ -363,7 +363,7 @@ def log( def log_dict( self, - dictionary: Dict[str, _METRIC_COLLECTION], + dictionary: Mapping[str, _METRIC_COLLECTION], prog_bar: bool = False, logger: bool = True, on_step: Optional[bool] = None, diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 456bc5b5cd406..70d8e4a2b6ff6 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -13,6 +13,7 @@ # limitations under the License. import operator from abc import ABC +from collections import OrderedDict from collections.abc import Mapping, Sequence from copy import copy from functools import partial @@ -92,12 +93,12 @@ def apply_to_collection( # Recursively apply to collection items if isinstance(data, Mapping): - out = [] # can't use dict, need to preserve order if `OrderedDict` + out = [] for k, v in data.items(): v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) if include_none or v is not None: out.append((k, v)) - return elem_type(out) + return elem_type(OrderedDict(out)) is_namedtuple = _is_namedtuple(data) is_sequence = isinstance(data, Sequence) and not isinstance(data, str) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index 1cc6ded7b25db..bcdc1eae1870c 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -84,6 +84,15 @@ def test_recursive_application_to_collection(): reduced = apply_to_collection(OrderedDict([('b', 2), ('a', 1)]), int, lambda x: str(x)) assert reduced == OrderedDict([('b', '2'), ('a', '1')]) + # custom mappings + class _CustomCollection(dict): + def __init__(self, initial_dict): + super().__init__(initial_dict) + + to_reduce = _CustomCollection({'a': 1, 'b': 2, 'c': 3}) + reduced = apply_to_collection(to_reduce, int, lambda x: str(x)) + assert reduced == _CustomCollection({'a': '1', 'b': '2', 'c': '3'}) + def test_apply_to_collection_include_none(): to_reduce = [1, 2, 3.4, 5.6, 7] From fb1590057ee23494e1adc72954780f451b46508e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 7 Jun 2021 01:22:59 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_apply_func.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index bcdc1eae1870c..2457cf998c2cd 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -86,6 +86,7 @@ def test_recursive_application_to_collection(): # custom mappings class _CustomCollection(dict): + def __init__(self, initial_dict): super().__init__(initial_dict) From 89af8058980f93fbf3fad21f7c3c32c0bd078b5b Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Mon, 7 Jun 2021 09:20:39 +0200 Subject: [PATCH 3/3] Update CHANGELOG.md --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef2f8070b1019..3e3c6959c55d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -162,6 +162,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `apply_to_collection` works on Custom Collections now ([#7851](https://github.com/PyTorchLightning/pytorch-lightning/pull/7851)) + - Fixed ambiguous warning when both overfit and train dataloader shuffling are enabled ([#7685](https://github.com/PyTorchLightning/pytorch-lightning/pull/7685)) - Fixed dataloaders are not reset when tuning the model ([#7566](https://github.com/PyTorchLightning/pytorch-lightning/pull/7566))