Skip to content

Commit 3be4891

Browse files
authored
feat: allow offset_noise in dreambooth training example (#2826)
1 parent d82b032 commit 3be4891

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,16 @@ def parse_args(input_args=None):
417417
),
418418
)
419419

420+
parser.add_argument(
421+
"--offset_noise",
422+
action="store_true",
423+
default=False,
424+
help=(
425+
"Fine-tuning against a modified noise"
426+
" See: https://www.crosslabs.org//blog/diffusion-with-offset-noise for more information."
427+
),
428+
)
429+
420430
if input_args is not None:
421431
args = parser.parse_args(input_args)
422432
else:
@@ -943,7 +953,12 @@ def load_model_hook(models, input_dir):
943953
latents = latents * vae.config.scaling_factor
944954

945955
# Sample noise that we'll add to the latents
946-
noise = torch.randn_like(latents)
956+
if args.offset_noise:
957+
noise = torch.randn_like(latents) + 0.1 * torch.randn(
958+
latents.shape[0], latents.shape[1], 1, 1, device=latents.device
959+
)
960+
else:
961+
noise = torch.randn_like(latents)
947962
bsz = latents.shape[0]
948963
# Sample a random timestep for each image
949964
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)

0 commit comments

Comments
 (0)