@@ -1092,7 +1092,17 @@ def compute_text_embeddings(prompt):
10921092 revision = args .revision ,
10931093 torch_dtype = weight_dtype ,
10941094 )
1095- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config )
1095+
1096+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1097+ variance_type = pipeline .scheduler .config .variance_type
1098+
1099+ if variance_type in ["learned" , "learned_range" ]:
1100+ variance_type = "fixed_small"
1101+
1102+ pipeline .scheduler = DPMSolverMultistepScheduler .from_config (
1103+ pipeline .scheduler .config , variance_type = variance_type
1104+ )
1105+
10961106 pipeline = pipeline .to (accelerator .device )
10971107 pipeline .set_progress_bar_config (disable = True )
10981108
@@ -1143,7 +1153,17 @@ def compute_text_embeddings(prompt):
11431153 pipeline = DiffusionPipeline .from_pretrained (
11441154 args .pretrained_model_name_or_path , revision = args .revision , torch_dtype = weight_dtype
11451155 )
1146- pipeline .scheduler = DPMSolverMultistepScheduler .from_config (pipeline .scheduler .config )
1156+
1157+ # We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
1158+ variance_type = pipeline .scheduler .config .variance_type
1159+
1160+ if variance_type in ["learned" , "learned_range" ]:
1161+ variance_type = "fixed_small"
1162+
1163+ pipeline .scheduler = DPMSolverMultistepScheduler .from_config (
1164+ pipeline .scheduler .config , variance_type = variance_type
1165+ )
1166+
11471167 pipeline = pipeline .to (accelerator .device )
11481168
11491169 # load attention processors
0 commit comments