Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/diffusers/pipeline_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union

for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue

model_cls = sub_model.__class__

save_method_name = None
Expand Down Expand Up @@ -367,6 +371,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue

is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True
Expand Down
9 changes: 9 additions & 0 deletions src/diffusers/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):

for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue

model_cls = sub_model.__class__

save_method_name = None
Expand Down Expand Up @@ -477,6 +481,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue

# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]
Expand Down
11 changes: 11 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import gc
import random
import tempfile
import time
import unittest

Expand Down Expand Up @@ -318,6 +319,16 @@ def test_stable_diffusion_no_safety_checker(self):
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None

# check that there's no error when saving a pipeline with one of the models being None
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
Comment on lines +322 to +324
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Testing that the error is gone

pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)

# sanity check that the pipeline still works
assert pipe.safety_checker is None
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None

def test_stable_diffusion_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet
Expand Down