diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b75e439f1662..1d550c6dc0a7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,6 +41,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added correct `dataloader_idx` to batch transfer hooks ([#6241](https://github.com/PyTorchLightning/pytorch-lightning/pull/6241)) +- Added `include_none=bool` argument to `apply_to_collection` ([#7769](https://github.com/PyTorchLightning/pytorch-lightning/pull/7769)) + + +- Added `apply_to_collections` to apply a function to two zipped collections ([#7769](https://github.com/PyTorchLightning/pytorch-lightning/pull/7769)) + + - Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487)) diff --git a/pytorch_lightning/utilities/apply_func.py b/pytorch_lightning/utilities/apply_func.py index 1cbab2fb8dee9..456bc5b5cd406 100644 --- a/pytorch_lightning/utilities/apply_func.py +++ b/pytorch_lightning/utilities/apply_func.py @@ -54,12 +54,18 @@ def from_numpy(value, device: torch.device = None): ] +def _is_namedtuple(obj: object) -> bool: + # https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 + return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") + + def apply_to_collection( data: Any, dtype: Union[type, tuple], function: Callable, *args, wrong_dtype: Optional[Union[type, tuple]] = None, + include_none: bool = True, **kwargs ) -> Any: """ @@ -70,40 +76,98 @@ def apply_to_collection( dtype: the given function will be applied to all elements of this dtype function: the function to apply *args: positional arguments (will be forwarded to calls of ``function``) - wrong_dtype: the given function won't be applied if this type is specified and the given collections is of - the :attr:`wrong_type` even if it is of type :attr`dtype` + wrong_dtype: the given function won't be applied if this type is specified and the given collections + is of the ``wrong_dtype`` even if it is of type ``dtype`` + include_none: Whether to include an element if the output of ``function`` is ``None``. **kwargs: keyword arguments (will be forwarded to calls of ``function``) Returns: - the resulting collection + The resulting collection """ - elem_type = type(data) - # Breaking condition if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): return function(data, *args, **kwargs) + elem_type = type(data) + # Recursively apply to collection items if isinstance(data, Mapping): - return elem_type({ - k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) - for k, v in data.items() - }) - - if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple - return elem_type( - *(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data) - ) - - if isinstance(data, Sequence) and not isinstance(data, str): - return elem_type([ - apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data - ]) + out = [] # can't use dict, need to preserve order if `OrderedDict` + for k, v in data.items(): + v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + if include_none or v is not None: + out.append((k, v)) + return elem_type(out) + + is_namedtuple = _is_namedtuple(data) + is_sequence = isinstance(data, Sequence) and not isinstance(data, str) + if is_namedtuple or is_sequence: + out = [] + for d in data: + v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + if include_none or v is not None: + out.append(v) + return elem_type(*out) if is_namedtuple else elem_type(out) # data is neither of dtype, nor a collection return data +def apply_to_collections( + data1: Optional[Any], + data2: Optional[Any], + dtype: Union[type, tuple], + function: Callable, + *args, + wrong_dtype: Optional[Union[type, tuple]] = None, + **kwargs +) -> Any: + """ + Zips two collections and applies a function to their items of a certain dtype. + + Args: + data1: The first collection + data2: The second collection + dtype: the given function will be applied to all elements of this dtype + function: the function to apply + *args: positional arguments (will be forwarded to calls of ``function``) + wrong_dtype: the given function won't be applied if this type is specified and the given collections + is of the ``wrong_dtype`` even if it is of type ``dtype`` + **kwargs: keyword arguments (will be forwarded to calls of ``function``) + + Returns: + The resulting collection + """ + if data1 is None and data2 is not None: + # in case they were passed reversed + data1, data2 = data2, None + + elem_type = type(data1) + + if isinstance(data1, dtype) and data2 is not None and (wrong_dtype is None or not isinstance(data1, wrong_dtype)): + return function(data1, data2, *args, **kwargs) + + if isinstance(data1, Mapping) and data2 is not None: + # use union because we want to fail if a key does not exist in both + zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()} + return elem_type({ + k: apply_to_collections(*v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for k, v in zipped.items() + }) + + is_namedtuple = _is_namedtuple(data1) + is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str) + if (is_namedtuple or is_sequence) and data2 is not None: + assert len(data1) == len(data2), 'Sequence collections have different sizes' + out = [ + apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + for v1, v2 in zip(data1, data2) + ] + return elem_type(*out) if is_namedtuple else elem_type(out) + + return apply_to_collection(data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) + + class TransferableDataType(ABC): """ A custom type for data that can be moved to a torch device via `.to(...)`. diff --git a/tests/utilities/test_apply_func.py b/tests/utilities/test_apply_func.py index a7eea3a749f26..1cc6ded7b25db 100644 --- a/tests/utilities/test_apply_func.py +++ b/tests/utilities/test_apply_func.py @@ -12,12 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import numbers -from collections import namedtuple +from collections import namedtuple, OrderedDict import numpy as np +import pytest import torch -from pytorch_lightning.utilities.apply_func import apply_to_collection +from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections def test_recursive_application_to_collection(): @@ -30,7 +31,7 @@ def test_recursive_application_to_collection(): 'd': ntc(bar=5.), # named tuple 'e': np.array([10.]), # numpy array 'f': 'this_is_a_dummy_str', # string - 'g': 12. # number + 'g': 12., # number } expected_result = { @@ -40,7 +41,7 @@ def test_recursive_application_to_collection(): 'd': ntc(bar=torch.tensor([10.])), 'e': np.array([20.]), 'f': 'this_is_a_dummy_str', - 'g': 24. + 'g': 24., } reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2) @@ -74,5 +75,82 @@ def test_recursive_application_to_collection(): assert isinstance(reduced['f'], str), 'A string should not be reduced' assert reduced['f'] == expected_result['f'], 'String not preserved during reduction' - assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor' + assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number' assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result' + + # mapping support + reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x)) + assert reduced == {'a': '1', 'b': '2'} + reduced = apply_to_collection(OrderedDict([('b', 2), ('a', 1)]), int, lambda x: str(x)) + assert reduced == OrderedDict([('b', '2'), ('a', '1')]) + + +def test_apply_to_collection_include_none(): + to_reduce = [1, 2, 3.4, 5.6, 7] + + def fn(x): + if isinstance(x, float): + return x + + reduced = apply_to_collection(to_reduce, (int, float), fn) + assert reduced == [None, None, 3.4, 5.6, None] + + reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False) + assert reduced == [3.4, 5.6] + + +def test_apply_to_collections(): + to_reduce_1 = {'a': {'b': [1, 2]}, 'c': 5} + to_reduce_2 = {'a': {'b': [3, 4]}, 'c': 6} + + def fn(a, b): + return a + b + + # basic test + reduced = apply_to_collections(to_reduce_1, to_reduce_2, int, fn) + assert reduced == {'a': {'b': [4, 6]}, 'c': 11} + + with pytest.raises(KeyError): + # strict mode - if a key does not exist in both we fail + apply_to_collections({**to_reduce_2, 'd': 'foo'}, to_reduce_1, float, fn) + + # multiple dtypes + reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn) + assert reduced == {'a': {'b': [1, 2, 3, 4]}, 'c': 11} + + # wrong dtype + reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn, wrong_dtype=int) + assert reduced == {'a': {'b': [1, 2, 3, 4]}, 'c': 5} + + # list takes precedence because it is the type of data1 + reduced = apply_to_collections([1, 2, 3], [4], (int, list), fn) + assert reduced == [1, 2, 3, 4] + + # different sizes + with pytest.raises(AssertionError, match='Sequence collections have different sizes'): + apply_to_collections([[1, 2], [3]], [4], int, fn) + + def fn(a, b): + return a.keys() | b.keys() + + # base case + reduced = apply_to_collections(to_reduce_1, to_reduce_2, dict, fn) + assert reduced == {'a', 'c'} + + # type conversion + to_reduce = [(1, 2), (3, 4)] + reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x)) + assert reduced == [(2, 4), (6, 8)] + + # named tuple + foo = namedtuple('Foo', ['bar']) + to_reduce = [foo(1), foo(2), foo(3)] + reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x)) + assert reduced == [foo(2), foo(4), foo(6)] + + # passing none + reduced1 = apply_to_collections([1, 2, 3], None, int, lambda x: x * x) + reduced2 = apply_to_collections(None, [1, 2, 3], int, lambda x: x * x) + assert reduced1 == reduced2 == [1, 4, 9] + reduced = apply_to_collections(None, None, int, lambda x: x * x) + assert reduced is None