-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Torchvision pretrain fix #8563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Torchvision pretrain fix #8563
Conversation
Signed-off-by: Eric Kerfoot <[email protected]>
Signed-off-by: Eric Kerfoot <[email protected]>
WalkthroughReplaces deprecated torchvision Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Assessment against linked issues
Assessment against linked issues: Out-of-scope changes
Pre-merge checks (3 passed, 2 warnings)❌ Failed checks (2 warnings)
✅ Passed checks (3 passed)
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal). Please share your feedback with us on this Discord post. ✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (4)
monai/networks/blocks/fcn.py (1)
126-128
: Weights API switch looks good; expose weights to callers for flexibility.Using
ResNet50_Weights.IMAGENET1K_V1
preserves legacy behavior. Consider exposing aweights
arg (defaulting to V1 whenpretrained=True
) so callers can opt into newer defaults (e.g., V2) or custom weights without patching code.Example diff:
- def __init__( - self, out_channels: int = 1, upsample_mode: str = "bilinear", pretrained: bool = True, progress: bool = True - ): + def __init__( + self, + out_channels: int = 1, + upsample_mode: str = "bilinear", + pretrained: bool = True, + progress: bool = True, + weights=None, + ): @@ - resnet = models.resnet50( - progress=progress, weights=models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None - ) + selected_weights = models.ResNet50_Weights.IMAGENET1K_V1 if (pretrained and weights is None) else weights + resnet = models.resnet50(progress=progress, weights=selected_weights)monai/networks/nets/milmodel.py (2)
19-22
: Import path change is fine; consider aligning tests.Tests still import
optional_import
frommonai.utils.module
. Not a blocker, but consider unifying onmonai.utils
for consistency.
96-106
: Add torchvision version-compat fallback for string backbones.Calling
torch_model(weights="DEFAULT")
will fail on older torchvision. A tinyTypeError
fallback keeps backward compatibility without reintroducing deprecation warnings on newer versions.- net = torch_model(weights="DEFAULT" if pretrained else None) + try: + net = torch_model(weights="DEFAULT" if pretrained else None) + except TypeError: + net = torch_model(pretrained=pretrained)Also applies to: 101-102
tests/networks/nets/test_milmodel.py (1)
53-61
: Guard import-time backbone construction to avoid failures when torchvision isn't present.Building
backbone = models.densenet121()
at import time will error if torchvision is missing. Wrap with a guard.-# custom backbone -backbone = models.densenet121() -backbone_nfeatures = backbone.classifier.in_features -backbone.classifier = torch.nn.Identity() -TEST_CASE_MILMODEL.append( - [ - {"num_classes": 5, "backbone": backbone, "backbone_num_features": backbone_nfeatures, "pretrained": False}, - (2, 2, 3, 512, 512), - (2, 5), - ] -) +# custom backbone +if models is not None: + backbone = models.densenet121() + backbone_nfeatures = backbone.classifier.in_features + backbone.classifier = torch.nn.Identity() + TEST_CASE_MILMODEL.append( + [ + {"num_classes": 5, "backbone": backbone, "backbone_num_features": backbone_nfeatures, "pretrained": False}, + (2, 2, 3, 512, 512), + (2, 5), + ] + )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (5)
monai/networks/blocks/fcn.py
(1 hunks)monai/networks/nets/milmodel.py
(3 hunks)monai/networks/nets/torchvision_fc.py
(1 hunks)tests/networks/nets/test_densenet.py
(1 hunks)tests/networks/nets/test_milmodel.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/networks/blocks/fcn.py
monai/networks/nets/torchvision_fc.py
tests/networks/nets/test_densenet.py
tests/networks/nets/test_milmodel.py
monai/networks/nets/milmodel.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (macOS-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: packaging
🔇 Additional comments (4)
tests/networks/nets/test_densenet.py (1)
91-103
: Harden test: wrap torchvision download in skip context and use tolerance-based compare. Avoid spurious CI failures if weights download flakes and prefer numerical closeness over exact equality.@parameterized.expand([TEST_PRETRAINED_2D_CASE_3]) @skipUnless(has_torchvision, "Requires `torchvision` package.") def test_pretrain_consistency(self, model, input_param, input_shape): example = torch.randn(input_shape).to(device) - with skip_if_downloading_fails(): - net = model(**input_param).to(device) - with eval_mode(net): - result = net.features.forward(example) - torchvision_net = torchvision.models.densenet121(weights="DEFAULT").to(device) - with eval_mode(torchvision_net): - expected_result = torchvision_net.features.forward(example) - self.assertTrue(torch.all(result == expected_result)) + with skip_if_downloading_fails(): + net = model(**input_param).to(device) + with eval_mode(net): + result = net.features.forward(example) + tv_net = torchvision.models.densenet121(weights="DEFAULT").to(device) + with eval_mode(tv_net): + expected = tv_net.features.forward(example) + # allow tiny numerical diffs across backends + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-6)Torchvision supports
weights="DEFAULT"
strings (>=0.13.0) (docs.pytorch.org, pytorch.org).monai/networks/nets/torchvision_fc.py (1)
115-121
: Make pretrained/weights interplay explicit with compatibility fallback
- Document that
pretrained=True, weights=None
maps toweights="DEFAULT"
.- Add a
try/except AttributeError
aroundgetattr(models, model_name)
to raiseValueError
for unknown models.- Wrap the weights‐based constructor in
try/except TypeError
to fall back onpretrained=
for older torchvision.Unable to verify torchvision’s API in this environment—please run the Python snippet in your local setup to confirm whether
weights=
is supported and adjust the fallback accordingly.monai/networks/nets/milmodel.py (1)
75-79
: Good: explicit ResNet50 V1 weights to preserve legacy behavior.No issues here.
tests/networks/nets/test_milmodel.py (1)
47-50
: Nice coverage: exercise both pretrained states.LGTM.
/build |
|
Hi @ericspod, do you remember the reason why we fixed the isort version? Can we try loosen the range of package versions in this PR? |
No I don't remember I'm afraid, I'm going to try loosening it now. |
Signed-off-by: Eric Kerfoot <[email protected]>
/build |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
requirements-dev.txt (1)
21-21
: Optional: cap to a minor while we validate, or consolidate on Ruff for import sorting.
- Conservative cap until CI passes broadly:
-isort>=5.1, !=6.0.0 +isort>=5.1, !=6.0.0, <6.1
- Or drop isort and rely on ruff’s isort rules to reduce tool overlap (follow-up PR).
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (1)
requirements-dev.txt
(1 hunks)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (19)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-py3 (3.12)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: build-docs
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: quick-py3 (macOS-latest)
🔇 Additional comments (2)
requirements-dev.txt (2)
21-21
: LGTM: loosens isort to admit 6.0.1 and resolves the resolver conflict.
This matches the CI feedback that complained about<6.0
vs==6.0.1
.
21-21
: Verify no other isort pins/configs still conflict with 6.x.requirements-dev.txt:21 —
isort>=5.1, !=6.0.0
(this blocks 6.x).
No other isort pins or pre-commit hook found in the repo.
setup.cfg contains a[isort]
section (profile = black
,known_first_party = monai
) — review for v6 breaking changes.
Signed-off-by: Eric Kerfoot <[email protected]>
If I remember, the new version of isort does weird things to sort imports that are a little complicated which old versions don't agree with. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/transforms/io/array.py (1)
544-553
: Guard meta_data and add robust key fallback to avoid runtime errors.If img isn’t a MetaTensor, meta_data is undefined, causing UnboundLocalError. Also, be resilient to either enum being used.
Apply:
if isinstance(img, MetaTensor): meta_data = img.meta + else: + raise TypeError("WriteFileMapping expects a MetaTensor with metadata.") - if MetaKeys.SAVED_TO not in meta_data: + if MetaKeys.SAVED_TO not in meta_data: raise KeyError( "Missing 'saved_to' key in metadata. Check SaveImage argument 'savepath_in_metadict' is True." ) - input_path = meta_data[ImageMetaKey.FILENAME_OR_OBJ] + input_path = meta_data.get(ImageMetaKey.FILENAME_OR_OBJ) or meta_data.get(MetaKeys.FILENAME_OR_OBJ) + if input_path is None: + raise KeyError("Missing 'filename_or_obj' key in metadata.") output_path = meta_data[MetaKeys.SAVED_TO]
🧹 Nitpick comments (1)
monai/transforms/io/array.py (1)
389-393
: Tighten type hint for padding_mode.Default is an enum value; reflect that in the annotation.
Apply:
- padding_mode: str = GridSamplePadMode.BORDER, + padding_mode: GridSamplePadMode | str = GridSamplePadMode.BORDER,
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge base: Disabled due to Reviews -> Disable Knowledge Base
setting
📒 Files selected for processing (1)
monai/transforms/io/array.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
monai/transforms/io/array.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (18)
- GitHub Check: flake8-py3 (mypy)
- GitHub Check: packaging
- GitHub Check: flake8-py3 (codeformat)
- GitHub Check: build-docs
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (ubuntu-latest)
- GitHub Check: quick-py3 (windows-latest)
- GitHub Check: min-dep-pytorch (2.7.1)
- GitHub Check: min-dep-pytorch (2.6.0)
- GitHub Check: min-dep-pytorch (2.5.1)
- GitHub Check: min-dep-pytorch (2.8.0)
- GitHub Check: min-dep-os (ubuntu-latest)
- GitHub Check: min-dep-os (windows-latest)
- GitHub Check: min-dep-py3 (3.9)
- GitHub Check: min-dep-py3 (3.11)
- GitHub Check: min-dep-py3 (3.10)
- GitHub Check: min-dep-os (macOS-latest)
- GitHub Check: min-dep-py3 (3.12)
🔇 Additional comments (1)
monai/transforms/io/array.py (1)
296-298
: Do not mirror FILENAME_OR_OBJ into MetaKeys.No references to MetaKeys.FILENAME_OR_OBJ were found; code consistently uses ImageMetaKey / Key.FILENAME_OR_OBJ (e.g. monai/transforms/io/array.py:297, monai/transforms/spatial/array.py:326, monai/data/csv_saver.py:95).
Likely an incorrect or invalid review comment.
/build |
Seems there are many conflicts in the PyTorch 25.0x base image. |
/build |
Since we didn't fully tested on 25.0x pytorch base images, I use 24.10 base image instead. |
Fixes #8552.
Description
This updates how pretrained model weights are loaded through Torchvision. This may not preserve historical results if the weights being loaded are now different since the "DEFAULT" weights may not be the weights loaded when using the
pretrained=True
argument. I tried to preserved behaviour as indicated in the Torchvision source code where possible.Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.