-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add safety module #213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add safety module #213
Changes from 12 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
7940411
add SafetyChecker
patil-suraj cc2ea49
better name, fix checker
patil-suraj e233bf3
add checker in main init
patil-suraj f10c6e6
remove from main init
patil-suraj d89a9ce
update logic to detect pipeline module
patil-suraj 68486d2
style
patil-suraj 063f2c5
handle all safety logic in safety checker
patil-suraj 20b64b4
draw text
patil-suraj 008720b
can't draw
patil-suraj fc877cc
small fixes
patil-suraj 24e982f
treat special care as nsfw
patil-suraj 76b10f9
remove commented lines
patil-suraj 2d136a0
update safety checker
patil-suraj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,11 +4,12 @@ | |
| import torch | ||
|
|
||
| from tqdm.auto import tqdm | ||
| from transformers import CLIPTextModel, CLIPTokenizer | ||
| from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | ||
|
|
||
| from ...models import AutoencoderKL, UNet2DConditionModel | ||
| from ...pipeline_utils import DiffusionPipeline | ||
| from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler | ||
| from .safety_checker import StableDiffusionSafetyChecker | ||
|
|
||
|
|
||
| class StableDiffusionPipeline(DiffusionPipeline): | ||
|
|
@@ -19,10 +20,20 @@ def __init__( | |
| tokenizer: CLIPTokenizer, | ||
| unet: UNet2DConditionModel, | ||
| scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler], | ||
| safety_checker: StableDiffusionSafetyChecker, | ||
| feature_extractor: CLIPFeatureExtractor, | ||
| ): | ||
| super().__init__() | ||
| scheduler = scheduler.set_format("pt") | ||
| self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) | ||
| self.register_modules( | ||
| vae=vae, | ||
| text_encoder=text_encoder, | ||
| tokenizer=tokenizer, | ||
| unet=unet, | ||
| scheduler=scheduler, | ||
| safety_checker=safety_checker, | ||
| feature_extractor=feature_extractor, | ||
| ) | ||
|
|
||
| @torch.no_grad() | ||
| def __call__( | ||
|
|
@@ -53,6 +64,7 @@ def __call__( | |
| self.unet.to(torch_device) | ||
| self.vae.to(torch_device) | ||
| self.text_encoder.to(torch_device) | ||
| self.safety_checker.to(torch_device) | ||
|
|
||
| # get prompt text embeddings | ||
| text_input = self.tokenizer( | ||
|
|
@@ -136,7 +148,12 @@ def __call__( | |
|
|
||
| image = (image / 2 + 0.5).clamp(0, 1) | ||
| image = image.cpu().permute(0, 2, 3, 1).numpy() | ||
|
|
||
| # run safety checker | ||
| safety_cheker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(torch_device) | ||
| image, has_nsfw_concept = self.safety_checker(images=image, clip_input=safety_cheker_input.pixel_values) | ||
|
|
||
|
Comment on lines
+153
to
+155
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| if output_type == "pil": | ||
| image = self.numpy_to_pil(image) | ||
|
|
||
| return {"sample": image} | ||
| return {"sample": image, "nsfw_content_detected": has_nsfw_concept} | ||
68 changes: 68 additions & 0 deletions
68
src/diffusers/pipelines/stable_diffusion/safety_checker.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel | ||
|
|
||
|
|
||
| def cosine_distance(image_embeds, text_embeds): | ||
| normalized_image_embeds = nn.functional.normalize(image_embeds) | ||
| normalized_text_embeds = nn.functional.normalize(text_embeds) | ||
| return torch.mm(normalized_image_embeds, normalized_text_embeds.T) | ||
|
|
||
|
|
||
| class StableDiffusionSafetyChecker(PreTrainedModel): | ||
| config_class = CLIPConfig | ||
|
|
||
| def __init__(self, config: CLIPConfig): | ||
| super().__init__(config) | ||
|
|
||
| self.vision_model = CLIPVisionModel(config.vision_config) | ||
| self.visual_projection = nn.Linear(config.vision_config.hidden_size, config.projection_dim, bias=False) | ||
|
|
||
| self.concept_embeds = nn.Parameter(torch.ones(17, config.projection_dim), requires_grad=False) | ||
| self.special_care_embeds = nn.Parameter(torch.ones(3, config.projection_dim), requires_grad=False) | ||
|
|
||
| self.register_buffer("concept_embeds_weights", torch.ones(17)) | ||
| self.register_buffer("special_care_embeds_weights", torch.ones(3)) | ||
|
|
||
| @torch.no_grad() | ||
| def forward(self, clip_input, images): | ||
| pooled_output = self.vision_model(clip_input)[1] # pooled_output | ||
| image_embeds = self.visual_projection(pooled_output) | ||
|
|
||
| special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds).cpu().numpy() | ||
| cos_dist = cosine_distance(image_embeds, self.concept_embeds).cpu().numpy() | ||
|
|
||
| result = [] | ||
| batch_size = image_embeds.shape[0] | ||
| for i in range(batch_size): | ||
| result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} | ||
| adjustment = 0.05 | ||
|
|
||
| for concet_idx in range(len(special_cos_dist[0])): | ||
| concept_cos = special_cos_dist[i][concet_idx] | ||
| concept_threshold = self.special_care_embeds_weights[concet_idx].item() | ||
| result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) | ||
| if result_img["special_scores"][concet_idx] > 0: | ||
| result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) | ||
| adjustment = 0.01 | ||
|
|
||
| for concet_idx in range(len(cos_dist[0])): | ||
| concept_cos = cos_dist[i][concet_idx] | ||
| concept_threshold = self.concept_embeds_weights[concet_idx].item() | ||
| result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) | ||
| if result_img["concept_scores"][concet_idx] > 0: | ||
| result_img["bad_concepts"].append(concet_idx) | ||
|
|
||
| result.append(result_img) | ||
|
|
||
| has_nsfw_concept = [ | ||
| len(result[i]["bad_concepts"]) > 0 or len(result[i]["special_care"]) > 0 for i in range(len(result)) | ||
| ] | ||
|
|
||
| for idx, has_nsfw_concept in enumerate(has_nsfw_concept): | ||
| if has_nsfw_concept: | ||
| images[idx] = np.zeros(images[idx].shape) # black image | ||
|
|
||
| return images, has_nsfw_concept |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More general logic to detect if a module comes from
pipelinemodule. For now this is only needed forLDMBertmodel and the safety checker.This should probably be in a separate PR.