Skip to content

Commit 127f842

Browse files
committed
OrderedDict support
1 parent ebd4694 commit 127f842

File tree

2 files changed

+20
-12
lines changed

2 files changed

+20
-12
lines changed

pytorch_lightning/utilities/apply_func.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import operator
1515
from abc import ABC
16+
from collections import OrderedDict
1617
from collections.abc import Mapping, Sequence
1718
from copy import copy
1819
from functools import partial
@@ -84,24 +85,24 @@ def apply_to_collection(
8485
Returns:
8586
The resulting collection
8687
"""
87-
8888
# Breaking condition
8989
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
9090
return function(data, *args, **kwargs)
9191

92+
elem_type = type(data)
93+
9294
# Recursively apply to collection items
9395
if isinstance(data, Mapping):
94-
out = {}
96+
out = [] # can't use dict, need to preserve order if `OrderedDict`
9597
for k, v in data.items():
9698
v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
9799
if include_none or v is not None:
98-
out[k] = v
99-
return out
100+
out.append((k, v))
101+
return elem_type(out)
100102

101103
is_namedtuple = _is_namedtuple(data)
102104
is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
103105
if is_namedtuple or is_sequence:
104-
elem_type = type(data)
105106
out = []
106107
for d in data:
107108
v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
@@ -142,22 +143,23 @@ def apply_to_collections(
142143
# in case they were passed reversed
143144
data1, data2 = data2, None
144145

146+
elem_type = type(data1)
147+
145148
if isinstance(data1, dtype) and data2 is not None and (wrong_dtype is None or not isinstance(data1, wrong_dtype)):
146149
return function(data1, data2, *args, **kwargs)
147150

148151
if isinstance(data1, Mapping) and data2 is not None:
149152
# use union because we want to fail if a key does not exist in both
150153
zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()}
151-
return {
154+
return elem_type({
152155
k: apply_to_collections(*v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
153156
for k, v in zipped.items()
154-
}
157+
})
155158

156159
is_namedtuple = _is_namedtuple(data1)
157160
is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
158161
if (is_namedtuple or is_sequence) and data2 is not None:
159162
assert len(data1) == len(data2), 'Sequence collections have different sizes'
160-
elem_type = type(data1)
161163
out = [
162164
apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
163165
for v1, v2 in zip(data1, data2)

tests/utilities/test_apply_func.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import numbers
15-
from collections import namedtuple
15+
from collections import namedtuple, OrderedDict
1616

1717
import numpy as np
1818
import pytest
@@ -31,7 +31,7 @@ def test_recursive_application_to_collection():
3131
'd': ntc(bar=5.), # named tuple
3232
'e': np.array([10.]), # numpy array
3333
'f': 'this_is_a_dummy_str', # string
34-
'g': 12. # number
34+
'g': 12., # number
3535
}
3636

3737
expected_result = {
@@ -41,7 +41,7 @@ def test_recursive_application_to_collection():
4141
'd': ntc(bar=torch.tensor([10.])),
4242
'e': np.array([20.]),
4343
'f': 'this_is_a_dummy_str',
44-
'g': 24.
44+
'g': 24.,
4545
}
4646

4747
reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)
@@ -75,9 +75,15 @@ def test_recursive_application_to_collection():
7575
assert isinstance(reduced['f'], str), 'A string should not be reduced'
7676
assert reduced['f'] == expected_result['f'], 'String not preserved during reduction'
7777

78-
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor'
78+
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number'
7979
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'
8080

81+
# mapping support
82+
reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x))
83+
assert reduced == {'a': '1', 'b': '2'}
84+
reduced = apply_to_collection(OrderedDict([('b', 2), ('a', 1)]), int, lambda x: str(x))
85+
assert reduced == OrderedDict([('b', '2'), ('a', '1')])
86+
8187

8288
def test_apply_to_collection_include_none():
8389
to_reduce = [1, 2, 3.4, 5.6, 7]

0 commit comments

Comments
 (0)