-
Notifications
You must be signed in to change notification settings - Fork 1.3k
fix: support decollate for numpy scalars #8470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
fix: support decollate for numpy scalars #8470
Conversation
187c141
to
c438fe0
Compare
monai/data/utils.py
Outdated
@@ -625,6 +625,8 @@ def decollate_batch(batch, detach: bool = True, pad=True, fill_value=None): | |||
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable) | |||
): | |||
return batch | |||
if isinstance(batch, np.ndarray) and batch.ndim == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the pr! Do you think it might be beneficial to convert the array into a tensor? This way, the data could be handled more consistently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could, I think it does not matter for my use cases. As long as the function handles numpy scalars in the form of an array it is good for me!
I will add this change and convert it as a tensor there (L629) if you prefer :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the quick fix!
May I ask the reason for only convert to tensor when batch.ndim == 0
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I noticed a different behavior when using the decollate_batch
function on torch tensors vs numpy arrays (see discussion #8472) so I don't want to convert numpy arrays to torch tensors as it will introduce some breaking changes
This PR only address the issue #8471 as I think it was not expected and should be supported (?).
451c207
to
49d4954
Compare
Could we consider a more complete solution? The issue it seems is that 0-d arrays are iterable but can't be iterated over. We already check for non-iterable things in |
Thanks for the feedback. The initial PR was: if isinstance(batch, (float, int, str, bytes)) or (
type(batch).__module__ == "numpy" and not isinstance(batch, Iterable)
):
return batch
if isinstance(batch, np.ndarray) and batch.ndim == 0:
return batch.item() if detach else batch
# rest ... Is this something that you find more complete? Note I refactored the PR to convert from numpy array to torch tensor as suggested by @KumoLiu. |
What I had in mind was more of the following change: ...
if batch is None or isinstance(batch, (float, int, str, bytes)):
return batch
if getattr(batch, "ndim", -1) == 0: # assumes only Numpy objects and Pytorch tensors have ndim
return batch.item() if detach else batch
if isinstance(batch, torch.Tensor):
if detach:
batch = batch.detach()
# REMOVE
# if batch.ndim == 0:
# return batch.item() if detach else batch
... |
Thanks! I will update the PR to include these changes. |
Important Review skippedReview was skipped as selected files did not have any reviewable changes. 💤 Files selected but had no reviewable changes (1)
You can disable this status message by setting the WalkthroughAdds a pre-check in monai/data/utils.py::decollate_batch for 0‑dimensional inputs: if the incoming batch has Estimated code review effort🎯 1 (Trivial) | ⏱️ ~2 minutes ✨ Finishing Touches🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (3)
monai/data/utils.py (3)
628-635
: Prefer generic 0-D detection early; drop special-case conversion.Handle any object with ndim==0 up-front, then the torch branch can skip its own scalar check. This reduces duplication and avoids unnecessary conversion for NumPy scalars when detach=False.
Apply this diff (and remove the torch scalar early-return below):
@@ - if isinstance(batch, (float, int, str, bytes)) or ( - type(batch).__module__ == "numpy" and not isinstance(batch, Iterable) - ): - return batch - if isinstance(batch, np.ndarray) and batch.ndim == 0: - batch = torch.from_numpy(batch) + if isinstance(batch, (float, int, str, bytes)) or ( + type(batch).__module__ == "numpy" and not isinstance(batch, Iterable) + ): + return batch + # generic 0-D objects (NumPy/Torch/others exposing ndim): treat as scalars + if getattr(batch, "ndim", -1) == 0: + return batch.item() if detach else batch @@ - if isinstance(batch, torch.Tensor): - if detach: - batch = batch.detach() - if batch.ndim == 0: - return batch.item() if detach else batch + if isinstance(batch, torch.Tensor): + if detach: + batch = batch.detach() + # ndim==0 handled above
628-630
: Graceful fallback for unsupported NumPy dtypes.torch.from_numpy will fail for some dtypes (e.g., datetime64, object). Consider catching and falling back to .item() or returning the ndarray unchanged.
- if isinstance(batch, np.ndarray) and batch.ndim == 0: - batch = torch.from_numpy(batch) + if isinstance(batch, np.ndarray) and batch.ndim == 0: + try: + batch = torch.from_numpy(batch) + except (TypeError, ValueError): + return batch.item() if detach else batch
614-621
: Doc/test touch-up for 0-D NumPy arrays.Please note 0-D NumPy array behavior in the decollate_batch docstring and add unit tests for:
- np.array(1) with detach=True/False
- Nested structures containing 0-D arrays
- Edge dtype (e.g., bool, float32)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (1)
monai/data/utils.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/data/utils.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
🔇 Additional comments (1)
monai/data/utils.py (1)
628-635
: Good fix: avoids 0-D ndarray iteration error and aligns with tensor path.Converting 0-D NumPy arrays to torch first prevents the TypeError raised by iterating 0-D arrays and lets the existing scalar-tensor handling (.item() when detach=True) kick in. Looks correct.
abb8218
to
ec3e5d9
Compare
fix linter Signed-off-by: Arthur Dujardin <[email protected]> fix numpy decollate multi arrays Signed-off-by: Arthur Dujardin <[email protected]> fix linter Signed-off-by: Arthur Dujardin <[email protected]> fix numpy scalar support Signed-off-by: Arthur Dujardin <[email protected]> minor refactoring for typing Signed-off-by: Arthur Dujardin <[email protected]> convert scalar array to tensor Signed-off-by: Arthur Dujardin <[email protected]> update decollate item
ec3e5d9
to
f9dba63
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Arthur Dujardin <[email protected]>
Description
This PR supports numpy scalars (e.g. in the form of
np.array(1)
) in thedecollate_batch
function (fix issue #8471).Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.