From 1201501ca6f2bf78f72dd5e82e2f367183ff7672 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 19 Sep 2022 09:28:37 +0200 Subject: [PATCH 1/7] Starting to integrate safety checker. --- .../pipelines/stable_diffusion/__init__.py | 5 +- .../stable_diffusion/safety_checker_flax.py | 112 ++++++++++++++++++ 2 files changed, 116 insertions(+), 1 deletion(-) create mode 100644 src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index 5ffda93f1721..ed93335299af 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,7 @@ import PIL from PIL import Image -from ...utils import BaseOutput, is_onnx_available, is_transformers_available +from ...utils import BaseOutput, is_onnx_available, is_transformers_available, is_flax_available @dataclass @@ -35,3 +35,6 @@ class StableDiffusionPipelineOutput(BaseOutput): if is_transformers_available() and is_onnx_available(): from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline + +if is_transformers_available() and is_flax_available(): + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py new file mode 100644 index 000000000000..9e290caf3fa8 --- /dev/null +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -0,0 +1,112 @@ +import numpy as np +import jax +import jax.numpy as jnp +from flax import linen as nn +from flax.core.frozen_dict import FrozenDict +from flax.struct import dataclass, field + +import warnings +# from typing import Optional, Tuple + +# from transformers import CLIPConfig, FlaxPreTrainedModel +from transformers import CLIPVisionConfig +from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule + +from ...configuration_utils import ConfigMixin, flax_register_to_config +from ...modeling_flax_utils import FlaxModelMixin + + +def jax_cosine_distance(emb_1, emb_2, eps=1e-12): + norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T + norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T + return jnp.matmul(norm_emb_1, norm_emb_2.T) + + +@flax_register_to_config +class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin): + # Used to be a CLIPConfig configuration, but we only need the Vision configuration. + projection_dim: int = 768 + vision_config: dict = field(default_factory=dict) + # vision_config: CLIPVisionConfig = None + dtype: jnp.dtype = jnp.float32 + + def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: + # init input tensor + input_shape = ( + 1, + self.vision_config["image_size"], + self.vision_config["image_size"], + self.vision_config["num_channels"], + ) + pixel_values = jax.random.normal(rng, input_shape) + params_rng, dropout_rng = jax.random.split(rng) + rngs = {"params": params_rng, "dropout": dropout_rng} + return self.init(rngs, pixel_values)["params"] + + def setup(self): + clip_vision_config = CLIPVisionConfig(self.vision_config) + self.vision_model = FlaxCLIPVisionModule(clip_vision_config) + self.visual_projection = nn.Dense(self.projection_dim, use_bias=False) + + self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim)) + self.special_care_embeds = self.param("special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim)) + + self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) + self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) + + def __call__(self, clip_input): + pooled_output = self.vision_model(clip_input)[1] # pooled_output + image_embeds = self.visual_projection(pooled_output) + + special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) + cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) + return special_cos_dist, cos_dist + + def filtered_with_scores(self, special_cos_dist, cos_dist, images): + batch_size = special_cos_dist.shape[0] + special_cos_dist = np.asarray(special_cos_dist) + cos_dist = np.asarray(cos_dist) + + result = [] + for i in range(batch_size): + result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} + + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign image inputs + adjustment = 0.0 + + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) + adjustment = 0.01 + + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) + + result.append(result_img) + + has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned instead." + " Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts From a3fb07f8eeeedfd5bf0a14f1fa36490be55a7de9 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 19 Sep 2022 10:17:17 +0200 Subject: [PATCH 2/7] Fix initialization of CLIPVisionConfig --- .../pipelines/stable_diffusion/safety_checker_flax.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index 9e290caf3fa8..a1e4ce515470 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -44,7 +44,7 @@ def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: return self.init(rngs, pixel_values)["params"] def setup(self): - clip_vision_config = CLIPVisionConfig(self.vision_config) + clip_vision_config = CLIPVisionConfig(**self.vision_config) self.vision_model = FlaxCLIPVisionModule(clip_vision_config) self.visual_projection = nn.Dense(self.projection_dim, use_bias=False) @@ -55,7 +55,7 @@ def setup(self): self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) def __call__(self, clip_input): - pooled_output = self.vision_model(clip_input)[1] # pooled_output + pooled_output = self.vision_model(clip_input)[1] image_embeds = self.visual_projection(pooled_output) special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) From 612b0c0ea52ae070d08aa40ed3902b63a56fdf6d Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 19 Sep 2022 10:19:17 +0200 Subject: [PATCH 3/7] Remove commented lines. --- .../pipelines/stable_diffusion/safety_checker_flax.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index a1e4ce515470..122500d37e68 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -6,9 +6,7 @@ from flax.struct import dataclass, field import warnings -# from typing import Optional, Tuple -# from transformers import CLIPConfig, FlaxPreTrainedModel from transformers import CLIPVisionConfig from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule @@ -24,10 +22,9 @@ def jax_cosine_distance(emb_1, emb_2, eps=1e-12): @flax_register_to_config class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin): - # Used to be a CLIPConfig configuration, but we only need the Vision configuration. projection_dim: int = 768 + # CLIPVisionConfig fields vision_config: dict = field(default_factory=dict) - # vision_config: CLIPVisionConfig = None dtype: jnp.dtype = jnp.float32 def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: From c8d2cb9bf49340296eba0899a0f8c574adc326e4 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 19 Sep 2022 10:56:48 +0200 Subject: [PATCH 4/7] make style --- .../pipelines/stable_diffusion/__init__.py | 4 ++-- .../stable_diffusion/safety_checker_flax.py | 14 ++++++++------ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/__init__.py b/src/diffusers/pipelines/stable_diffusion/__init__.py index ed93335299af..e41043e0ad53 100644 --- a/src/diffusers/pipelines/stable_diffusion/__init__.py +++ b/src/diffusers/pipelines/stable_diffusion/__init__.py @@ -6,7 +6,7 @@ import PIL from PIL import Image -from ...utils import BaseOutput, is_onnx_available, is_transformers_available, is_flax_available +from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available @dataclass @@ -37,4 +37,4 @@ class StableDiffusionPipelineOutput(BaseOutput): from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline if is_transformers_available() and is_flax_available(): - from .safety_checker_flax import FlaxStableDiffusionSafetyChecker \ No newline at end of file + from .safety_checker_flax import FlaxStableDiffusionSafetyChecker diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index 122500d37e68..0ebd1a114ebf 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -1,12 +1,12 @@ +import warnings + import numpy as np + import jax import jax.numpy as jnp from flax import linen as nn from flax.core.frozen_dict import FrozenDict from flax.struct import dataclass, field - -import warnings - from transformers import CLIPVisionConfig from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule @@ -46,7 +46,9 @@ def setup(self): self.visual_projection = nn.Dense(self.projection_dim, use_bias=False) self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim)) - self.special_care_embeds = self.param("special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim)) + self.special_care_embeds = self.param( + "special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim) + ) self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) @@ -102,8 +104,8 @@ def filtered_with_scores(self, special_cos_dist, cos_dist, images): if any(has_nsfw_concepts): warnings.warn( - "Potential NSFW content was detected in one or more images. A black image will be returned instead." - " Try again with a different prompt and/or seed." + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." ) return images, has_nsfw_concepts From 804e98a8fbe7dcd12dfa81a65d38f3a4b93be5e3 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 19 Sep 2022 11:06:59 +0200 Subject: [PATCH 5/7] Remove unused import --- src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index 0ebd1a114ebf..233f548c1efb 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -6,7 +6,7 @@ import jax.numpy as jnp from flax import linen as nn from flax.core.frozen_dict import FrozenDict -from flax.struct import dataclass, field +from flax.struct import field from transformers import CLIPVisionConfig from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule From 882542c5821e14e682b9751c096f18ea82380ad0 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 19 Sep 2022 15:24:15 +0200 Subject: [PATCH 6/7] Pass dtype to modules Co-authored-by: Suraj Patil --- src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index 233f548c1efb..7ee85a46c171 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -42,7 +42,7 @@ def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: def setup(self): clip_vision_config = CLIPVisionConfig(**self.vision_config) - self.vision_model = FlaxCLIPVisionModule(clip_vision_config) + self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype) self.visual_projection = nn.Dense(self.projection_dim, use_bias=False) self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim)) From 9706814a5e53a4dff9bd893bfe606924fa675a3b Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 19 Sep 2022 15:24:40 +0200 Subject: [PATCH 7/7] Pass dtype to modules Co-authored-by: Suraj Patil --- src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index 7ee85a46c171..de84b793a176 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -43,7 +43,7 @@ def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: def setup(self): clip_vision_config = CLIPVisionConfig(**self.vision_config) self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype) - self.visual_projection = nn.Dense(self.projection_dim, use_bias=False) + self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype) self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim)) self.special_care_embeds = self.param(