From f47c2d21a4d25ab99daae5e955f8bd11606c6e67 Mon Sep 17 00:00:00 2001 From: Josh Date: Sun, 2 Oct 2022 16:12:52 -0700 Subject: [PATCH 1/7] Conversion script --- ..._diffusers_to_original_stable_diffusion.py | 226 ++++++++++++++++++ 1 file changed, 226 insertions(+) create mode 100644 scripts/convert_diffusers_to_original_stable_diffusion.py diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py new file mode 100644 index 000000000000..124992c3180e --- /dev/null +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -0,0 +1,226 @@ +# Script for converting a HF Diffusers saved pipeline to a Stable Diffusion checkpoint. +# *Only* converts the UNet, VAE, and Text Encoder. +# Does not convert optimizer state or any other thing. + +import argparse +import os +import os.path as osp +import torch + +#=================# +# UNet Conversion # +#=================# + +unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ('time_embed.0.weight', 'time_embedding.linear_1.weight'), + ('time_embed.0.bias', 'time_embedding.linear_1.bias'), + ('time_embed.2.weight', 'time_embedding.linear_2.weight'), + ('time_embed.2.bias', 'time_embedding.linear_2.bias'), + ('input_blocks.0.0.weight', 'conv_in.weight'), + ('input_blocks.0.0.bias', 'conv_in.bias'), + ('out.0.weight', 'conv_norm_out.weight'), + ('out.0.bias', 'conv_norm_out.bias'), + ('out.2.weight', 'conv_out.weight'), + ('out.2.bias', 'conv_out.bias') +] + +unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut") +] + +unet_conversion_map_layer = [] +# hardcoded number of downblocks and resnets/attentions... +# would need smarter logic for other networks. +for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f'down_blocks.{i}.resnets.{j}.' + sd_down_res_prefix = f'input_blocks.{3*i + j + 1}.0.' + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f'down_blocks.{i}.attentions.{j}.' + sd_down_atn_prefix = f'input_blocks.{3*i + j + 1}.1.' + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f'up_blocks.{i}.resnets.{j}.' + sd_up_res_prefix = f'output_blocks.{3*i + j}.0.' + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f'up_blocks.{i}.attentions.{j}.' + sd_up_atn_prefix = f'output_blocks.{3*i + j}.1.' + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.conv.' + sd_downsample_prefix = f'input_blocks.{3*(i+1)}.0.op.' + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' + sd_upsample_prefix = f'output_blocks.{3*i + 2}.{1 if i == 0 else 2}.' + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + +hf_mid_atn_prefix = 'mid_block.attentions.0.' +sd_mid_atn_prefix = 'middle_block.1.' +unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + +for j in range(2): + hf_mid_res_prefix = f'mid_block.resnets.{j}.' + sd_mid_res_prefix = f'middle_block.{2*j}.' + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + +def convert_unet_state_dict(unet_state_dict): + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k:k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k,v in mapping.items(): + if 'resnets' in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k,v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v:unet_state_dict[k] for k,v in mapping.items()} + return new_state_dict + +#================# +# VAE Conversion # +#================# + +vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ('nin_shortcut', 'conv_shortcut'), + ('norm_out', 'conv_norm_out'), + ('mid.attn_1.', 'mid_block.attentions.0.') +] + +for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f'encoder.down_blocks.{i}.resnets.{j}.' + sd_down_prefix = f'encoder.down.{i}.block.{j}.' + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.' + sd_downsample_prefix = f'down.{i}.downsample.' + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' + sd_upsample_prefix = f'up.{3-i}.upsample.' + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f'decoder.up_blocks.{i}.resnets.{j}.' + sd_up_prefix = f'decoder.up.{3-i}.block.{j}.' + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + +# this part accounts for mid blocks in both the encoder and the decoder +for i in range(2): + hf_mid_res_prefix = f'mid_block.resnets.{i}.' + sd_mid_res_prefix = f'mid.block_{i+1}.' + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + +vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ('norm.', 'group_norm.'), + ('q.','query.'), + ('k.','key.'), + ('v.','value.'), + ('proj_out.','proj_attn.') +] + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + mapping = {k:k for k in vae_state_dict.keys()} + for k,v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k,v in mapping.items(): + if 'attentions' in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v:vae_state_dict[k] for k,v in mapping.items()} + weights_to_convert = ['q', 'k', 'v', 'proj_out'] + for k,v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f'mid.attn_1.{weight_name}.weight' in k: + print(f'Reshaping {k} for SD format') + new_state_dict[k] = reshape_weight_for_sd(v) + return new_state_dict + +#=========================# +# Text Encoder Conversion # +#=========================# +# pretty much a no-op + +def convert_text_enc_state_dict(text_enc_dict): + return text_enc_dict + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") + parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + + args = parser.parse_args() + + assert args.model_path is not None, \ + "Must provide a model path!" + + assert args.checkpoint_path is not None, \ + "Must provide a checkpoint path!" + + unet_path = osp.join(args.model_path, 'unet', 'diffusion_pytorch_model.bin') + vae_path = osp.join(args.model_path, 'vae', 'diffusion_pytorch_model.bin') + text_enc_path = osp.join(args.model_path, 'text_encoder', 'pytorch_model.bin') + + # Convert the UNet model + unet_state_dict = torch.load(unet_path) + unet_state_dict = convert_unet_state_dict(unet_state_dict) + unet_state_dict = {"model.diffusion_model."+k:v for k,v in unet_state_dict.items()} + + # Convert the VAE model + vae_state_dict = torch.load(vae_path) + vae_state_dict = convert_vae_state_dict(vae_state_dict) + vae_state_dict = {"first_stage_model."+k:v for k,v in vae_state_dict.items()} + + # Convert the text encoder model + text_enc_dict = torch.load(text_enc_path) + text_enc_dict = convert_text_enc_state_dict(text_enc_dict) + text_enc_dict = {"cond_stage_model.transformer."+k:v for k,v in text_enc_dict.items()} + + # Put together new checkpoint + state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + state_dict = {"state_dict": state_dict} + torch.save(state_dict, args.checkpoint_path) \ No newline at end of file From 3e16a013e363d53160290a40152c081aff26bf47 Mon Sep 17 00:00:00 2001 From: Josh Date: Sun, 2 Oct 2022 16:27:31 -0700 Subject: [PATCH 2/7] ran black --- ..._diffusers_to_original_stable_diffusion.py | 156 +++++++++--------- setup.py | 9 +- 2 files changed, 81 insertions(+), 84 deletions(-) diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index 124992c3180e..c9e462c9c60e 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -7,22 +7,22 @@ import os.path as osp import torch -#=================# +# =================# # UNet Conversion # -#=================# +# =================# unet_conversion_map = [ # (stable-diffusion, HF Diffusers) - ('time_embed.0.weight', 'time_embedding.linear_1.weight'), - ('time_embed.0.bias', 'time_embedding.linear_1.bias'), - ('time_embed.2.weight', 'time_embedding.linear_2.weight'), - ('time_embed.2.bias', 'time_embedding.linear_2.bias'), - ('input_blocks.0.0.weight', 'conv_in.weight'), - ('input_blocks.0.0.bias', 'conv_in.bias'), - ('out.0.weight', 'conv_norm_out.weight'), - ('out.0.bias', 'conv_norm_out.bias'), - ('out.2.weight', 'conv_out.weight'), - ('out.2.bias', 'conv_out.bias') + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), ] unet_conversion_map_resnet = [ @@ -32,7 +32,7 @@ ("out_layers.0", "norm2"), ("out_layers.3", "conv2"), ("emb_layers.1", "time_emb_proj"), - ("skip_connection", "conv_shortcut") + ("skip_connection", "conv_shortcut"), ] unet_conversion_map_layer = [] @@ -43,150 +43,156 @@ for j in range(2): # loop over resnets/attentions for downblocks - hf_down_res_prefix = f'down_blocks.{i}.resnets.{j}.' - sd_down_res_prefix = f'input_blocks.{3*i + j + 1}.0.' + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) if i < 3: # no attention layers in down_blocks.3 - hf_down_atn_prefix = f'down_blocks.{i}.attentions.{j}.' - sd_down_atn_prefix = f'input_blocks.{3*i + j + 1}.1.' + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) for j in range(3): # loop over resnets/attentions for upblocks - hf_up_res_prefix = f'up_blocks.{i}.resnets.{j}.' - sd_up_res_prefix = f'output_blocks.{3*i + j}.0.' + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) if i > 0: # no attention layers in up_blocks.0 - hf_up_atn_prefix = f'up_blocks.{i}.attentions.{j}.' - sd_up_atn_prefix = f'output_blocks.{3*i + j}.1.' + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) if i < 3: # no downsample in down_blocks.3 - hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.conv.' - sd_downsample_prefix = f'input_blocks.{3*(i+1)}.0.op.' + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) # no upsample in up_blocks.3 - hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' - sd_upsample_prefix = f'output_blocks.{3*i + 2}.{1 if i == 0 else 2}.' + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) -hf_mid_atn_prefix = 'mid_block.attentions.0.' -sd_mid_atn_prefix = 'middle_block.1.' +hf_mid_atn_prefix = "mid_block.attentions.0." +sd_mid_atn_prefix = "middle_block.1." unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) for j in range(2): - hf_mid_res_prefix = f'mid_block.resnets.{j}.' - sd_mid_res_prefix = f'middle_block.{2*j}.' + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + def convert_unet_state_dict(unet_state_dict): # buyer beware: this is a *brittle* function, # and correct output requires that all of these pieces interact in # the exact order in which I have arranged them. - mapping = {k:k for k in unet_state_dict.keys()} + mapping = {k: k for k in unet_state_dict.keys()} for sd_name, hf_name in unet_conversion_map: mapping[hf_name] = sd_name - for k,v in mapping.items(): - if 'resnets' in k: + for k, v in mapping.items(): + if "resnets" in k: for sd_part, hf_part in unet_conversion_map_resnet: v = v.replace(hf_part, sd_part) mapping[k] = v - for k,v in mapping.items(): + for k, v in mapping.items(): for sd_part, hf_part in unet_conversion_map_layer: v = v.replace(hf_part, sd_part) mapping[k] = v - new_state_dict = {v:unet_state_dict[k] for k,v in mapping.items()} + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} return new_state_dict -#================# + +# ================# # VAE Conversion # -#================# +# ================# vae_conversion_map = [ # (stable-diffusion, HF Diffusers) - ('nin_shortcut', 'conv_shortcut'), - ('norm_out', 'conv_norm_out'), - ('mid.attn_1.', 'mid_block.attentions.0.') + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), ] for i in range(4): # down_blocks have two resnets for j in range(2): - hf_down_prefix = f'encoder.down_blocks.{i}.resnets.{j}.' - sd_down_prefix = f'encoder.down.{i}.block.{j}.' + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) if i < 3: - hf_downsample_prefix = f'down_blocks.{i}.downsamplers.0.' - sd_downsample_prefix = f'down.{i}.downsample.' + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) - hf_upsample_prefix = f'up_blocks.{i}.upsamplers.0.' - sd_upsample_prefix = f'up.{3-i}.upsample.' + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) # up_blocks have three resnets # also, up blocks in hf are numbered in reverse from sd for j in range(3): - hf_up_prefix = f'decoder.up_blocks.{i}.resnets.{j}.' - sd_up_prefix = f'decoder.up.{3-i}.block.{j}.' + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) # this part accounts for mid blocks in both the encoder and the decoder for i in range(2): - hf_mid_res_prefix = f'mid_block.resnets.{i}.' - sd_mid_res_prefix = f'mid.block_{i+1}.' + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) vae_conversion_map_attn = [ # (stable-diffusion, HF Diffusers) - ('norm.', 'group_norm.'), - ('q.','query.'), - ('k.','key.'), - ('v.','value.'), - ('proj_out.','proj_attn.') + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), ] + def reshape_weight_for_sd(w): # convert HF linear weights to SD conv2d weights return w.reshape(*w.shape, 1, 1) def convert_vae_state_dict(vae_state_dict): - mapping = {k:k for k in vae_state_dict.keys()} - for k,v in mapping.items(): + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): for sd_part, hf_part in vae_conversion_map: v = v.replace(hf_part, sd_part) mapping[k] = v - for k,v in mapping.items(): - if 'attentions' in k: + for k, v in mapping.items(): + if "attentions" in k: for sd_part, hf_part in vae_conversion_map_attn: v = v.replace(hf_part, sd_part) mapping[k] = v - new_state_dict = {v:vae_state_dict[k] for k,v in mapping.items()} - weights_to_convert = ['q', 'k', 'v', 'proj_out'] - for k,v in new_state_dict.items(): + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): for weight_name in weights_to_convert: - if f'mid.attn_1.{weight_name}.weight' in k: - print(f'Reshaping {k} for SD format') + if f"mid.attn_1.{weight_name}.weight" in k: + print(f"Reshaping {k} for SD format") new_state_dict[k] = reshape_weight_for_sd(v) return new_state_dict -#=========================# + +# =========================# # Text Encoder Conversion # -#=========================# +# =========================# # pretty much a no-op + def convert_text_enc_state_dict(text_enc_dict): return text_enc_dict + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -195,32 +201,30 @@ def convert_text_enc_state_dict(text_enc_dict): args = parser.parse_args() - assert args.model_path is not None, \ - "Must provide a model path!" + assert args.model_path is not None, "Must provide a model path!" - assert args.checkpoint_path is not None, \ - "Must provide a checkpoint path!" + assert args.checkpoint_path is not None, "Must provide a checkpoint path!" - unet_path = osp.join(args.model_path, 'unet', 'diffusion_pytorch_model.bin') - vae_path = osp.join(args.model_path, 'vae', 'diffusion_pytorch_model.bin') - text_enc_path = osp.join(args.model_path, 'text_encoder', 'pytorch_model.bin') + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") + vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") + text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") # Convert the UNet model unet_state_dict = torch.load(unet_path) unet_state_dict = convert_unet_state_dict(unet_state_dict) - unet_state_dict = {"model.diffusion_model."+k:v for k,v in unet_state_dict.items()} + unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} # Convert the VAE model vae_state_dict = torch.load(vae_path) vae_state_dict = convert_vae_state_dict(vae_state_dict) - vae_state_dict = {"first_stage_model."+k:v for k,v in vae_state_dict.items()} + vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} # Convert the text encoder model text_enc_dict = torch.load(text_enc_path) text_enc_dict = convert_text_enc_state_dict(text_enc_dict) - text_enc_dict = {"cond_stage_model.transformer."+k:v for k,v in text_enc_dict.items()} + text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} # Put together new checkpoint state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} state_dict = {"state_dict": state_dict} - torch.save(state_dict, args.checkpoint_path) \ No newline at end of file + torch.save(state_dict, args.checkpoint_path) diff --git a/setup.py b/setup.py index ff5f14564487..a965e7bfa12f 100644 --- a/setup.py +++ b/setup.py @@ -177,14 +177,7 @@ def run(self): extras["docs"] = deps_list("hf-doc-builder") extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards") extras["test"] = deps_list( - "datasets", - "onnxruntime-gpu", - "pytest", - "pytest-timeout", - "pytest-xdist", - "scipy", - "torchvision", - "transformers" + "datasets", "onnxruntime-gpu", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers" ) extras["torch"] = deps_list("torch") From abe7b8f1b52917d3b5c4e7f266093a30b0b2d529 Mon Sep 17 00:00:00 2001 From: Josh Date: Sun, 2 Oct 2022 16:40:10 -0700 Subject: [PATCH 3/7] ran isort --- scripts/convert_diffusers_to_original_stable_diffusion.py | 2 ++ setup.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index c9e462c9c60e..ed09659b0bee 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -5,8 +5,10 @@ import argparse import os import os.path as osp + import torch + # =================# # UNet Conversion # # =================# diff --git a/setup.py b/setup.py index a965e7bfa12f..9a907f47932a 100644 --- a/setup.py +++ b/setup.py @@ -67,12 +67,13 @@ you need to go back to main before executing this. """ -import re import os +import re from distutils.core import Command from setuptools import find_packages, setup + # IMPORTANT: # 1. all dependencies should be listed here with their version requirements if any # 2. once modified, run: `make deps_table_update` to update src/diffusers/dependency_versions_table.py From 6a9ccb5862da787ac2af6317f411adf34183b313 Mon Sep 17 00:00:00 2001 From: Josh Date: Sun, 2 Oct 2022 16:41:43 -0700 Subject: [PATCH 4/7] remove unused import --- scripts/convert_diffusers_to_original_stable_diffusion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index ed09659b0bee..34b4c7ef7a5d 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -3,7 +3,6 @@ # Does not convert optimizer state or any other thing. import argparse -import os import os.path as osp import torch From 3810020902db7aae2ad85ee6d77c609cf032389d Mon Sep 17 00:00:00 2001 From: Josh Date: Sun, 2 Oct 2022 17:05:05 -0700 Subject: [PATCH 5/7] map location so everything gets loaded onto CPU before conversion --- scripts/convert_diffusers_to_original_stable_diffusion.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index 34b4c7ef7a5d..f1bc55bca9fb 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -211,17 +211,17 @@ def convert_text_enc_state_dict(text_enc_dict): text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") # Convert the UNet model - unet_state_dict = torch.load(unet_path) + unet_state_dict = torch.load(unet_path, map_location='cpu') unet_state_dict = convert_unet_state_dict(unet_state_dict) unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} # Convert the VAE model - vae_state_dict = torch.load(vae_path) + vae_state_dict = torch.load(vae_path, map_location='cpu') vae_state_dict = convert_vae_state_dict(vae_state_dict) vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} # Convert the text encoder model - text_enc_dict = torch.load(text_enc_path) + text_enc_dict = torch.load(text_enc_path, map_location='cpu') text_enc_dict = convert_text_enc_state_dict(text_enc_dict) text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} From a56ec4b1c27a4664facd92e192ae2db5f6401d25 Mon Sep 17 00:00:00 2001 From: Josh Date: Sun, 2 Oct 2022 17:11:04 -0700 Subject: [PATCH 6/7] ran black again --- .../convert_diffusers_to_original_stable_diffusion.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index f1bc55bca9fb..9888f628a9e3 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -199,6 +199,7 @@ def convert_text_enc_state_dict(text_enc_dict): parser.add_argument("--model_path", default=None, type=str, required=True, help="Path to the model to convert.") parser.add_argument("--checkpoint_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument("--half", action="store_true", help="Save weights in half precision.") args = parser.parse_args() @@ -211,21 +212,23 @@ def convert_text_enc_state_dict(text_enc_dict): text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") # Convert the UNet model - unet_state_dict = torch.load(unet_path, map_location='cpu') + unet_state_dict = torch.load(unet_path, map_location="cpu") unet_state_dict = convert_unet_state_dict(unet_state_dict) unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} # Convert the VAE model - vae_state_dict = torch.load(vae_path, map_location='cpu') + vae_state_dict = torch.load(vae_path, map_location="cpu") vae_state_dict = convert_vae_state_dict(vae_state_dict) vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} # Convert the text encoder model - text_enc_dict = torch.load(text_enc_path, map_location='cpu') + text_enc_dict = torch.load(text_enc_path, map_location="cpu") text_enc_dict = convert_text_enc_state_dict(text_enc_dict) text_enc_dict = {"cond_stage_model.transformer." + k: v for k, v in text_enc_dict.items()} # Put together new checkpoint state_dict = {**unet_state_dict, **vae_state_dict, **text_enc_dict} + if args.half: + state_dict = {k: v.half() for k, v in state_dict.items()} state_dict = {"state_dict": state_dict} torch.save(state_dict, args.checkpoint_path) From 628eb4cae7fdd03477601496fd686c054cf8f814 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 4 Oct 2022 13:17:18 +0200 Subject: [PATCH 7/7] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index e990a73bc0b6..affb2e06fc56 100644 --- a/setup.py +++ b/setup.py @@ -178,7 +178,7 @@ def run(self): extras["docs"] = deps_list("hf-doc-builder") extras["training"] = deps_list("accelerate", "datasets", "tensorboard", "modelcards") extras["test"] = deps_list( - "datasets", "onnxruntime-gpu", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers" + "datasets", "onnxruntime", "pytest", "pytest-timeout", "pytest-xdist", "scipy", "torchvision", "transformers" ) extras["torch"] = deps_list("torch")