-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Classifier free guidance unconditioned value #8562
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Classifier free guidance unconditioned value #8562
Conversation
WalkthroughAdds a new cfg_fill_value: float = -1.0 parameter to sample() across multiple inferer classes in monai/inferers/inferer.py. Replaces hard-coded uncondition.fill_(-1) with uncondition.fill_(cfg_fill_value) and propagates this argument through nested sampling calls. Updates docstrings accordingly. Tests updated to pass cfg_fill_value in DiffusionInferer and LatentDiffusionInferer sampling tests. Default behavior remains unchanged when cfg_fill_value is not provided. Estimated code review effort🎯 2 (Simple) | ⏱️ ~10 minutes Assessment against linked issues
Assessment against linked issues: Out-of-scope changesNone found. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
monai/inferers/inferer.py (4)
1100-1100
: Fix ambiguous Tensor truth-value check when computing log_predicted_variance.This will raise “bool value of Tensor is ambiguous” when predicted variance is learned.
Apply:
- log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + log_predicted_variance = ( + torch.log(predicted_variance) if predicted_variance is not None else log_posterior_variance + )
1721-1721
: Same ambiguous Tensor truth-value bug as above.Align with the earlier fix.
Apply:
- log_predicted_variance = torch.log(predicted_variance) if predicted_variance else log_posterior_variance + log_predicted_variance = ( + torch.log(predicted_variance) if predicted_variance is not None else log_posterior_variance + )
1775-1775
: User-facing error message missing a space.Tiny UX nit: concatenated literals render as “Noneand”.
Apply:
- raise ValueError("If ldm_latent_shape is None, autoencoder_latent_shape must be None" "and vice versa.") + raise ValueError( + "If ldm_latent_shape is None, autoencoder_latent_shape must be None and vice versa." + )
1876-1878
: Interpolation bug in error message (missing f-string).Placeholders won’t render values.
Apply:
- raise ValueError( - "If both autoencoder_model and diffusion_model implement SPADE, the number of semantic" - "labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}" - ) + raise ValueError( + f"If both autoencoder_model and diffusion_model implement SPADE, the number of semantic " + f"labels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}" + )
🧹 Nitpick comments (4)
monai/inferers/inferer.py (2)
966-967
: Prefer torch.full_like and guard odd dtypes for unconditional stub.Avoid a redundant allocation and potential dtype surprises (e.g., uint8/boolean casting) by constructing directly with the fill value.
Apply:
- uncondition = torch.ones_like(conditioning) - uncondition.fill_(cfg_fill_value) + uncondition = torch.full_like(conditioning, cfg_fill_value)Optionally assert floating dtype if negative fill values are used to prevent unsigned wrap-around. Would you like a follow-up patch to add a minimal runtime check?
1531-1532
: Mirror the torch.full_like change here too.Same rationale as above.
Apply:
- if conditioning is not None: - uncondition = torch.ones_like(conditioning) - uncondition.fill_(cfg_fill_value) + if conditioning is not None: + uncondition = torch.full_like(conditioning, cfg_fill_value)tests/inferers/test_diffusion_inferer.py (1)
109-110
: Add a non-default cfg_fill_value case to exercise the new API.Current test passes -1 (the default). Include another value (e.g., 0.0) to validate plumbing.
Apply:
sample, intermediates = inferer.sample( input_noise=noise, diffusion_model=model, scheduler=scheduler, save_intermediates=True, intermediate_steps=1, cfg=5, - cfg_fill_value=-1, + cfg_fill_value=0.0, )tests/inferers/test_latent_diffusion_inferer.py (1)
459-460
: Also test a non-default cfg_fill_value here.Catches dtype/shape path issues for latent sampling too.
Apply:
sample = inferer.sample( input_noise=noise, autoencoder_model=stage_1, diffusion_model=stage_2, scheduler=scheduler, seg=input_seg, cfg=5, - cfg_fill_value=-1, + cfg_fill_value=0.0, )
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Knowledge Base: Disabled due to Reviews > Disable Knowledge Base setting
📒 Files selected for processing (3)
monai/inferers/inferer.py
(12 hunks)tests/inferers/test_diffusion_inferer.py
(1 hunks)tests/inferers/test_latent_diffusion_inferer.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (1)
**/*.py
⚙️ CodeRabbit configuration file
Review the Python code for quality and correctness. Ensure variable names adhere to PEP8 style guides, are sensible and informative in regards to their function, though permitting simple names for loop and comprehension variables. Ensure routine names are meaningful in regards to their function and use verbs, adjectives, and nouns in a semantically appropriate way. Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings. Examine code for logical error or inconsistencies, and suggest what may be changed to addressed these. Suggest any enhancements for code improving efficiency, maintainability, comprehensibility, and correctness. Ensure new or modified definitions will be covered by existing or new unit tests.
Files:
tests/inferers/test_latent_diffusion_inferer.py
tests/inferers/test_diffusion_inferer.py
monai/inferers/inferer.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: flake8-py3 (pytype)
- GitHub Check: quick-py3 (macOS-latest)
🔇 Additional comments (6)
monai/inferers/inferer.py (6)
919-919
: cfg_fill_value API addition looks good and BC-safe.Default keeps previous behavior; docstring explains intent clearly.
Also applies to: 933-934
1266-1267
: cfg_fill_value in LatentDiffusionInferer.sample is correctly plumbed and documented.Clear docstring; default preserves behavior.
Also applies to: 1282-1283
1307-1308
: Propagation to super().sample is correct.
1487-1488
: cfg_fill_value in ControlNetDiffusionInferer.sample is consistent and well documented.Also applies to: 1502-1504
1849-1850
: cfg_fill_value in ControlNetLatentDiffusionInferer.sample is consistent.Also applies to: 1867-1868
1896-1897
: Propagation to super().sample is correct.
/build |
Fixes #8560.
Description
This PR adds an argument cfg_fill_value so that users can control the value that replaces the conditioning tensor during classifier-free guidance inference.
Previously, this was set to -1, which might not be ideal for certain application where the conditioning could have -1 as a normal value.
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.