Skip to content

Commit a971c59

Browse files
yiyixuxuyiyixuxupatrickvonplaten
authored
fix auto_pipeline: pass kwargs to load_config (#4793)
* fix --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 934d439 commit a971c59

File tree

2 files changed

+100
-3
lines changed

2 files changed

+100
-3
lines changed

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from collections import OrderedDict
1818

1919
from ..configuration_utils import ConfigMixin
20+
from ..utils import DIFFUSERS_CACHE
2021
from .controlnet import (
2122
StableDiffusionControlNetImg2ImgPipeline,
2223
StableDiffusionControlNetInpaintPipeline,
@@ -295,14 +296,37 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
295296
>>> image = pipeline(prompt).images[0]
296297
```
297298
"""
298-
config = cls.load_config(pretrained_model_or_path)
299+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
300+
force_download = kwargs.pop("force_download", False)
301+
resume_download = kwargs.pop("resume_download", False)
302+
proxies = kwargs.pop("proxies", None)
303+
use_auth_token = kwargs.pop("use_auth_token", None)
304+
local_files_only = kwargs.pop("local_files_only", False)
305+
revision = kwargs.pop("revision", None)
306+
subfolder = kwargs.pop("subfolder", None)
307+
user_agent = kwargs.pop("user_agent", {})
308+
309+
load_config_kwargs = {
310+
"cache_dir": cache_dir,
311+
"force_download": force_download,
312+
"resume_download": resume_download,
313+
"proxies": proxies,
314+
"use_auth_token": use_auth_token,
315+
"local_files_only": local_files_only,
316+
"revision": revision,
317+
"subfolder": subfolder,
318+
"user_agent": user_agent,
319+
}
320+
321+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
299322
orig_class_name = config["_class_name"]
300323

301324
if "controlnet" in kwargs:
302325
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
303326

304327
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
305328

329+
kwargs = {**load_config_kwargs, **kwargs}
306330
return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
307331

308332
@classmethod
@@ -535,14 +559,37 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
535559
>>> image = pipeline(prompt, image).images[0]
536560
```
537561
"""
538-
config = cls.load_config(pretrained_model_or_path)
562+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
563+
force_download = kwargs.pop("force_download", False)
564+
resume_download = kwargs.pop("resume_download", False)
565+
proxies = kwargs.pop("proxies", None)
566+
use_auth_token = kwargs.pop("use_auth_token", None)
567+
local_files_only = kwargs.pop("local_files_only", False)
568+
revision = kwargs.pop("revision", None)
569+
subfolder = kwargs.pop("subfolder", None)
570+
user_agent = kwargs.pop("user_agent", {})
571+
572+
load_config_kwargs = {
573+
"cache_dir": cache_dir,
574+
"force_download": force_download,
575+
"resume_download": resume_download,
576+
"proxies": proxies,
577+
"use_auth_token": use_auth_token,
578+
"local_files_only": local_files_only,
579+
"revision": revision,
580+
"subfolder": subfolder,
581+
"user_agent": user_agent,
582+
}
583+
584+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
539585
orig_class_name = config["_class_name"]
540586

541587
if "controlnet" in kwargs:
542588
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
543589

544590
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
545591

592+
kwargs = {**load_config_kwargs, **kwargs}
546593
return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs)
547594

548595
@classmethod
@@ -776,14 +823,37 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs):
776823
>>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0]
777824
```
778825
"""
779-
config = cls.load_config(pretrained_model_or_path)
826+
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
827+
force_download = kwargs.pop("force_download", False)
828+
resume_download = kwargs.pop("resume_download", False)
829+
proxies = kwargs.pop("proxies", None)
830+
use_auth_token = kwargs.pop("use_auth_token", None)
831+
local_files_only = kwargs.pop("local_files_only", False)
832+
revision = kwargs.pop("revision", None)
833+
subfolder = kwargs.pop("subfolder", None)
834+
user_agent = kwargs.pop("user_agent", {})
835+
836+
load_config_kwargs = {
837+
"cache_dir": cache_dir,
838+
"force_download": force_download,
839+
"resume_download": resume_download,
840+
"proxies": proxies,
841+
"use_auth_token": use_auth_token,
842+
"local_files_only": local_files_only,
843+
"revision": revision,
844+
"subfolder": subfolder,
845+
"user_agent": user_agent,
846+
}
847+
848+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
780849
orig_class_name = config["_class_name"]
781850

782851
if "controlnet" in kwargs:
783852
orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline")
784853

785854
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
786855

856+
kwargs = {**load_config_kwargs, **kwargs}
787857
return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs)
788858

789859
@classmethod

tests/pipelines/test_pipelines_auto.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
# limitations under the License.
1515

1616
import gc
17+
import os
18+
import shutil
1719
import unittest
1820
from collections import OrderedDict
21+
from pathlib import Path
1922

2023
import torch
2124

@@ -24,6 +27,7 @@
2427
AutoPipelineForInpainting,
2528
AutoPipelineForText2Image,
2629
ControlNetModel,
30+
DiffusionPipeline,
2731
)
2832
from diffusers.pipelines.auto_pipeline import (
2933
AUTO_IMAGE2IMAGE_PIPELINES_MAPPING,
@@ -81,6 +85,29 @@ def test_from_pipe_consistent_sdxl(self):
8185

8286
assert dict(pipe.config) == original_config
8387

88+
def test_kwargs_local_files_only(self):
89+
repo = "hf-internal-testing/tiny-stable-diffusion-torch"
90+
tmpdirname = DiffusionPipeline.download(repo)
91+
tmpdirname = Path(tmpdirname)
92+
93+
# edit commit_id to so that it's not the latest commit
94+
commit_id = tmpdirname.name
95+
new_commit_id = commit_id + "hug"
96+
97+
ref_dir = tmpdirname.parent.parent / "refs/main"
98+
with open(ref_dir, "w") as f:
99+
f.write(new_commit_id)
100+
101+
new_tmpdirname = tmpdirname.parent / new_commit_id
102+
os.rename(tmpdirname, new_tmpdirname)
103+
104+
try:
105+
AutoPipelineForText2Image.from_pretrained(repo, local_files_only=True)
106+
except OSError:
107+
assert False, "not able to load local files"
108+
109+
shutil.rmtree(tmpdirname.parent.parent)
110+
84111

85112
@slow
86113
class AutoPipelineIntegrationTest(unittest.TestCase):

0 commit comments

Comments
 (0)