From 34ca64306f4ad39e764eba0a66b51df9bce7a2fb Mon Sep 17 00:00:00 2001 From: banteg <4562643+banteg@users.noreply.github.com> Date: Thu, 1 Sep 2022 10:19:49 +0400 Subject: [PATCH] feat: add allow nsfw flag --- .../stable_diffusion/pipeline_stable_diffusion.py | 9 ++++++--- .../pipeline_stable_diffusion_img2img.py | 9 ++++++--- .../pipeline_stable_diffusion_inpaint.py | 9 ++++++--- 3 files changed, 18 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index d4290da6f030..62fca76da18d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -47,6 +47,7 @@ def __call__( generator: Optional[torch.Generator] = None, latents: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", + allow_nsfw: bool = True, **kwargs, ): if "torch_device" in kwargs: @@ -161,9 +162,11 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + has_nsfw_concept = None + if not allow_nsfw: + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 2c3d5c8e15e8..487d654196b9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -58,6 +58,7 @@ def __call__( eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", + allow_nsfw: bool = True, ): if isinstance(prompt, str): @@ -160,9 +161,11 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + has_nsfw_concept = None + if not allow_nsfw: + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 6827846722d7..b43ff7b5ba62 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -72,6 +72,7 @@ def __call__( eta: Optional[float] = 0.0, generator: Optional[torch.Generator] = None, output_type: Optional[str] = "pil", + allow_nsfw: bool = True, ): if isinstance(prompt, str): @@ -187,9 +188,11 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - # run safety checker - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) - image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) + has_nsfw_concept = None + if not allow_nsfw: + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(self.device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image)