|
18 | 18 | import torch |
19 | 19 | from huggingface_hub import ModelCard |
20 | 20 |
|
21 | | -from diffusers import DiffusionPipeline, KandinskyV22CombinedPipeline, KandinskyV22Pipeline, KandinskyV22PriorPipeline |
| 21 | +from diffusers import ( |
| 22 | + DDPMScheduler, |
| 23 | + DiffusionPipeline, |
| 24 | + KandinskyV22CombinedPipeline, |
| 25 | + KandinskyV22Pipeline, |
| 26 | + KandinskyV22PriorPipeline, |
| 27 | +) |
22 | 28 | from diffusers.pipelines.pipeline_utils import CONNECTED_PIPES_KEYS |
23 | 29 |
|
24 | 30 |
|
@@ -101,3 +107,22 @@ def test_load_connected_checkpoint_default(self): |
101 | 107 | assert dict(component.config) == dict(comp.config) |
102 | 108 | else: |
103 | 109 | assert component.__class__ == comp.__class__ |
| 110 | + |
| 111 | + def test_load_connected_checkpoint_with_passed_obj(self): |
| 112 | + pipeline = KandinskyV22CombinedPipeline.from_pretrained( |
| 113 | + "hf-internal-testing/tiny-random-kandinsky-v22-decoder" |
| 114 | + ) |
| 115 | + prior_scheduler = DDPMScheduler.from_config(pipeline.prior_scheduler.config) |
| 116 | + scheduler = DDPMScheduler.from_config(pipeline.scheduler.config) |
| 117 | + |
| 118 | + # make sure we pass a different scheduler and prior_scheduler |
| 119 | + assert pipeline.prior_scheduler.__class__ != prior_scheduler.__class__ |
| 120 | + assert pipeline.scheduler.__class__ != scheduler.__class__ |
| 121 | + |
| 122 | + pipeline_new = KandinskyV22CombinedPipeline.from_pretrained( |
| 123 | + "hf-internal-testing/tiny-random-kandinsky-v22-decoder", |
| 124 | + prior_scheduler=prior_scheduler, |
| 125 | + scheduler=scheduler, |
| 126 | + ) |
| 127 | + assert dict(pipeline_new.prior_scheduler.config) == dict(prior_scheduler.config) |
| 128 | + assert dict(pipeline_new.scheduler.config) == dict(scheduler.config) |
0 commit comments