Skip to content

Commit d9cfe32

Browse files
CompVis -> diffusers script - allow converting from merged checkpoint to either EMA or non-EMA (#991)
* improve script * up
1 parent 0343d8f commit d9cfe32

File tree

1 file changed

+36
-4
lines changed

1 file changed

+36
-4
lines changed

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,15 +285,34 @@ def create_ldm_bert_config(original_config):
285285
return config
286286

287287

288-
def convert_ldm_unet_checkpoint(checkpoint, config):
288+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False):
289289
"""
290290
Takes a state dict and a config, and returns a converted checkpoint.
291291
"""
292292

293293
# extract state_dict for UNet
294294
unet_state_dict = {}
295-
unet_key = "model.diffusion_model."
296295
keys = list(checkpoint.keys())
296+
297+
unet_key = "model.diffusion_model."
298+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
299+
if sum(k.startswith("model_ema") for k in keys) > 100:
300+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
301+
if extract_ema:
302+
print(
303+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
304+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
305+
)
306+
for key in keys:
307+
if key.startswith("model.diffusion_model"):
308+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
309+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
310+
else:
311+
print(
312+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
313+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
314+
)
315+
297316
for key in keys:
298317
if key.startswith(unet_key):
299318
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
@@ -630,6 +649,15 @@ def convert_ldm_clip_checkpoint(checkpoint):
630649
type=str,
631650
help="Type of scheduler to use. Should be one of ['pndm', 'lms', 'ddim']",
632651
)
652+
parser.add_argument(
653+
"--extract_ema",
654+
action="store_true",
655+
help=(
656+
"Only relevant for checkpoints that have both EMA and non-EMA weights. Whether to extract the EMA weights"
657+
" or not. Defaults to `False`. Add `--extract_ema` to extract the EMA weights. EMA weights usually yield"
658+
" higher quality images for inference. Non-EMA weights are usually better to continue fine-tuning."
659+
),
660+
)
633661
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
634662

635663
args = parser.parse_args()
@@ -641,7 +669,9 @@ def convert_ldm_clip_checkpoint(checkpoint):
641669
args.original_config_file = "./v1-inference.yaml"
642670

643671
original_config = OmegaConf.load(args.original_config_file)
644-
checkpoint = torch.load(args.checkpoint_path)["state_dict"]
672+
673+
checkpoint = torch.load(args.checkpoint_path)
674+
checkpoint = checkpoint["state_dict"]
645675

646676
num_train_timesteps = original_config.model.params.timesteps
647677
beta_start = original_config.model.params.linear_start
@@ -669,7 +699,9 @@ def convert_ldm_clip_checkpoint(checkpoint):
669699

670700
# Convert the UNet2DConditionModel model.
671701
unet_config = create_unet_diffusers_config(original_config)
672-
converted_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config)
702+
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
703+
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
704+
)
673705

674706
unet = UNet2DConditionModel(**unet_config)
675707
unet.load_state_dict(converted_unet_checkpoint)

0 commit comments

Comments
 (0)