Skip to content

Commit 36f7ef5

Browse files
jachiampatrickvonplaten
authored andcommitted
Checkpoint conversion script from Diffusers => Stable Diffusion (CompVis) (huggingface#701)
* Conversion script * ran black * ran isort * remove unused import * map location so everything gets loaded onto CPU before conversion * ran black again * Update setup.py Co-authored-by: Patrick von Platen <[email protected]>
1 parent a3e21e9 commit 36f7ef5

File tree

2 files changed

+237
-9
lines changed

2 files changed

+237
-9
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint.
2+
# *Only* converts the UNet, VAE, and Text Encoder.
3+
# Does not convert optimizer state or any other thing.
4+
5+
import argparse
6+
import os.path as osp
7+
8+
import torch
9+
10+
11+
# =================#
12+
# UNet Conversion #
13+
# =================#
14+
15+
unet_conversion_map = [
16+
# (stable-diffusion, HF Diffusers)
17+
("time_embed.0.weight", "time_embedding.linear_1.weight"),
18+
("time_embed.0.bias", "time_embedding.linear_1.bias"),
19+
("time_embed.2.weight", "time_embedding.linear_2.weight"),
20+
("time_embed.2.bias", "time_embedding.linear_2.bias"),
21+
("input_blocks.0.0.weight", "conv_in.weight"),
22+
("input_blocks.0.0.bias", "conv_in.bias"),
23+
("out.0.weight", "conv_norm_out.weight"),
24+
("out.0.bias", "conv_norm_out.bias"),
25+
("out.2.weight", "conv_out.weight"),
26+
("out.2.bias", "conv_out.bias"),
27+
]
28+
29+
unet_conversion_map_resnet = [
30+
# (stable-diffusion, HF Diffusers)
31+
("in_layers.0", "norm1"),
32+
("in_layers.2", "conv1"),
33+
("out_layers.0", "norm2"),
34+
("out_layers.3", "conv2"),
35+
("emb_layers.1", "time_emb_proj"),
36+
("skip_connection", "conv_shortcut"),
37+
]
38+
39+
unet_conversion_map_layer = []
40+
# hardcoded number of downblocks and resnets/attentions...
41+
# would need smarter logic for other networks.
42+
for i in range(4):
43+
# loop over downblocks/upblocks
44+
45+
for j in range(2):
46+
# loop over resnets/attentions for downblocks
47+
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
48+
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
49+
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
50+
51+
if i < 3:
52+
# no attention layers in down_blocks.3
53+
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
54+
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
55+
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
56+
57+
for j in range(3):
58+
# loop over resnets/attentions for upblocks
59+
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
60+
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
61+
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
62+
63+
if i > 0:
64+
# no attention layers in up_blocks.0
65+
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
66+
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
67+
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
68+
69+
if i < 3:
70+
# no downsample in down_blocks.3
71+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
72+
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
73+
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
74+
75+
# no upsample in up_blocks.3
76+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
77+
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
78+
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
79+
80+
hf_mid_atn_prefix = "mid_block.attentions.0."
81+
sd_mid_atn_prefix = "middle_block.1."
82+
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
83+
84+
for j in range(2):
85+
hf_mid_res_prefix = f"mid_block.resnets.{j}."
86+
sd_mid_res_prefix = f"middle_block.{2*j}."
87+
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
88+
89+
90+
def convert_unet_state_dict(unet_state_dict):
91+
# buyer beware: this is a *brittle* function,
92+
# and correct output requires that all of these pieces interact in
93+
# the exact order in which I have arranged them.
94+
mapping = {k: k for k in unet_state_dict.keys()}
95+
for sd_name, hf_name in unet_conversion_map:
96+
mapping[hf_name] = sd_name
97+
for k, v in mapping.items():
98+
if "resnets" in k:
99+
for sd_part, hf_part in unet_conversion_map_resnet:
100+
v = v.replace(hf_part, sd_part)
101+
mapping[k] = v
102+
for k, v in mapping.items():
103+
for sd_part, hf_part in unet_conversion_map_layer:
104+
v = v.replace(hf_part, sd_part)
105+
mapping[k] = v
106+
new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()}
107+
return new_state_dict
108+
109+
110+
# ================#
111+
# VAE Conversion #
112+
# ================#
113+
114+
vae_conversion_map = [
115+
# (stable-diffusion, HF Diffusers)
116+
("nin_shortcut", "conv_shortcut"),
117+
("norm_out", "conv_norm_out"),
118+
("mid.attn_1.", "mid_block.attentions.0."),
119+
]
120+
121+
for i in range(4):
122+
# down_blocks have two resnets
123+
for j in range(2):
124+
hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}."
125+
sd_down_prefix = f"encoder.down.{i}.block.{j}."
126+
vae_conversion_map.append((sd_down_prefix, hf_down_prefix))
127+
128+
if i < 3:
129+
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0."
130+
sd_downsample_prefix = f"down.{i}.downsample."
131+
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
132+
133+
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
134+
sd_upsample_prefix = f"up.{3-i}.upsample."
135+
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
136+
137+
# up_blocks have three resnets
138+
# also, up blocks in hf are numbered in reverse from sd
139+
for j in range(3):
140+
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
141+
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
142+
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
143+
144+
# this part accounts for mid blocks in both the encoder and the decoder
145+
for i in range(2):
146+
hf_mid_res_prefix = f"mid_block.resnets.{i}."
147+
sd_mid_res_prefix = f"mid.block_{i+1}."
148+
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
149+
150+
151+
vae_conversion_map_attn = [
152+
# (stable-diffusion, HF Diffusers)
153+
("norm.", "group_norm."),
154+
("q.", "query."),
155+
("k.", "key."),
156+
("v.", "value."),
157+
("proj_out.", "proj_attn."),
158+
]
159+
160+
161+
def reshape_weight_for_sd(w):
162+
# convert HF linear weights to SD conv2d weights
163+
return w.reshape(*w.shape, 1, 1)
164+
165+
166+
def convert_vae_state_dict(vae_state_dict):
167+
mapping = {k: k for k in vae_state_dict.keys()}
168+
for k, v in mapping.items():
169+
for sd_part, hf_part in vae_conversion_map:
170+
v = v.replace(hf_part, sd_part)
171+
mapping[k] = v
172+
for k, v in mapping.items():
173+
if "attentions" in k:
174+
for sd_part, hf_part in vae_conversion_map_attn:
175+
v = v.replace(hf_part, sd_part)
176+
mapping[k] = v
177+
new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()}
178+
weights_to_convert = ["q", "k", "v", "proj_out"]
179+
for k, v in new_state_dict.items():
180+
for weight_name in weights_to_convert:
181+
if f"mid.attn_1.{weight_name}.weight" in k:
182+
print(f"Reshaping {k} for SD format")
183+
new_state_dict[k] = reshape_weight_for_sd(v)
184+
return new_state_dict
185+
186+
187+
# =========================#
188+
# Text Encoder Conversion #
189+
# =========================#
190+
# pretty much a no-op
191+
192+
193+
def convert_text_enc_state_dict(text_enc_dict):
194+
return text_enc_dict
195+
196+
197+
if __name__ == "__main__":
198+
parser = argparse.ArgumentParser()
199+
200+
parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.")
201+
parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.")
202+
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
203+
204+
args = parser.parse_args()
205+
206+
assert args.model_path is not None, "Must provide a model path!"
207+
208+
assert args.checkpoint_path is not None, "Must provide a checkpoint path!"
209+
210+
unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin")
211+
vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin")
212+
text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin")
213+
214+
# Convert the UNet model
215+
unet_state_dict = torch.load(unet_path, map_location="cpu")
216+
unet_state_dict = convert_unet_state_dict(unet_state_dict)
217+
unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()}
218+
219+
# Convert the VAE model
220+
vae_state_dict = torch.load(vae_path, map_location="cpu")
221+
vae_state_dict = convert_vae_state_dict(vae_state_dict)
222+
vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()}
223+
224+
# Convert the text encoder model
225+
text_enc_dict = torch.load(text_enc_path, map_location="cpu")
226+
text_enc_dict = convert_text_enc_state_dict(text_enc_dict)
227+
text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()}
228+
229+
# Put together new checkpoint
230+
state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict}
231+
if args.half:
232+
state_dict = {k: v.half() for k, v in state_dict.items()}
233+
state_dict = {"state_dict": state_dict}
234+
torch.save(state_dict, args.checkpoint_path)

setup.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,13 @@
6767
you need to go back to main before executing this.
6868
"""
6969

70-
import re
7170
import os
71+
import re
7272
from distutils.core import Command
7373

7474
from setuptools import find_packages, setup
7575

76+
7677
# IMPORTANT:
7778
# 1. all dependencies should be listed here with their version requirements if any
7879
# 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py
@@ -177,14 +178,7 @@ def run(self):
177178
extras["docs"] = deps_list("hf-doc-builder")
178179
extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards")
179180
extras["test"] = deps_list(
180-
"datasets",
181-
"onnxruntime",
182-
"pytest",
183-
"pytest-timeout",
184-
"pytest-xdist",
185-
"scipy",
186-
"torchvision",
187-
"transformers"
181+
"datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers"
188182
)
189183
extras["torch"] = deps_list("torch")
190184

0 commit comments

Comments
 (0)