From 7940411401ddd6df73f2124d813bf2bc4f68ea4c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 18 Aug 2022 19:54:11 +0530 Subject: [PATCH 01/13] add SafetyChecker --- src/diffusers/pipeline_utils.py | 1 + .../pipelines/stable_diffusion/__init__.py | 2 +- .../pipeline_stable_diffusion.py | 23 ++++++- .../stable_diffusion/safety_checker.py | 60 +++++++++++++++++++ 4 files changed, 83 insertions(+), 3 deletions(-) create mode 100644 src/diffusers/pipelines/stable_diffusion/safety_checker.py diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 94a6c67b1cc4..8a8ecf92f587 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -42,6 +42,7 @@ "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], }, } diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 5e48f6f521f5..aedca6bbc3ca 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -3,4 +3,4 @@ if is_transformers_available(): - from .pipeline_stable_diffusion import StableDiffusionPipeline + from .pipeline_stable_diffusion import SafetyChecker, StableDiffusionPipeline diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 3b4acd46b172..8faa46155232 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -4,11 +4,12 @@ import torch from tqdm.auto import tqdm -from transformers import CLIPTextModel, CLIPTokenizer +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler +from .safety_checker import SafetyChecker class StableDiffusionPipeline(DiffusionPipeline): @@ -19,10 +20,20 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], + safety_checker: SafetyChecker, + feature_extractor: CLIPFeatureExtractor, ): super().__init__() scheduler = scheduler.set_format("pt") - self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) @torch.no_grad() def __call__( @@ -53,6 +64,7 @@ def __call__( self.unet.to(torch_device) self.vae.to(torch_device) self.text_encoder.to(torch_device) + self.safety_checker.to(torch_device) # get prompt text embeddings text_input = self.tokenizer( @@ -136,6 +148,13 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() + + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image)).pixel_values + has_bad_concepts = self.safety_checker(safety_cheker_input) + + if has_bad_concepts: + raise ValueError("The generated image contains concepts that are not allowed.") + if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py new file mode 100644 index 000000000000..0d965d114da9 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn + +from transformers import CLIPConfig, CLIPProcessor, CLIPVisionModel, PreTrainedModel + + +def cosine_distance(image_embeds, text_embeds): + normalized_image_embeds = nn.functional.normalize(image_embeds) + normalized_text_embeds = nn.functional.normalize(text_embeds) + return torch.mm(normalized_image_embeds, normalized_text_embeds.T) + + +class SafetyChecker(PreTrainedModel): + config_class = CLIPConfig + + def __init__(self, config: CLIPConfig): + super().__init__(config) + + self.vision_model = CLIPVisionModel(config.vision_config) + self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) + + self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) + self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) + + self.register_buffer("concept_embeds_weights", torch.ones(17)) + self.register_buffer("special_care_embeds_weights", torch.ones(3)) + + @torch.no_grad() + def forward(self, images): + """Get embeddings for images and output nsfw and concept scores""" + pooled_output = self.vision_model(**images)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() + cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() + + result = [] + for i in range(image_embeds.shape[0]): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + adjustment = 0.05 + + for concet_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concet_idx] + concept_threshold = self.special_care_embeds_weights[concet_idx].item() + result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concet_idx] > 0: + result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) + adjustment = 0.01 + + for concet_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concet_idx] + concept_threshold = self.concept_embeds_weights[concet_idx].item() + result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concet_idx] > 0: + result_img["bad_concepts"].append(concet_idx) + + result.append(result_img) + + has_bad_concepts = [len(result[i]["bad_concepts"]) > 0 for i in range(len(result))] + return has_bad_concepts From cc2ea493b547e819f217c6baf8870842aeae4cf6 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 18 Aug 2022 21:49:58 +0530 Subject: [PATCH 02/13] better name, fix checker --- .../stable_diffusion/pipeline_stable_diffusion.py | 8 ++++---- .../pipelines/stable_diffusion/safety_checker.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 8faa46155232..5f4ec0cfd960 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -9,7 +9,7 @@ from ...models import AutoencoderKL, UNet2DConditionModel from ...pipeline_utils import DiffusionPipeline from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler -from .safety_checker import SafetyChecker +from .safety_checker import StableDiffusionSafetyChecker class StableDiffusionPipeline(DiffusionPipeline): @@ -20,7 +20,7 @@ def __init__( tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], - safety_checker: SafetyChecker, + safety_checker: StableDiffusionSafetyChecker, feature_extractor: CLIPFeatureExtractor, ): super().__init__() @@ -149,8 +149,8 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image)).pixel_values - has_bad_concepts = self.safety_checker(safety_cheker_input) + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").pixel_values + has_bad_concepts = self.safety_checker(safety_cheker_input.to(torch_device)) if has_bad_concepts: raise ValueError("The generated image contains concepts that are not allowed.") diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 0d965d114da9..c6d84237f1a9 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -10,7 +10,7 @@ def cosine_distance(image_embeds, text_embeds): return torch.mm(normalized_image_embeds, normalized_text_embeds.T) -class SafetyChecker(PreTrainedModel): +class StableDiffusionSafetyChecker(PreTrainedModel): config_class = CLIPConfig def __init__(self, config: CLIPConfig): @@ -28,7 +28,7 @@ def __init__(self, config: CLIPConfig): @torch.no_grad() def forward(self, images): """Get embeddings for images and output nsfw and concept scores""" - pooled_output = self.vision_model(**images)[1] # pooled_output + pooled_output = self.vision_model(images)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() From e233bf34bbde4a7a7a52982a07bcf75d4bbae168 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 18 Aug 2022 21:50:05 +0530 Subject: [PATCH 03/13] add checker in main init --- src/diffusers/__init__.py | 1 + src/diffusers/pipelines/stable_diffusion/__init__.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 8b96ccf6d222..7abbc49a7d97 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -39,5 +39,6 @@ if is_transformers_available(): from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline + from .pipelines.stable_diffusion import StableDiffusionSafetyChecker else: from .utils.dummy_transformers_objects import * diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index aedca6bbc3ca..cd6e143b0e4a 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -3,4 +3,4 @@ if is_transformers_available(): - from .pipeline_stable_diffusion import SafetyChecker, StableDiffusionPipeline + from .pipeline_stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipeline From f10c6e6a78ff721c70cfda57d53818eaa6c8be5c Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 18 Aug 2022 22:38:47 +0530 Subject: [PATCH 04/13] remove from main init --- src/diffusers/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 7abbc49a7d97..8b96ccf6d222 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -39,6 +39,5 @@ if is_transformers_available(): from .pipelines import LDMTextToImagePipeline, StableDiffusionPipeline - from .pipelines.stable_diffusion import StableDiffusionSafetyChecker else: from .utils.dummy_transformers_objects import * From d89a9ce0f30e059ddff2f1e5f1fc40b3be5abe99 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 18 Aug 2022 22:39:04 +0530 Subject: [PATCH 05/13] update logic to detect pipeline module --- src/diffusers/pipeline_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 8a8ecf92f587..81201c905372 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -66,7 +66,9 @@ def register_modules(self, **kwargs): # check if the module is a pipeline module pipeline_file = module.__module__.split(".")[-1] pipeline_dir = module.__module__.split(".")[-2] - is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir) + # is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir) + path = module.__module__.split(".") + is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) # if library is not in LOADABLE_CLASSES, then it is a custom module. # Or if it's a pipeline module, then the module is inside the pipeline From 68486d222011bcdcd7319da108c74af7ffe7dfcc Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 18 Aug 2022 23:00:24 +0530 Subject: [PATCH 06/13] style --- src/diffusers/pipelines/stable_diffusion/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index cd6e143b0e4a..5306ba821a1e 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -3,4 +3,4 @@ if is_transformers_available(): - from .pipeline_stable_diffusion import StableDiffusionSafetyChecker, StableDiffusionPipeline + from .pipeline_stable_diffusion import StableDiffusionPipeline, StableDiffusionSafetyChecker From 063f2c5fcb3c25e775f589cd84c76db311668bd9 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 18 Aug 2022 23:01:33 +0530 Subject: [PATCH 07/13] handle all safety logic in safety checker --- .../pipeline_stable_diffusion.py | 10 ++++------ .../stable_diffusion/safety_checker.py | 19 +++++++++++++------ 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 5f4ec0cfd960..804d83fe5d52 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -149,13 +149,11 @@ def __call__( image = (image / 2 + 0.5).clamp(0, 1) image = image.cpu().permute(0, 2, 3, 1).numpy() - safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").pixel_values - has_bad_concepts = self.safety_checker(safety_cheker_input.to(torch_device)) - - if has_bad_concepts: - raise ValueError("The generated image contains concepts that are not allowed.") + # run safety checker + safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device) + image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image} + return {"sample": image, "nsfw": has_nsfw_concept} diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index c6d84237f1a9..6d902be89318 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -1,7 +1,8 @@ +import numpy as np import torch import torch.nn as nn -from transformers import CLIPConfig, CLIPProcessor, CLIPVisionModel, PreTrainedModel +from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel def cosine_distance(image_embeds, text_embeds): @@ -26,16 +27,17 @@ def __init__(self, config: CLIPConfig): self.register_buffer("special_care_embeds_weights", torch.ones(3)) @torch.no_grad() - def forward(self, images): + def forward(self, clip_input, images): """Get embeddings for images and output nsfw and concept scores""" - pooled_output = self.vision_model(images)[1] # pooled_output + pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() result = [] - for i in range(image_embeds.shape[0]): + batch_size = image_embeds.shape[0] + for i in range(batch_size): result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} adjustment = 0.05 @@ -56,5 +58,10 @@ def forward(self, images): result.append(result_img) - has_bad_concepts = [len(result[i]["bad_concepts"]) > 0 for i in range(len(result))] - return has_bad_concepts + has_nsfw_concept = [len(result[i]["bad_concepts"]) > 0 for i in range(len(result))] + + for idx, has_nsfw_concept in enumerate(has_nsfw_concept): + if has_nsfw_concept: + images[idx] = np.zeros(images[idx].shape) # black image + + return images, has_nsfw_concept From 20b64b445f89929ba1446602e57e150d9a49104a Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 18 Aug 2022 23:19:38 +0530 Subject: [PATCH 08/13] draw text --- src/diffusers/pipelines/stable_diffusion/safety_checker.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 6d902be89318..c908f27ac884 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +from PIL import ImageDraw from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel @@ -62,6 +63,9 @@ def forward(self, clip_input, images): for idx, has_nsfw_concept in enumerate(has_nsfw_concept): if has_nsfw_concept: - images[idx] = np.zeros(images[idx].shape) # black image + black_image = np.zeros(images[idx].shape) # black image + draw = ImageDraw.Draw(black_image) + draw.text((10, 10), "Too NSFW for diffusers") # TODO: better text + images[idx] = black_image return images, has_nsfw_concept From 008720b1fa8478d9c9ed4b4a5f943be146e87bca Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 18 Aug 2022 23:22:53 +0530 Subject: [PATCH 09/13] can't draw --- src/diffusers/pipelines/stable_diffusion/safety_checker.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index c908f27ac884..8eaff3646619 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -63,9 +63,6 @@ def forward(self, clip_input, images): for idx, has_nsfw_concept in enumerate(has_nsfw_concept): if has_nsfw_concept: - black_image = np.zeros(images[idx].shape) # black image - draw = ImageDraw.Draw(black_image) - draw.text((10, 10), "Too NSFW for diffusers") # TODO: better text - images[idx] = black_image + images[idx] = np.zeros(images[idx].shape) # black image return images, has_nsfw_concept From fc877cc7432e5322649b6068996a9bae0e31d7a9 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 18 Aug 2022 23:54:35 +0530 Subject: [PATCH 10/13] small fixes --- .../pipelines/stable_diffusion/pipeline_stable_diffusion.py | 2 +- src/diffusers/pipelines/stable_diffusion/safety_checker.py | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 804d83fe5d52..baff1db97092 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -156,4 +156,4 @@ def __call__( if output_type == "pil": image = self.numpy_to_pil(image) - return {"sample": image, "nsfw": has_nsfw_concept} + return {"sample": image, "nsfw_content_detected": has_nsfw_concept} diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 8eaff3646619..e56dd7cacd58 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -2,7 +2,6 @@ import torch import torch.nn as nn -from PIL import ImageDraw from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel @@ -29,7 +28,6 @@ def __init__(self, config: CLIPConfig): @torch.no_grad() def forward(self, clip_input, images): - """Get embeddings for images and output nsfw and concept scores""" pooled_output = self.vision_model(clip_input)[1] # pooled_output image_embeds = self.visual_projection(pooled_output) From 24e982f6310899069d41732fbb7be4e7318cf82f Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 19 Aug 2022 00:12:08 +0530 Subject: [PATCH 11/13] treat special care as nsfw --- src/diffusers/pipelines/stable_diffusion/safety_checker.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index e56dd7cacd58..8ffb47226e8d 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -57,7 +57,9 @@ def forward(self, clip_input, images): result.append(result_img) - has_nsfw_concept = [len(result[i]["bad_concepts"]) > 0 for i in range(len(result))] + has_nsfw_concept = [ + len(result[i]["bad_concepts"]) > 0 or len(result[i]["special_care"]) > 0 for i in range(len(result)) + ] for idx, has_nsfw_concept in enumerate(has_nsfw_concept): if has_nsfw_concept: From 76b10f9d84ebefe7f9952a3669a3825bc50b6980 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 19 Aug 2022 13:37:35 +0530 Subject: [PATCH 12/13] remove commented lines --- src/diffusers/pipeline_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 81201c905372..5b781f0e0971 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -64,9 +64,7 @@ def register_modules(self, **kwargs): library = module.__module__.split(".")[0] # check if the module is a pipeline module - pipeline_file = module.__module__.split(".")[-1] pipeline_dir = module.__module__.split(".")[-2] - # is_pipeline_module = pipeline_file == "pipeline_" + pipeline_dir and hasattr(pipelines, pipeline_dir) path = module.__module__.split(".") is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) From 2d136a0b60a45b0edc7e91488a02c0cd7f017e5b Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Fri, 19 Aug 2022 15:22:09 +0530 Subject: [PATCH 13/13] update safety checker --- .../stable_diffusion/safety_checker.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 8ffb47226e8d..1c5db4210ddf 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -4,6 +4,11 @@ from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel +from ...utils import logging + + +logger = logging.get_logger(__name__) + def cosine_distance(image_embeds, text_embeds): normalized_image_embeds = nn.functional.normalize(image_embeds) @@ -57,12 +62,16 @@ def forward(self, clip_input, images): result.append(result_img) - has_nsfw_concept = [ - len(result[i]["bad_concepts"]) > 0 or len(result[i]["special_care"]) > 0 for i in range(len(result)) - ] + has_nsfw_concepts = [len(result[i]["bad_concepts"]) > 0 or i in range(len(result))] - for idx, has_nsfw_concept in enumerate(has_nsfw_concept): + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): if has_nsfw_concept: images[idx] = np.zeros(images[idx].shape) # black image - return images, has_nsfw_concept + if any(has_nsfw_concepts): + logger.warning( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts