diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index f0b353d931d4..a9eba10d157c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -48,6 +48,7 @@ def __call__( generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", + censored = False, **kwargs, ): if "torch_device" in kwargs: @@ -161,11 +162,16 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) - + # run safety checker if censorship enabled + has_nsfw_concept= False + if censored: + warnings.warn("Running in Censored Mode! If you get the NSFW message repeatedly, pls try calling this with censored=False") + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + if has_nsfw_concept: + warnings.warn("CENSORED. If this behavior is not desirable for your use case, simple call with censored=False") + else: + warnings.warn("Running in Uncensored Mode! Returned image may be NSFW") if output_type == "pil": image = self.numpy_to_pil(image) - - return {"sample": image, "nsfw_content_detected": has_nsfw_concept} + return {"sample": image, "nsfw_content_detected": False}