Skip to content

Commit 195b24b

Browse files
authored
apply_to_collection improvements and add apply_to_collections (#7769)
* `apply_to_collection` improvements and add `apply_to_collections` * Update CHANGELOG * Minor fix * Minor fix * Remove attr * Swap is first is None * None test * OrderedDict support * flake8 * Fix docstring
1 parent 1dd61e4 commit 195b24b

File tree

3 files changed

+172
-24
lines changed

3 files changed

+172
-24
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4141
- Added correct `dataloader_idx` to batch transfer hooks ([#6241](https://github.com/PyTorchLightning/pytorch-lightning/pull/6241))
4242

4343

44+
- Added `include_none=bool` argument to `apply_to_collection` ([#7769](https://github.com/PyTorchLightning/pytorch-lightning/pull/7769))
45+
46+
47+
- Added `apply_to_collections` to apply a function to two zipped collections ([#7769](https://github.com/PyTorchLightning/pytorch-lightning/pull/7769))
48+
49+
4450
- Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487))
4551

4652

pytorch_lightning/utilities/apply_func.py

Lines changed: 83 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -54,12 +54,18 @@ def from_numpy(value, device: torch.device = None):
5454
]
5555

5656

57+
def _is_namedtuple(obj: object) -> bool:
58+
# https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8
59+
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")
60+
61+
5762
def apply_to_collection(
5863
data: Any,
5964
dtype: Union[type, tuple],
6065
function: Callable,
6166
*args,
6267
wrong_dtype: Optional[Union[type, tuple]] = None,
68+
include_none: bool = True,
6369
**kwargs
6470
) -> Any:
6571
"""
@@ -70,40 +76,98 @@ def apply_to_collection(
7076
dtype: the given function will be applied to all elements of this dtype
7177
function: the function to apply
7278
*args: positional arguments (will be forwarded to calls of ``function``)
73-
wrong_dtype: the given function won't be applied if this type is specified and the given collections is of
74-
the :attr:`wrong_type` even if it is of type :attr`dtype`
79+
wrong_dtype: the given function won't be applied if this type is specified and the given collections
80+
is of the ``wrong_dtype`` even if it is of type ``dtype``
81+
include_none: Whether to include an element if the output of ``function`` is ``None``.
7582
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
7683
7784
Returns:
78-
the resulting collection
85+
The resulting collection
7986
"""
80-
elem_type = type(data)
81-
8287
# Breaking condition
8388
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
8489
return function(data, *args, **kwargs)
8590

91+
elem_type = type(data)
92+
8693
# Recursively apply to collection items
8794
if isinstance(data, Mapping):
88-
return elem_type({
89-
k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
90-
for k, v in data.items()
91-
})
92-
93-
if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
94-
return elem_type(
95-
*(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data)
96-
)
97-
98-
if isinstance(data, Sequence) and not isinstance(data, str):
99-
return elem_type([
100-
apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data
101-
])
95+
out = [] # can't use dict, need to preserve order if `OrderedDict`
96+
for k, v in data.items():
97+
v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
98+
if include_none or v is not None:
99+
out.append((k, v))
100+
return elem_type(out)
101+
102+
is_namedtuple = _is_namedtuple(data)
103+
is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
104+
if is_namedtuple or is_sequence:
105+
out = []
106+
for d in data:
107+
v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
108+
if include_none or v is not None:
109+
out.append(v)
110+
return elem_type(*out) if is_namedtuple else elem_type(out)
102111

103112
# data is neither of dtype, nor a collection
104113
return data
105114

106115

116+
def apply_to_collections(
117+
data1: Optional[Any],
118+
data2: Optional[Any],
119+
dtype: Union[type, tuple],
120+
function: Callable,
121+
*args,
122+
wrong_dtype: Optional[Union[type, tuple]] = None,
123+
**kwargs
124+
) -> Any:
125+
"""
126+
Zips two collections and applies a function to their items of a certain dtype.
127+
128+
Args:
129+
data1: The first collection
130+
data2: The second collection
131+
dtype: the given function will be applied to all elements of this dtype
132+
function: the function to apply
133+
*args: positional arguments (will be forwarded to calls of ``function``)
134+
wrong_dtype: the given function won't be applied if this type is specified and the given collections
135+
is of the ``wrong_dtype`` even if it is of type ``dtype``
136+
**kwargs: keyword arguments (will be forwarded to calls of ``function``)
137+
138+
Returns:
139+
The resulting collection
140+
"""
141+
if data1 is None and data2 is not None:
142+
# in case they were passed reversed
143+
data1, data2 = data2, None
144+
145+
elem_type = type(data1)
146+
147+
if isinstance(data1, dtype) and data2 is not None and (wrong_dtype is None or not isinstance(data1, wrong_dtype)):
148+
return function(data1, data2, *args, **kwargs)
149+
150+
if isinstance(data1, Mapping) and data2 is not None:
151+
# use union because we want to fail if a key does not exist in both
152+
zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()}
153+
return elem_type({
154+
k: apply_to_collections(*v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
155+
for k, v in zipped.items()
156+
})
157+
158+
is_namedtuple = _is_namedtuple(data1)
159+
is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
160+
if (is_namedtuple or is_sequence) and data2 is not None:
161+
assert len(data1) == len(data2), 'Sequence collections have different sizes'
162+
out = [
163+
apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
164+
for v1, v2 in zip(data1, data2)
165+
]
166+
return elem_type(*out) if is_namedtuple else elem_type(out)
167+
168+
return apply_to_collection(data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
169+
170+
107171
class TransferableDataType(ABC):
108172
"""
109173
A custom type for data that can be moved to a torch device via `.to(...)`.

tests/utilities/test_apply_func.py

Lines changed: 83 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import numbers
15-
from collections import namedtuple
15+
from collections import namedtuple, OrderedDict
1616

1717
import numpy as np
18+
import pytest
1819
import torch
1920

20-
from pytorch_lightning.utilities.apply_func import apply_to_collection
21+
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
2122

2223

2324
def test_recursive_application_to_collection():
@@ -30,7 +31,7 @@ def test_recursive_application_to_collection():
3031
'd': ntc(bar=5.), # named tuple
3132
'e': np.array([10.]), # numpy array
3233
'f': 'this_is_a_dummy_str', # string
33-
'g': 12. # number
34+
'g': 12., # number
3435
}
3536

3637
expected_result = {
@@ -40,7 +41,7 @@ def test_recursive_application_to_collection():
4041
'd': ntc(bar=torch.tensor([10.])),
4142
'e': np.array([20.]),
4243
'f': 'this_is_a_dummy_str',
43-
'g': 24.
44+
'g': 24.,
4445
}
4546

4647
reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)
@@ -74,5 +75,82 @@ def test_recursive_application_to_collection():
7475
assert isinstance(reduced['f'], str), 'A string should not be reduced'
7576
assert reduced['f'] == expected_result['f'], 'String not preserved during reduction'
7677

77-
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor'
78+
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number'
7879
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'
80+
81+
# mapping support
82+
reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x))
83+
assert reduced == {'a': '1', 'b': '2'}
84+
reduced = apply_to_collection(OrderedDict([('b', 2), ('a', 1)]), int, lambda x: str(x))
85+
assert reduced == OrderedDict([('b', '2'), ('a', '1')])
86+
87+
88+
def test_apply_to_collection_include_none():
89+
to_reduce = [1, 2, 3.4, 5.6, 7]
90+
91+
def fn(x):
92+
if isinstance(x, float):
93+
return x
94+
95+
reduced = apply_to_collection(to_reduce, (int, float), fn)
96+
assert reduced == [None, None, 3.4, 5.6, None]
97+
98+
reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False)
99+
assert reduced == [3.4, 5.6]
100+
101+
102+
def test_apply_to_collections():
103+
to_reduce_1 = {'a': {'b': [1, 2]}, 'c': 5}
104+
to_reduce_2 = {'a': {'b': [3, 4]}, 'c': 6}
105+
106+
def fn(a, b):
107+
return a + b
108+
109+
# basic test
110+
reduced = apply_to_collections(to_reduce_1, to_reduce_2, int, fn)
111+
assert reduced == {'a': {'b': [4, 6]}, 'c': 11}
112+
113+
with pytest.raises(KeyError):
114+
# strict mode - if a key does not exist in both we fail
115+
apply_to_collections({**to_reduce_2, 'd': 'foo'}, to_reduce_1, float, fn)
116+
117+
# multiple dtypes
118+
reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn)
119+
assert reduced == {'a': {'b': [1, 2, 3, 4]}, 'c': 11}
120+
121+
# wrong dtype
122+
reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn, wrong_dtype=int)
123+
assert reduced == {'a': {'b': [1, 2, 3, 4]}, 'c': 5}
124+
125+
# list takes precedence because it is the type of data1
126+
reduced = apply_to_collections([1, 2, 3], [4], (int, list), fn)
127+
assert reduced == [1, 2, 3, 4]
128+
129+
# different sizes
130+
with pytest.raises(AssertionError, match='Sequence collections have different sizes'):
131+
apply_to_collections([[1, 2], [3]], [4], int, fn)
132+
133+
def fn(a, b):
134+
return a.keys() | b.keys()
135+
136+
# base case
137+
reduced = apply_to_collections(to_reduce_1, to_reduce_2, dict, fn)
138+
assert reduced == {'a', 'c'}
139+
140+
# type conversion
141+
to_reduce = [(1, 2), (3, 4)]
142+
reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x))
143+
assert reduced == [(2, 4), (6, 8)]
144+
145+
# named tuple
146+
foo = namedtuple('Foo', ['bar'])
147+
to_reduce = [foo(1), foo(2), foo(3)]
148+
reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x))
149+
assert reduced == [foo(2), foo(4), foo(6)]
150+
151+
# passing none
152+
reduced1 = apply_to_collections([1, 2, 3], None, int, lambda x: x * x)
153+
reduced2 = apply_to_collections(None, [1, 2, 3], int, lambda x: x * x)
154+
assert reduced1 == reduced2 == [1, 4, 9]
155+
reduced = apply_to_collections(None, None, int, lambda x: x * x)
156+
assert reduced is None

0 commit comments

Comments
 (0)