diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index a7b6031d137a..fa66c2aec852 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -112,23 +112,26 @@ def register_modules(self, **kwargs): for name, module in kwargs.items(): # retrieve library - library = module.__module__.split(".")[0] + if module is None: + register_dict = {name: (None, None)} + else: + library = module.__module__.split(".")[0] - # check if the module is a pipeline module - pipeline_dir = module.__module__.split(".")[-2] - path = module.__module__.split(".") - is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) + # check if the module is a pipeline module + pipeline_dir = module.__module__.split(".")[-2] + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) - # if library is not in LOADABLE_CLASSES, then it is a custom module. - # Or if it's a pipeline module, then the module is inside the pipeline - # folder so we set the library to module name. - if library not in LOADABLE_CLASSES or is_pipeline_module: - library = pipeline_dir + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # folder so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: + library = pipeline_dir - # retrieve class_name - class_name = module.__class__.__name__ + # retrieve class_name + class_name = module.__class__.__name__ - register_dict = {name: (library, class_name)} + register_dict = {name: (library, class_name)} # save model index config self.register_to_config(**register_dict) @@ -422,6 +425,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P is_pipeline_module = hasattr(pipelines, library_name) loaded_sub_model = None + sub_model_should_be_defined = True # if the model is in a pipeline module, then we load it from the pipeline if name in passed_class_obj: @@ -442,6 +446,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P f"{passed_class_obj[name]} is of type: {type(passed_class_obj[name])}, but should be" f" {expected_class_obj}" ) + elif passed_class_obj[name] is None: + logger.warn( + f"You have passed `None` for {name} to disable its functionality in {pipeline_class}. Note" + f" that this might lead to problems when using {pipeline_class} and is not recommended." + ) + sub_model_should_be_defined = False else: logger.warn( f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" @@ -462,7 +472,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} - if loaded_sub_model is None: + if loaded_sub_model is None and sub_model_should_be_defined: load_method_name = None for class_name, class_candidate in class_candidates.items(): if issubclass(class_obj, class_candidate): diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 615fa404da0b..8c07afe58fc2 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Union +from typing import List, Optional, Union import numpy as np @@ -20,11 +20,11 @@ class StableDiffusionPipelineOutput(BaseOutput): num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. nsfw_content_detected (`List[bool]`) List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work" - (nsfw) content. + (nsfw) content, or `None` if safety checking could not be performed. """ images: Union[List[PIL.Image.Image], np.ndarray] - nsfw_content_detected: List[bool] + nsfw_content_detected: Optional[List[bool]] if is_transformers_available() and is_torch_available(): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index ca6c580ffc55..c5b7ebd48baa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -71,6 +71,16 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + if safety_checker is None: + logger.warn( + f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -335,10 +345,15 @@ def __call__( # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 image = image.cpu().permute(0, 2, 3, 1).float().numpy() - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) - ) + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) + else: + has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 3e5ac4b33582..b998e5c613b7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -83,6 +83,16 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + if safety_checker is None: + logger.warn( + f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -358,10 +368,15 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) - ) + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype) + ) + else: + has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 30a588e754b3..ba31840209b1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -98,6 +98,16 @@ def __init__( new_config["steps_offset"] = 1 scheduler._internal_dict = FrozenDict(new_config) + if safety_checker is None: + logger.warn( + f"You have disabed the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + self.register_modules( vae=vae, text_encoder=text_encoder, @@ -382,8 +392,13 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to( + self.device + ) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_checker_input.pixel_values) + else: + has_nsfw_concept = None if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 8004241ac15e..e8f13ded3da5 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -498,6 +498,17 @@ def test_from_pretrained_error_message_uninstalled_packages(self): assert isinstance(pipe, StableDiffusionPipeline) assert isinstance(pipe.scheduler, LMSDiscreteScheduler) + def test_stable_diffusion_no_safety_checker(self): + pipe = StableDiffusionPipeline.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-lms-pipe", safety_checker=None + ) + assert isinstance(pipe, StableDiffusionPipeline) + assert isinstance(pipe.scheduler, LMSDiscreteScheduler) + assert pipe.safety_checker is None + + image = pipe("example prompt", num_inference_steps=2).images[0] + assert image is not None + def test_stable_diffusion_k_lms(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator unet = self.dummy_cond_unet