|
17 | 17 | from collections import OrderedDict |
18 | 18 |
|
19 | 19 | from ..configuration_utils import ConfigMixin |
| 20 | +from ..utils import DIFFUSERS_CACHE |
20 | 21 | from .controlnet import ( |
21 | 22 | StableDiffusionControlNetImg2ImgPipeline, |
22 | 23 | StableDiffusionControlNetInpaintPipeline, |
@@ -295,14 +296,37 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): |
295 | 296 | >>> image = pipeline(prompt).images[0] |
296 | 297 | ``` |
297 | 298 | """ |
298 | | - config = cls.load_config(pretrained_model_or_path) |
| 299 | + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) |
| 300 | + force_download = kwargs.pop("force_download", False) |
| 301 | + resume_download = kwargs.pop("resume_download", False) |
| 302 | + proxies = kwargs.pop("proxies", None) |
| 303 | + use_auth_token = kwargs.pop("use_auth_token", None) |
| 304 | + local_files_only = kwargs.pop("local_files_only", False) |
| 305 | + revision = kwargs.pop("revision", None) |
| 306 | + subfolder = kwargs.pop("subfolder", None) |
| 307 | + user_agent = kwargs.pop("user_agent", {}) |
| 308 | + |
| 309 | + load_config_kwargs = { |
| 310 | + "cache_dir": cache_dir, |
| 311 | + "force_download": force_download, |
| 312 | + "resume_download": resume_download, |
| 313 | + "proxies": proxies, |
| 314 | + "use_auth_token": use_auth_token, |
| 315 | + "local_files_only": local_files_only, |
| 316 | + "revision": revision, |
| 317 | + "subfolder": subfolder, |
| 318 | + "user_agent": user_agent, |
| 319 | + } |
| 320 | + |
| 321 | + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) |
299 | 322 | orig_class_name = config["_class_name"] |
300 | 323 |
|
301 | 324 | if "controlnet" in kwargs: |
302 | 325 | orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") |
303 | 326 |
|
304 | 327 | text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name) |
305 | 328 |
|
| 329 | + kwargs = {**load_config_kwargs, **kwargs} |
306 | 330 | return text_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) |
307 | 331 |
|
308 | 332 | @classmethod |
@@ -535,14 +559,37 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): |
535 | 559 | >>> image = pipeline(prompt, image).images[0] |
536 | 560 | ``` |
537 | 561 | """ |
538 | | - config = cls.load_config(pretrained_model_or_path) |
| 562 | + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) |
| 563 | + force_download = kwargs.pop("force_download", False) |
| 564 | + resume_download = kwargs.pop("resume_download", False) |
| 565 | + proxies = kwargs.pop("proxies", None) |
| 566 | + use_auth_token = kwargs.pop("use_auth_token", None) |
| 567 | + local_files_only = kwargs.pop("local_files_only", False) |
| 568 | + revision = kwargs.pop("revision", None) |
| 569 | + subfolder = kwargs.pop("subfolder", None) |
| 570 | + user_agent = kwargs.pop("user_agent", {}) |
| 571 | + |
| 572 | + load_config_kwargs = { |
| 573 | + "cache_dir": cache_dir, |
| 574 | + "force_download": force_download, |
| 575 | + "resume_download": resume_download, |
| 576 | + "proxies": proxies, |
| 577 | + "use_auth_token": use_auth_token, |
| 578 | + "local_files_only": local_files_only, |
| 579 | + "revision": revision, |
| 580 | + "subfolder": subfolder, |
| 581 | + "user_agent": user_agent, |
| 582 | + } |
| 583 | + |
| 584 | + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) |
539 | 585 | orig_class_name = config["_class_name"] |
540 | 586 |
|
541 | 587 | if "controlnet" in kwargs: |
542 | 588 | orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") |
543 | 589 |
|
544 | 590 | image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name) |
545 | 591 |
|
| 592 | + kwargs = {**load_config_kwargs, **kwargs} |
546 | 593 | return image_2_image_cls.from_pretrained(pretrained_model_or_path, **kwargs) |
547 | 594 |
|
548 | 595 | @classmethod |
@@ -776,14 +823,37 @@ def from_pretrained(cls, pretrained_model_or_path, **kwargs): |
776 | 823 | >>> image = pipeline(prompt, image=init_image, mask_image=mask_image).images[0] |
777 | 824 | ``` |
778 | 825 | """ |
779 | | - config = cls.load_config(pretrained_model_or_path) |
| 826 | + cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) |
| 827 | + force_download = kwargs.pop("force_download", False) |
| 828 | + resume_download = kwargs.pop("resume_download", False) |
| 829 | + proxies = kwargs.pop("proxies", None) |
| 830 | + use_auth_token = kwargs.pop("use_auth_token", None) |
| 831 | + local_files_only = kwargs.pop("local_files_only", False) |
| 832 | + revision = kwargs.pop("revision", None) |
| 833 | + subfolder = kwargs.pop("subfolder", None) |
| 834 | + user_agent = kwargs.pop("user_agent", {}) |
| 835 | + |
| 836 | + load_config_kwargs = { |
| 837 | + "cache_dir": cache_dir, |
| 838 | + "force_download": force_download, |
| 839 | + "resume_download": resume_download, |
| 840 | + "proxies": proxies, |
| 841 | + "use_auth_token": use_auth_token, |
| 842 | + "local_files_only": local_files_only, |
| 843 | + "revision": revision, |
| 844 | + "subfolder": subfolder, |
| 845 | + "user_agent": user_agent, |
| 846 | + } |
| 847 | + |
| 848 | + config = cls.load_config(pretrained_model_or_path, **load_config_kwargs) |
780 | 849 | orig_class_name = config["_class_name"] |
781 | 850 |
|
782 | 851 | if "controlnet" in kwargs: |
783 | 852 | orig_class_name = config["_class_name"].replace("Pipeline", "ControlNetPipeline") |
784 | 853 |
|
785 | 854 | inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name) |
786 | 855 |
|
| 856 | + kwargs = {**load_config_kwargs, **kwargs} |
787 | 857 | return inpainting_cls.from_pretrained(pretrained_model_or_path, **kwargs) |
788 | 858 |
|
789 | 859 | @classmethod |
|
0 commit comments