From f34bf742d0f8009008d02c17679ffd63de00c3fb Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 3 Nov 2022 14:25:51 +0100 Subject: [PATCH 1/3] Allow saving `None` pipeline components --- src/diffusers/pipeline_utils.py | 9 +++++++++ .../stable_diffusion/test_stable_diffusion.py | 11 +++++++++++ 2 files changed, 20 insertions(+) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 94c1e135abe5..4ba8d2d9300d 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -176,6 +176,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]): for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) + if sub_model is None: + # edge case for saving a pipeline with safety_checker=None + continue + model_cls = sub_model.__class__ save_method_name = None @@ -477,6 +481,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): + if class_name is None: + # edge case for when the pipeline was saved with safety_checker=None + init_kwargs[name] = None + continue + # 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names if class_name.startswith("Flax"): class_name = class_name[4:] diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 260d58e94b04..6c15021ef9fd 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -17,6 +17,7 @@ import random import time import unittest +import tempfile import numpy as np import torch @@ -318,6 +319,16 @@ def test_stable_diffusion_no_safety_checker(self): image = pipe("example prompt", num_inference_steps=2).images[0] assert image is not None + # check that there's no error when saving a pipeline with one of the models being None + with tempfile.TemporaryDirectory() as tmpdirname: + pipe.save_pretrained(tmpdirname) + pipe = StableDiffusionPipeline.from_pretrained(tmpdirname) + + # sanity check that the pipeline still works + assert pipe.safety_checker is None + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + def test_stable_diffusion_k_lms(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet From a56b5777a788257851d7e34745732ad5081ca33e Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 3 Nov 2022 15:34:10 +0100 Subject: [PATCH 2/3] support flax as well --- src/diffusers/pipeline_flax_utils.py | 8 ++++++++ tests/pipelines/stable_diffusion/test_stable_diffusion.py | 2 +- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index c281c772dbd2..bc1230f669d8 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -161,6 +161,9 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union for pipeline_component_name in model_index_dict.keys(): sub_model = getattr(self, pipeline_component_name) + if sub_model is None: + # edge case for saving a pipeline with safety_checker=None + continue model_cls = sub_model.__class__ save_method_name = None @@ -367,6 +370,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P # 3. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): + if class_name is None: + # edge case for when the pipeline was saved with safety_checker=None + init_kwargs[name] = None + continue + is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None sub_model_should_be_defined = True diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 6c15021ef9fd..ded2470cc2e5 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -15,9 +15,9 @@ import gc import random +import tempfile import time import unittest -import tempfile import numpy as np import torch From 008c42cde49399fc3ed477275637a2b95286f1b9 Mon Sep 17 00:00:00 2001 From: anton-l Date: Thu, 3 Nov 2022 15:34:48 +0100 Subject: [PATCH 3/3] style --- src/diffusers/pipeline_flax_utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index bc1230f669d8..e63009b49c8f 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -164,6 +164,7 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union if sub_model is None: # edge case for saving a pipeline with safety_checker=None continue + model_cls = sub_model.__class__ save_method_name = None