Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added correct `dataloader_idx` to batch transfer hooks ([#6241](https://github.com/PyTorchLightning/pytorch-lightning/pull/6241))


- Added `include_none=bool` argument to `apply_to_collection` ([#7769](https://github.com/PyTorchLightning/pytorch-lightning/pull/7769))


- Added `apply_to_collections` to apply a function to two zipped collections ([#7769](https://github.com/PyTorchLightning/pytorch-lightning/pull/7769))


- Added `ddp_fully_sharded` support ([#7487](https://github.com/PyTorchLightning/pytorch-lightning/pull/7487))


Expand Down
102 changes: 83 additions & 19 deletions pytorch_lightning/utilities/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,18 @@ def from_numpy(value, device: torch.device = None):
]


def _is_namedtuple(obj: object) -> bool:
# https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8
return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields")


def apply_to_collection(
data: Any,
dtype: Union[type, tuple],
function: Callable,
*args,
wrong_dtype: Optional[Union[type, tuple]] = None,
include_none: bool = True,
**kwargs
) -> Any:
"""
Expand All @@ -70,40 +76,98 @@ def apply_to_collection(
dtype: the given function will be applied to all elements of this dtype
function: the function to apply
*args: positional arguments (will be forwarded to calls of ``function``)
wrong_dtype: the given function won't be applied if this type is specified and the given collections is of
the :attr:`wrong_type` even if it is of type :attr`dtype`
wrong_dtype: the given function won't be applied if this type is specified and the given collections
is of the ``wrong_dtype`` even if it is of type ``dtype``
include_none: Whether to include an element if the output of ``function`` is ``None``.
**kwargs: keyword arguments (will be forwarded to calls of ``function``)

Returns:
the resulting collection
The resulting collection
"""
elem_type = type(data)

# Breaking condition
if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)):
return function(data, *args, **kwargs)

elem_type = type(data)

# Recursively apply to collection items
if isinstance(data, Mapping):
return elem_type({
k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
for k, v in data.items()
})

if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple
return elem_type(
*(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data)
)

if isinstance(data, Sequence) and not isinstance(data, str):
return elem_type([
apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data
])
out = [] # can't use dict, need to preserve order if `OrderedDict`
for k, v in data.items():
v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
if include_none or v is not None:
out.append((k, v))
return elem_type(out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no guarantee that an arbitrary Mapping will support construction from Iterate[Tuple[K, T]]. This will fail for custom collections.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For example, this will fail when data is CfgNode (from YAQS) (https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L74).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you saying this was supported before and now this change broke? Could you elaborate?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, we use Lightning in Detectron2Go (based on Detectron). The configuration there is specified by YACS (a CfgNode object).

We use the save_hyperameters function on this object, which eventually calls the code here. The previous code (which calls elem_type === CfgDict with a dict) worked fine. With this new change, an error is raised since elem_type === CfgDict cannot be constructed from a List of Tuples.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe the modification:

from collections import OrderedDict
...
elem_type(OrderedDict(out))

will fix this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kandluis are you interested in sending a patch with your proposal?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure -- sent out #7851 :)


is_namedtuple = _is_namedtuple(data)
is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
if is_namedtuple or is_sequence:
out = []
for d in data:
v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
if include_none or v is not None:
out.append(v)
return elem_type(*out) if is_namedtuple else elem_type(out)

# data is neither of dtype, nor a collection
return data


def apply_to_collections(
data1: Optional[Any],
data2: Optional[Any],
dtype: Union[type, tuple],
function: Callable,
*args,
wrong_dtype: Optional[Union[type, tuple]] = None,
**kwargs
) -> Any:
"""
Zips two collections and applies a function to their items of a certain dtype.

Args:
data1: The first collection
data2: The second collection
dtype: the given function will be applied to all elements of this dtype
function: the function to apply
*args: positional arguments (will be forwarded to calls of ``function``)
wrong_dtype: the given function won't be applied if this type is specified and the given collections
is of the ``wrong_dtype`` even if it is of type ``dtype``
**kwargs: keyword arguments (will be forwarded to calls of ``function``)

Returns:
The resulting collection
"""
if data1 is None and data2 is not None:
# in case they were passed reversed
data1, data2 = data2, None

elem_type = type(data1)

if isinstance(data1, dtype) and data2 is not None and (wrong_dtype is None or not isinstance(data1, wrong_dtype)):
return function(data1, data2, *args, **kwargs)

if isinstance(data1, Mapping) and data2 is not None:
# use union because we want to fail if a key does not exist in both
zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()}
return elem_type({
k: apply_to_collections(*v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
for k, v in zipped.items()
})
Comment on lines +152 to +156
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@awaelchli this implementation is not OrderedDict safe as data1.keys() | data2.keys() is a set.

But I don't think we need to implement it in this PR. Could add a warning if the instance is an ordereddict

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized the previous implementation also did not take care of OrderedDict.


is_namedtuple = _is_namedtuple(data1)
is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str)
if (is_namedtuple or is_sequence) and data2 is not None:
assert len(data1) == len(data2), 'Sequence collections have different sizes'
out = [
apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)
for v1, v2 in zip(data1, data2)
]
return elem_type(*out) if is_namedtuple else elem_type(out)

return apply_to_collection(data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs)


class TransferableDataType(ABC):
"""
A custom type for data that can be moved to a torch device via `.to(...)`.
Expand Down
88 changes: 83 additions & 5 deletions tests/utilities/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from collections import namedtuple
from collections import namedtuple, OrderedDict

import numpy as np
import pytest
import torch

from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections


def test_recursive_application_to_collection():
Expand All @@ -30,7 +31,7 @@ def test_recursive_application_to_collection():
'd': ntc(bar=5.), # named tuple
'e': np.array([10.]), # numpy array
'f': 'this_is_a_dummy_str', # string
'g': 12. # number
'g': 12., # number
}

expected_result = {
Expand All @@ -40,7 +41,7 @@ def test_recursive_application_to_collection():
'd': ntc(bar=torch.tensor([10.])),
'e': np.array([20.]),
'f': 'this_is_a_dummy_str',
'g': 24.
'g': 24.,
}

reduced = apply_to_collection(to_reduce, (torch.Tensor, numbers.Number, np.ndarray), lambda x: x * 2)
Expand Down Expand Up @@ -74,5 +75,82 @@ def test_recursive_application_to_collection():
assert isinstance(reduced['f'], str), 'A string should not be reduced'
assert reduced['f'] == expected_result['f'], 'String not preserved during reduction'

assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a tensor'
assert isinstance(reduced['g'], numbers.Number), 'Reduction of a number should result in a number'
assert reduced['g'] == expected_result['g'], 'Reduction of a number did not yield the desired result'

# mapping support
reduced = apply_to_collection({'a': 1, 'b': 2}, int, lambda x: str(x))
assert reduced == {'a': '1', 'b': '2'}
reduced = apply_to_collection(OrderedDict([('b', 2), ('a', 1)]), int, lambda x: str(x))
assert reduced == OrderedDict([('b', '2'), ('a', '1')])


def test_apply_to_collection_include_none():
to_reduce = [1, 2, 3.4, 5.6, 7]

def fn(x):
if isinstance(x, float):
return x

reduced = apply_to_collection(to_reduce, (int, float), fn)
assert reduced == [None, None, 3.4, 5.6, None]

reduced = apply_to_collection(to_reduce, (int, float), fn, include_none=False)
assert reduced == [3.4, 5.6]


def test_apply_to_collections():
to_reduce_1 = {'a': {'b': [1, 2]}, 'c': 5}
to_reduce_2 = {'a': {'b': [3, 4]}, 'c': 6}

def fn(a, b):
return a + b

# basic test
reduced = apply_to_collections(to_reduce_1, to_reduce_2, int, fn)
assert reduced == {'a': {'b': [4, 6]}, 'c': 11}

with pytest.raises(KeyError):
# strict mode - if a key does not exist in both we fail
apply_to_collections({**to_reduce_2, 'd': 'foo'}, to_reduce_1, float, fn)

# multiple dtypes
reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn)
assert reduced == {'a': {'b': [1, 2, 3, 4]}, 'c': 11}

# wrong dtype
reduced = apply_to_collections(to_reduce_1, to_reduce_2, (list, int), fn, wrong_dtype=int)
assert reduced == {'a': {'b': [1, 2, 3, 4]}, 'c': 5}

# list takes precedence because it is the type of data1
reduced = apply_to_collections([1, 2, 3], [4], (int, list), fn)
assert reduced == [1, 2, 3, 4]

# different sizes
with pytest.raises(AssertionError, match='Sequence collections have different sizes'):
apply_to_collections([[1, 2], [3]], [4], int, fn)

def fn(a, b):
return a.keys() | b.keys()

# base case
reduced = apply_to_collections(to_reduce_1, to_reduce_2, dict, fn)
assert reduced == {'a', 'c'}

# type conversion
to_reduce = [(1, 2), (3, 4)]
reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x))
assert reduced == [(2, 4), (6, 8)]

# named tuple
foo = namedtuple('Foo', ['bar'])
to_reduce = [foo(1), foo(2), foo(3)]
reduced = apply_to_collections(to_reduce, to_reduce, int, lambda *x: sum(x))
assert reduced == [foo(2), foo(4), foo(6)]

# passing none
reduced1 = apply_to_collections([1, 2, 3], None, int, lambda x: x * x)
reduced2 = apply_to_collections(None, [1, 2, 3], int, lambda x: x * x)
assert reduced1 == reduced2 == [1, 4, 9]
reduced = apply_to_collections(None, None, int, lambda x: x * x)
assert reduced is None