Skip to content

Commit b5bc69d

Browse files
virginiafdezVirginia Fernandez
andauthored
Classifier free guidance unconditioned value (#8562)
Fixes #8560. ### Description This PR adds an argument cfg_fill_value so that users can control the value that replaces the conditioning tensor during classifier-free guidance inference. Previously, this was set to -1, which might not be ideal for certain application where the conditioning could have -1 as a normal value. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [x] New tests added to cover the changes **modified existing ones - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. --------- Co-authored-by: Virginia Fernandez <[email protected]>
1 parent 0968da2 commit b5bc69d

File tree

3 files changed

+15
-3
lines changed

3 files changed

+15
-3
lines changed

monai/inferers/inferer.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -916,6 +916,7 @@ def sample(
916916
verbose: bool = True,
917917
seg: torch.Tensor | None = None,
918918
cfg: float | None = None,
919+
cfg_fill_value: float = -1.0,
919920
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
920921
"""
921922
Args:
@@ -929,6 +930,7 @@ def sample(
929930
verbose: if true, prints the progression bar of the sampling process.
930931
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
931932
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
933+
cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
932934
"""
933935
if mode not in ["crossattn", "concat"]:
934936
raise NotImplementedError(f"{mode} condition is not supported")
@@ -961,7 +963,7 @@ def sample(
961963
model_input = torch.cat([image] * 2, dim=0)
962964
if conditioning is not None:
963965
uncondition = torch.ones_like(conditioning)
964-
uncondition.fill_(-1)
966+
uncondition.fill_(cfg_fill_value)
965967
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
966968
else:
967969
conditioning_input = None
@@ -1261,6 +1263,7 @@ def sample( # type: ignore[override]
12611263
verbose: bool = True,
12621264
seg: torch.Tensor | None = None,
12631265
cfg: float | None = None,
1266+
cfg_fill_value: float = -1.0,
12641267
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
12651268
"""
12661269
Args:
@@ -1276,6 +1279,7 @@ def sample( # type: ignore[override]
12761279
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
12771280
is instance of SPADEAutoencoderKL, segmentation must be provided.
12781281
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1282+
cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
12791283
"""
12801284

12811285
if (
@@ -1300,6 +1304,7 @@ def sample( # type: ignore[override]
13001304
verbose=verbose,
13011305
seg=seg,
13021306
cfg=cfg,
1307+
cfg_fill_value=cfg_fill_value,
13031308
)
13041309

13051310
if save_intermediates:
@@ -1479,6 +1484,7 @@ def sample( # type: ignore[override]
14791484
verbose: bool = True,
14801485
seg: torch.Tensor | None = None,
14811486
cfg: float | None = None,
1487+
cfg_fill_value: float = -1.0,
14821488
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
14831489
"""
14841490
Args:
@@ -1493,7 +1499,8 @@ def sample( # type: ignore[override]
14931499
mode: Conditioning mode for the network.
14941500
verbose: if true, prints the progression bar of the sampling process.
14951501
seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
1496-
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1502+
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1503+
cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
14971504
"""
14981505
if mode not in ["crossattn", "concat"]:
14991506
raise NotImplementedError(f"{mode} condition is not supported")
@@ -1521,7 +1528,7 @@ def sample( # type: ignore[override]
15211528
model_input = torch.cat([image] * 2, dim=0)
15221529
if conditioning is not None:
15231530
uncondition = torch.ones_like(conditioning)
1524-
uncondition.fill_(-1)
1531+
uncondition.fill_(cfg_fill_value)
15251532
conditioning_input = torch.cat([uncondition, conditioning], dim=0)
15261533
else:
15271534
conditioning_input = None
@@ -1839,6 +1846,7 @@ def sample( # type: ignore[override]
18391846
verbose: bool = True,
18401847
seg: torch.Tensor | None = None,
18411848
cfg: float | None = None,
1849+
cfg_fill_value: float = -1.0,
18421850
) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]:
18431851
"""
18441852
Args:
@@ -1856,6 +1864,7 @@ def sample( # type: ignore[override]
18561864
seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
18571865
is instance of SPADEAutoencoderKL, segmentation must be provided.
18581866
cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
1867+
cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
18591868
"""
18601869

18611870
if (
@@ -1884,6 +1893,7 @@ def sample( # type: ignore[override]
18841893
verbose=verbose,
18851894
seg=seg,
18861895
cfg=cfg,
1896+
cfg_fill_value=cfg_fill_value,
18871897
)
18881898

18891899
if save_intermediates:

tests/inferers/test_diffusion_inferer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ def test_sample_cfg(self, model_params, input_shape):
106106
save_intermediates=True,
107107
intermediate_steps=1,
108108
cfg=5,
109+
cfg_fill_value=-1,
109110
)
110111
self.assertEqual(sample.shape, noise.shape)
111112

tests/inferers/test_latent_diffusion_inferer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,7 @@ def test_sample_shape_with_cfg(
456456
scheduler=scheduler,
457457
seg=input_seg,
458458
cfg=5,
459+
cfg_fill_value=-1,
459460
)
460461
else:
461462
sample = inferer.sample(

0 commit comments

Comments
 (0)