From 56d8b749e6efa943b4cf991543cac7dc0b9e9741 Mon Sep 17 00:00:00 2001 From: Sam Rahimi Date: Thu, 25 Aug 2022 04:47:35 -0500 Subject: [PATCH] Made the safety checker on / off switch Safety checker now can be switched on or off to allow for a greater diversity of use cases for the stable diffusion diffuser . --- .../pipeline_stable_diffusion.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 550513b5c943..6e187f35a069 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -47,6 +47,7 @@ def __call__( eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", + censored = False, **kwargs, ): if "torch_device" in kwargs: @@ -154,11 +155,16 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) - + # run safety checker if censorship enabled + has_nsfw_concept= False + if censored: + warnings.warn("Running in Censored Mode! If you get the NSFW message repeatedly, pls try calling this with censored=False") + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + if has_nsfw_concept: + warnings.warn("CENSORED. If this behavior is not desirable for your use case, simple call with censored=False") + else: + warnings.warn("Running in Uncensored Mode! Returned image may be NSFW") if output_type == "pil": image = self.numpy_to_pil(image) - - return {"sample": image, "nsfw_content_detected": has_nsfw_concept} + return {"sample": image, "nsfw_content_detected": False}