Skip to content

Commit aef11cb

Browse files
yiyixuxuyiyixuxupatrickvonplaten
authored
add pipeline_class_name argument to Stable Diffusion conversion script (#4461)
* add pipeline class * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * style --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 71c8224 commit aef11cb

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
""" Conversion script for the LDM checkpoints. """
1616

1717
import argparse
18+
import importlib
1819

1920
import torch
2021

@@ -133,8 +134,22 @@
133134
required=False,
134135
help="Set to a path, hub id to an already converted vae to not convert it again.",
135136
)
137+
parser.add_argument(
138+
"--pipeline_class_name",
139+
type=str,
140+
default=None,
141+
required=False,
142+
help="Specify the pipeline class name",
143+
)
144+
136145
args = parser.parse_args()
137146

147+
if args.pipeline_class_name is not None:
148+
library = importlib.import_module("diffusers")
149+
class_obj = getattr(library, args.pipeline_class_name)
150+
else:
151+
pipeline_class = None
152+
138153
pipe = download_from_original_stable_diffusion_ckpt(
139154
checkpoint_path=args.checkpoint_path,
140155
original_config_file=args.original_config_file,
@@ -152,6 +167,7 @@
152167
clip_stats_path=args.clip_stats_path,
153168
controlnet=args.controlnet,
154169
vae_path=args.vae_path,
170+
pipeline_class=pipeline_class,
155171
)
156172

157173
if args.half:

0 commit comments

Comments
 (0)