Skip to content

Commit 4a38166

Browse files
authored
Allow saving None pipeline components (#1118)
* Allow saving `None` pipeline components * support flax as well * style
1 parent 0edf9ca commit 4a38166

File tree

3 files changed

+29
-0
lines changed

3 files changed

+29
-0
lines changed

src/diffusers/pipeline_flax_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike], params: Union
161161

162162
for pipeline_component_name in model_index_dict.keys():
163163
sub_model = getattr(self, pipeline_component_name)
164+
if sub_model is None:
165+
# edge case for saving a pipeline with safety_checker=None
166+
continue
167+
164168
model_cls = sub_model.__class__
165169

166170
save_method_name = None
@@ -367,6 +371,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
367371

368372
# 3. Load each module in the pipeline
369373
for name, (library_name, class_name) in init_dict.items():
374+
if class_name is None:
375+
# edge case for when the pipeline was saved with safety_checker=None
376+
init_kwargs[name] = None
377+
continue
378+
370379
is_pipeline_module = hasattr(pipelines, library_name)
371380
loaded_sub_model = None
372381
sub_model_should_be_defined = True

src/diffusers/pipeline_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
176176

177177
for pipeline_component_name in model_index_dict.keys():
178178
sub_model = getattr(self, pipeline_component_name)
179+
if sub_model is None:
180+
# edge case for saving a pipeline with safety_checker=None
181+
continue
182+
179183
model_cls = sub_model.__class__
180184

181185
save_method_name = None
@@ -477,6 +481,11 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
477481

478482
# 3. Load each module in the pipeline
479483
for name, (library_name, class_name) in init_dict.items():
484+
if class_name is None:
485+
# edge case for when the pipeline was saved with safety_checker=None
486+
init_kwargs[name] = None
487+
continue
488+
480489
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
481490
if class_name.startswith("Flax"):
482491
class_name = class_name[4:]

tests/pipelines/stable_diffusion/test_stable_diffusion.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import gc
1717
import random
18+
import tempfile
1819
import time
1920
import unittest
2021

@@ -318,6 +319,16 @@ def test_stable_diffusion_no_safety_checker(self):
318319
image = pipe("example prompt", num_inference_steps=2).images[0]
319320
assert image is not None
320321

322+
# check that there's no error when saving a pipeline with one of the models being None
323+
with tempfile.TemporaryDirectory() as tmpdirname:
324+
pipe.save_pretrained(tmpdirname)
325+
pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
326+
327+
# sanity check that the pipeline still works
328+
assert pipe.safety_checker is None
329+
image = pipe("example prompt", num_inference_steps=2).images[0]
330+
assert image is not None
331+
321332
def test_stable_diffusion_k_lms(self):
322333
device = "cpu" # ensure determinism for the device-dependent torch.Generator
323334
unet = self.dummy_cond_unet

0 commit comments

Comments
 (0)