-
Notifications
You must be signed in to change notification settings - Fork 3.6k
apply_to_collection improvements and add apply_to_collections
#7769
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b998f87
066945c
8ccb997
28d1cec
1e8e409
8dd9a25
ebd4694
127f842
b397ad3
b0ac231
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,12 +54,18 @@ def from_numpy(value, device: torch.device = None): | |
| ] | ||
|
|
||
|
|
||
| def _is_namedtuple(obj: object) -> bool: | ||
| # https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 | ||
| return isinstance(obj, tuple) and hasattr(obj, "_asdict") and hasattr(obj, "_fields") | ||
|
|
||
|
|
||
| def apply_to_collection( | ||
| data: Any, | ||
| dtype: Union[type, tuple], | ||
| function: Callable, | ||
| *args, | ||
| wrong_dtype: Optional[Union[type, tuple]] = None, | ||
| include_none: bool = True, | ||
| **kwargs | ||
| ) -> Any: | ||
| """ | ||
|
|
@@ -70,40 +76,98 @@ def apply_to_collection( | |
| dtype: the given function will be applied to all elements of this dtype | ||
| function: the function to apply | ||
| *args: positional arguments (will be forwarded to calls of ``function``) | ||
| wrong_dtype: the given function won't be applied if this type is specified and the given collections is of | ||
| the :attr:`wrong_type` even if it is of type :attr`dtype` | ||
| wrong_dtype: the given function won't be applied if this type is specified and the given collections | ||
| is of the ``wrong_dtype`` even if it is of type ``dtype`` | ||
| include_none: Whether to include an element if the output of ``function`` is ``None``. | ||
| **kwargs: keyword arguments (will be forwarded to calls of ``function``) | ||
|
|
||
| Returns: | ||
| the resulting collection | ||
| The resulting collection | ||
| """ | ||
| elem_type = type(data) | ||
|
|
||
| # Breaking condition | ||
| if isinstance(data, dtype) and (wrong_dtype is None or not isinstance(data, wrong_dtype)): | ||
| return function(data, *args, **kwargs) | ||
|
|
||
| elem_type = type(data) | ||
|
|
||
| # Recursively apply to collection items | ||
| if isinstance(data, Mapping): | ||
| return elem_type({ | ||
| k: apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) | ||
| for k, v in data.items() | ||
| }) | ||
|
|
||
| if isinstance(data, tuple) and hasattr(data, '_fields'): # named tuple | ||
| return elem_type( | ||
| *(apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data) | ||
| ) | ||
|
|
||
| if isinstance(data, Sequence) and not isinstance(data, str): | ||
| return elem_type([ | ||
| apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) for d in data | ||
| ]) | ||
| out = [] # can't use dict, need to preserve order if `OrderedDict` | ||
| for k, v in data.items(): | ||
| v = apply_to_collection(v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) | ||
| if include_none or v is not None: | ||
| out.append((k, v)) | ||
| return elem_type(out) | ||
|
|
||
| is_namedtuple = _is_namedtuple(data) | ||
| is_sequence = isinstance(data, Sequence) and not isinstance(data, str) | ||
| if is_namedtuple or is_sequence: | ||
| out = [] | ||
| for d in data: | ||
| v = apply_to_collection(d, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) | ||
| if include_none or v is not None: | ||
| out.append(v) | ||
| return elem_type(*out) if is_namedtuple else elem_type(out) | ||
|
|
||
| # data is neither of dtype, nor a collection | ||
| return data | ||
|
|
||
|
|
||
| def apply_to_collections( | ||
| data1: Optional[Any], | ||
| data2: Optional[Any], | ||
| dtype: Union[type, tuple], | ||
| function: Callable, | ||
| *args, | ||
| wrong_dtype: Optional[Union[type, tuple]] = None, | ||
| **kwargs | ||
| ) -> Any: | ||
| """ | ||
| Zips two collections and applies a function to their items of a certain dtype. | ||
|
|
||
| Args: | ||
| data1: The first collection | ||
| data2: The second collection | ||
| dtype: the given function will be applied to all elements of this dtype | ||
| function: the function to apply | ||
| *args: positional arguments (will be forwarded to calls of ``function``) | ||
| wrong_dtype: the given function won't be applied if this type is specified and the given collections | ||
| is of the ``wrong_dtype`` even if it is of type ``dtype`` | ||
| **kwargs: keyword arguments (will be forwarded to calls of ``function``) | ||
|
|
||
| Returns: | ||
| The resulting collection | ||
| """ | ||
| if data1 is None and data2 is not None: | ||
| # in case they were passed reversed | ||
| data1, data2 = data2, None | ||
|
|
||
| elem_type = type(data1) | ||
|
|
||
| if isinstance(data1, dtype) and data2 is not None and (wrong_dtype is None or not isinstance(data1, wrong_dtype)): | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return function(data1, data2, *args, **kwargs) | ||
|
|
||
| if isinstance(data1, Mapping) and data2 is not None: | ||
| # use union because we want to fail if a key does not exist in both | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| zipped = {k: (data1[k], data2[k]) for k in data1.keys() | data2.keys()} | ||
| return elem_type({ | ||
| k: apply_to_collections(*v, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) | ||
| for k, v in zipped.items() | ||
| }) | ||
|
Comment on lines
+152
to
+156
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @awaelchli this implementation is not OrderedDict safe as But I don't think we need to implement it in this PR. Could add a warning if the instance is an ordereddict
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I just realized the previous implementation also did not take care of OrderedDict. |
||
|
|
||
| is_namedtuple = _is_namedtuple(data1) | ||
| is_sequence = isinstance(data1, Sequence) and not isinstance(data1, str) | ||
| if (is_namedtuple or is_sequence) and data2 is not None: | ||
| assert len(data1) == len(data2), 'Sequence collections have different sizes' | ||
| out = [ | ||
| apply_to_collections(v1, v2, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) | ||
| for v1, v2 in zip(data1, data2) | ||
| ] | ||
| return elem_type(*out) if is_namedtuple else elem_type(out) | ||
|
|
||
| return apply_to_collection(data1, dtype, function, *args, wrong_dtype=wrong_dtype, **kwargs) | ||
|
|
||
|
|
||
| class TransferableDataType(ABC): | ||
| """ | ||
| A custom type for data that can be moved to a torch device via `.to(...)`. | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no guarantee that an arbitrary
Mappingwill support construction from Iterate[Tuple[K, T]]. This will fail for custom collections.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, this will fail when data is CfgNode (from YAQS) (https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L74).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you saying this was supported before and now this change broke? Could you elaborate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, we use Lightning in Detectron2Go (based on Detectron). The configuration there is specified by YACS (a CfgNode object).
We use the
save_hyperametersfunction on this object, which eventually calls the code here. The previous code (which callselem_type === CfgDictwith adict) worked fine. With this new change, an error is raised sinceelem_type === CfgDictcannot be constructed from a List of Tuples.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe the modification:
will fix this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kandluis are you interested in sending a patch with your proposal?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure -- sent out #7851 :)