diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 376c1e8726de..a25000aa36c9 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -15,6 +15,7 @@ """ Conversion script for the LDM checkpoints. """ import argparse +import importlib import torch @@ -133,8 +134,22 @@ required=False, help="Set to a path, hub id to an already converted vae to not convert it again.", ) + parser.add_argument( + "--pipeline_class_name", + type=str, + default=None, + required=False, + help="Specify the pipeline class name", + ) + args = parser.parse_args() + if args.pipeline_class_name is not None: + library = importlib.import_module("diffusers") + class_obj = getattr(library, args.pipeline_class_name) + else: + pipeline_class = None + pipe = download_from_original_stable_diffusion_ckpt( checkpoint_path=args.checkpoint_path, original_config_file=args.original_config_file, @@ -152,6 +167,7 @@ clip_stats_path=args.clip_stats_path, controlnet=args.controlnet, vae_path=args.vae_path, + pipeline_class=pipeline_class, ) if args.half: