Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 115 additions & 35 deletions scripts/convert_original_stable_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
LDMTextToImagePipeline,
LMSDiscreteScheduler,
PNDMScheduler,
Expand Down Expand Up @@ -232,6 +233,15 @@ def create_unet_diffusers_config(original_config, image_size: int):

vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)

head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim = [5, 10, 20, 20]

config = dict(
sample_size=image_size // vae_scale_factor,
in_channels=unet_params.in_channels,
Expand All @@ -241,7 +251,8 @@ def create_unet_diffusers_config(original_config, image_size: int):
block_out_channels=tuple(block_out_channels),
layers_per_block=unet_params.num_res_blocks,
cross_attention_dim=unet_params.context_dim,
attention_head_dim=unet_params.num_heads,
attention_head_dim=head_dim,
use_linear_projection=use_linear_projection,
)

return config
Expand Down Expand Up @@ -636,6 +647,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
return text_model


def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")

# SKIP for now - need openclip -> HF conversion script here
# keys = list(checkpoint.keys())
#
# text_model_dict = {}
# for key in keys:
# if key.startswith("cond_stage_model.model.transformer"):
# text_model_dict[key[len("cond_stage_model.model.transformer.") :]] = checkpoint[key]
#
# text_model.load_state_dict(text_model_dict)

return text_model


if __name__ == "__main__":
parser = argparse.ArgumentParser()

Expand All @@ -657,13 +684,22 @@ def convert_ldm_clip_checkpoint(checkpoint):
)
parser.add_argument(
"--image_size",
default=512,
default=None,
type=int,
help=(
"The image size that the model was trained on. Use 512 for Stable Diffusion v1.X and Stable Siffusion v2"
" Base. Use 768 for Stable Diffusion v2."
),
)
parser.add_argument(
"--prediction_type",
default=None,
type=int,
help=(
"The prediction type that the model was trained on. Use 'epsilon' for Stable Diffusion v1.X and Stable"
" Siffusion v2 Base. Use 'v-prediction' for Stable Diffusion v2."
),
)
parser.add_argument(
"--extract_ema",
action="store_true",
Expand All @@ -674,73 +710,117 @@ def convert_ldm_clip_checkpoint(checkpoint):
),
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")

args = parser.parse_args()

image_size = args.image_size
prediction_type = args.prediction_type

checkpoint = torch.load(args.checkpoint_path)
global_step = checkpoint["global_step"]
checkpoint = checkpoint["state_dict"]

if args.original_config_file is None:
os.system(
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
args.original_config_file = "./v1-inference.yaml"
key_name = "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight"

if key_name in checkpoint and checkpoint[key_name].shape[-1] == 1024:
# model_type = "v2"
os.system(
"wget https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/v2-inference-v.yaml"
)
args.original_config_file = "./v2-inference-v.yaml"
else:
# model_type = "v1"
os.system(
"wget https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
)
args.original_config_file = "./v1-inference.yaml"

original_config = OmegaConf.load(args.original_config_file)

checkpoint = torch.load(args.checkpoint_path)
checkpoint = checkpoint["state_dict"]
if (
"parameterization" in original_config["model"]["params"]
and original_config["model"]["params"]["parameterization"] == "v"
):
if prediction_type is None:
# NOTE: For stable diffusion 2 base it is recommended to pass `prediction_type=="epsilon"`
# as it relies on a brittle global step parameter here
prediction_type = "epsilon" if global_step == 875000 else "v_prediction"
if image_size is None:
# NOTE: For stable diffusion 2 base one has to pass `image_size==512`
# as it relies on a brittle global step parameter here
image_size = 512 if global_step == 875000 else 768
else:
if prediction_type is None:
prediction_type = "epsilon"
if image_size is None:
image_size = 512

num_train_timesteps = original_config.model.params.timesteps
beta_start = original_config.model.params.linear_start
beta_end = original_config.model.params.linear_end

scheduler = DDIMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
steps_offset=1,
clip_sample=False,
set_alpha_to_one=False,
prediction_type=prediction_type,
)
if args.scheduler_type == "pndm":
scheduler = PNDMScheduler(
beta_end=beta_end,
beta_schedule="scaled_linear",
beta_start=beta_start,
num_train_timesteps=num_train_timesteps,
skip_prk_steps=True,
)
config = dict(scheduler.config)
config["skip_prk_steps"] = True
scheduler = PNDMScheduler.from_config(config)
elif args.scheduler_type == "lms":
scheduler = LMSDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
scheduler = LMSDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "heun":
scheduler = HeunDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "euler":
scheduler = EulerDiscreteScheduler(beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear")
scheduler = EulerDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "euler-ancestral":
scheduler = EulerAncestralDiscreteScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
scheduler = EulerAncestralDiscreteScheduler.from_config(scheduler.config)
elif args.scheduler_type == "dpm":
scheduler = DPMSolverMultistepScheduler(
beta_start=beta_start, beta_end=beta_end, beta_schedule="scaled_linear"
)
scheduler = DPMSolverMultistepScheduler.from_config(scheduler.config)
elif args.scheduler_type == "ddim":
scheduler = DDIMScheduler(
beta_start=beta_start,
beta_end=beta_end,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
)
scheduler = scheduler
else:
raise ValueError(f"Scheduler of type {args.scheduler_type} doesn't exist!")

# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=args.image_size)
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet = UNet2DConditionModel(**unet_config)

converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=args.checkpoint_path, extract_ema=args.extract_ema
)

unet = UNet2DConditionModel(**unet_config)
unet.load_state_dict(converted_unet_checkpoint)

# Convert the VAE model.
vae_config = create_vae_diffusers_config(original_config, image_size=args.image_size)
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)

vae = AutoencoderKL(**vae_config)
vae.load_state_dict(converted_vae_checkpoint)

# Convert the text model.
text_model_type = original_config.model.params.cond_stage_config.target.split(".")[-1]
if text_model_type == "FrozenCLIPEmbedder":
if text_model_type == "FrozenOpenCLIPEmbedder":
text_model = convert_open_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2", subfolder="tokenizer")
pipe = StableDiffusionPipeline(
vae=vae,
text_encoder=text_model,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
requires_safety_checker=False,
)
elif text_model_type == "FrozenCLIPEmbedder":
text_model = convert_ldm_clip_checkpoint(checkpoint)
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
safety_checker = StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker")
Expand Down