Skip to content

Commit ec0f594

Browse files
committed
wip
1 parent 1602604 commit ec0f594

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
AutoencoderKL,
4545
DDPMScheduler,
4646
DiffusionPipeline,
47-
DPMSolverMultistepScheduler,
4847
StableDiffusionPipeline,
4948
UNet2DConditionModel,
5049
)
@@ -1092,7 +1091,11 @@ def compute_text_embeddings(prompt):
10921091
revision=args.revision,
10931092
torch_dtype=weight_dtype,
10941093
)
1095-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1094+
1095+
# TODO temp hack for IF
1096+
# pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1097+
pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config, variance_type="fixed_small")
1098+
10961099
pipeline = pipeline.to(accelerator.device)
10971100
pipeline.set_progress_bar_config(disable=True)
10981101

@@ -1143,7 +1146,11 @@ def compute_text_embeddings(prompt):
11431146
pipeline = DiffusionPipeline.from_pretrained(
11441147
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
11451148
)
1146-
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1149+
1150+
# TODO temp for IF
1151+
# pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1152+
pipeline.scheduler = DDPMScheduler.from_config(pipeline.scheduler.config, variance_type="fixed_small")
1153+
11471154
pipeline = pipeline.to(accelerator.device)
11481155

11491156
# load attention processors
@@ -1153,7 +1160,7 @@ def compute_text_embeddings(prompt):
11531160
if args.validation_prompt and args.num_validation_images > 0:
11541161
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
11551162
images = [
1156-
pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]
1163+
pipeline(args.validation_prompt, generator=generator).images[0]
11571164
for _ in range(args.num_validation_images)
11581165
]
11591166

0 commit comments

Comments
 (0)