Skip to content

Commit 2c1677e

Browse files
yiyixuxuyiyixuxu
andauthored
allow passing components to connected pipelines when use the combined pipeline (#4883)
* fix * add test --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent c73e609 commit 2c1677e

File tree

2 files changed

+41
-2
lines changed

2 files changed

+41
-2
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1147,8 +1147,22 @@ def load_module(name, value):
11471147
"variant": variant,
11481148
"use_safetensors": use_safetensors,
11491149
}
1150+
1151+
def get_connected_passed_kwargs(prefix):
1152+
connected_passed_class_obj = {
1153+
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
1154+
}
1155+
connected_passed_pipe_kwargs = {
1156+
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
1157+
}
1158+
1159+
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
1160+
return connected_passed_kwargs
1161+
11501162
connected_pipes = {
1151-
prefix: DiffusionPipeline.from_pretrained(repo_id, **load_kwargs.copy())
1163+
prefix: DiffusionPipeline.from_pretrained(
1164+
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
1165+
)
11521166
for prefix, repo_id in connected_pipes.items()
11531167
if repo_id is not None
11541168
}

tests/pipelines/test_pipelines_combined.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@
1818
import torch
1919
from huggingface_hub import ModelCard
2020

21-
from diffusers import DiffusionPipeline, KandinskyV22CombinedPipeline, KandinskyV22Pipeline, KandinskyV22PriorPipeline
21+
from diffusers import (
22+
DDPMScheduler,
23+
DiffusionPipeline,
24+
KandinskyV22CombinedPipeline,
25+
KandinskyV22Pipeline,
26+
KandinskyV22PriorPipeline,
27+
)
2228
from diffusers.pipelines.pipeline_utils import CONNECTED_PIPES_KEYS
2329

2430

@@ -101,3 +107,22 @@ def test_load_connected_checkpoint_default(self):
101107
assert dict(component.config) == dict(comp.config)
102108
else:
103109
assert component.__class__ == comp.__class__
110+
111+
def test_load_connected_checkpoint_with_passed_obj(self):
112+
pipeline = KandinskyV22CombinedPipeline.from_pretrained(
113+
"hf-internal-testing/tiny-random-kandinsky-v22-decoder"
114+
)
115+
prior_scheduler = DDPMScheduler.from_config(pipeline.prior_scheduler.config)
116+
scheduler = DDPMScheduler.from_config(pipeline.scheduler.config)
117+
118+
# make sure we pass a different scheduler and prior_scheduler
119+
assert pipeline.prior_scheduler.__class__ != prior_scheduler.__class__
120+
assert pipeline.scheduler.__class__ != scheduler.__class__
121+
122+
pipeline_new = KandinskyV22CombinedPipeline.from_pretrained(
123+
"hf-internal-testing/tiny-random-kandinsky-v22-decoder",
124+
prior_scheduler=prior_scheduler,
125+
scheduler=scheduler,
126+
)
127+
assert dict(pipeline_new.prior_scheduler.config) == dict(prior_scheduler.config)
128+
assert dict(pipeline_new.scheduler.config) == dict(scheduler.config)

0 commit comments

Comments
 (0)