4444 AutoencoderKL ,
4545 DDPMScheduler ,
4646 DiffusionPipeline ,
47- DPMSolverMultistepScheduler ,
4847 StableDiffusionPipeline ,
4948 UNet2DConditionModel ,
5049)
@@ -1092,7 +1091,11 @@ def compute_text_embeddings(prompt):
10921091 revision = args .revision ,
10931092 torch_dtype = weight_dtype ,
10941093 )
1095- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config )
1094+
1095+ # TODO temp hack for IF
1096+ # pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1097+ pipeline .scheduler = DDPMScheduler .from_config (pipeline .scheduler .config , variance_type = "fixed_small" )
1098+
10961099 pipeline = pipeline .to (accelerator .device )
10971100 pipeline .set_progress_bar_config (disable = True )
10981101
@@ -1143,7 +1146,11 @@ def compute_text_embeddings(prompt):
11431146 pipeline = DiffusionPipeline .from_pretrained (
11441147 args .pretrained_model_name_or_path , revision = args .revision , torch_dtype = weight_dtype
11451148 )
1146- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config )
1149+
1150+ # TODO temp for IF
1151+ # pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
1152+ pipeline .scheduler = DDPMScheduler .from_config (pipeline .scheduler .config , variance_type = "fixed_small" )
1153+
11471154 pipeline = pipeline .to (accelerator .device )
11481155
11491156 # load attention processors
@@ -1153,7 +1160,7 @@ def compute_text_embeddings(prompt):
11531160 if args .validation_prompt and args .num_validation_images > 0 :
11541161 generator = torch .Generator (device = accelerator .device ).manual_seed (args .seed ) if args .seed else None
11551162 images = [
1156- pipeline (args .validation_prompt , num_inference_steps = 25 , generator = generator ).images [0 ]
1163+ pipeline (args .validation_prompt , generator = generator ).images [0 ]
11571164 for _ in range (args .num_validation_images )
11581165 ]
11591166
0 commit comments