Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import PIL
from PIL import Image

from ...utils import BaseOutput, is_onnx_available, is_transformers_available
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available


@dataclass
Expand Down Expand Up @@ -35,3 +35,6 @@ class StableDiffusionPipelineOutput(BaseOutput):

if is_transformers_available() and is_onnx_available():
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline

if is_transformers_available() and is_flax_available():
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
111 changes: 111 additions & 0 deletions src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import warnings

import numpy as np

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.struct import field
from transformers import CLIPVisionConfig
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule

from ...configuration_utils import ConfigMixin, flax_register_to_config
from ...modeling_flax_utils import FlaxModelMixin


def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T
return jnp.matmul(norm_emb_1, norm_emb_2.T)


@flax_register_to_config
class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin):
projection_dim: int = 768
# CLIPVisionConfig fields
vision_config: dict = field(default_factory=dict)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config file that is currently used contains all the CLIP configuration options. We just retrieve the ones corresponding to the CLIPVisionConfig that we use.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the saving & loading work correctly with it?
Think long-term what would be better is to actually just write the parameters that we would people like to change in here (e.g. the size of the CLIP, ....) and then use CLIP default config parameters (see suggestion below)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the saving & loading work correctly with it?

Loading works (from CLIP config files as the one currently in the fusing repo). I did not test saving yet, will do now.

Think long-term what would be better is to actually just write the parameters that we would people like to change

I thought about this approach, but there are many parameters in the configuration file and wasn't sure which ones we wanted to expose. In addition, it could still be useful to reuse any CLIP configuration json and just retrieve the vision part, instead of reading the parameters from the root of the json, which I presume would be incompatible with existing configurations.

These are all the relevant keys in the current configuration file:

  "vision_config": {
    "_name_or_path": "",
    "add_cross_attention": false,
    "architectures": null,
    "attention_dropout": 0.0,
    "bad_words_ids": null,
    "bos_token_id": null,
    "chunk_size_feed_forward": 0,
    "cross_attention_hidden_size": null,
    "decoder_start_token_id": null,
    "diversity_penalty": 0.0,
    "do_sample": false,
    "dropout": 0.0,
    "early_stopping": false,
    "encoder_no_repeat_ngram_size": 0,
    "eos_token_id": null,
    "exponential_decay_length_penalty": null,
    "finetuning_task": null,
    "forced_bos_token_id": null,
    "forced_eos_token_id": null,
    "hidden_act": "quick_gelu",
    "hidden_size": 1024,
    "id2label": {
      "0": "LABEL_0",
      "1": "LABEL_1"
    },
    "image_size": 224,
    "initializer_factor": 1.0,
    "initializer_range": 0.02,
    "intermediate_size": 4096,
    "is_decoder": false,
    "is_encoder_decoder": false,
    "label2id": {
      "LABEL_0": 0,
      "LABEL_1": 1
    },
    "layer_norm_eps": 1e-05,
    "length_penalty": 1.0,
    "max_length": 20,
    "min_length": 0,
    "model_type": "clip_vision_model",
    "no_repeat_ngram_size": 0,
    "num_attention_heads": 16,
    "num_beam_groups": 1,
    "num_beams": 1,
    "num_channels": 3,
    "num_hidden_layers": 24,
    "num_return_sequences": 1,
    "output_attentions": false,
    "output_hidden_states": false,
    "output_scores": false,
    "pad_token_id": null,
    "patch_size": 14,
    "prefix": null,
    "problem_type": null,
    "pruned_heads": {},
    "remove_invalid_values": false,
    "repetition_penalty": 1.0,
    "return_dict": true,
    "return_dict_in_generate": false,
    "sep_token_id": null,
    "task_specific_params": null,
    "temperature": 1.0,
    "tf_legacy_loss": false,
    "tie_encoder_decoder": false,
    "tie_word_embeddings": true,
    "tokenizer_class": null,
    "top_k": 50,
    "top_p": 1.0,
    "torch_dtype": null,
    "torchscript": false,
    "transformers_version": "4.22.0.dev0",
    "typical_p": 1.0,
    "use_bfloat16": false
  },
  "vision_config_dict": {
    "hidden_size": 1024,
    "intermediate_size": 4096,
    "num_attention_heads": 16,
    "num_hidden_layers": 24,
    "patch_size": 14
  }

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not test saving yet, will do now.

Saving works after #565 was applied.

dtype: jnp.dtype = jnp.float32

def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
# init input tensor
input_shape = (
1,
self.vision_config["image_size"],
self.vision_config["image_size"],
self.vision_config["num_channels"],
)
pixel_values = jax.random.normal(rng, input_shape)
params_rng, dropout_rng = jax.random.split(rng)
rngs = {"params": params_rng, "dropout": dropout_rng}
return self.init(rngs, pixel_values)["params"]

def setup(self):
clip_vision_config = CLIPVisionConfig(**self.vision_config)
self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype)
self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype)

self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim))
self.special_care_embeds = self.param(
"special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim)
)

self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,))

def __call__(self, clip_input):
pooled_output = self.vision_model(clip_input)[1]
image_embeds = self.visual_projection(pooled_output)

special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
return special_cos_dist, cos_dist

def filtered_with_scores(self, special_cos_dist, cos_dist, images):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This method is not meant to be used with pmap, but it's fast as it just checks the scores.

batch_size = special_cos_dist.shape[0]
special_cos_dist = np.asarray(special_cos_dist)
cos_dist = np.asarray(cos_dist)

result = []
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}

# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign image inputs
adjustment = 0.0

for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Is there any reason to increase adjustment on the first "special care" concept, rather than after the loop, before we check "bad" concepts? Intuitively I understand that we use "special care" concepts to make the "bad" concepts check more strict, but I feel like "special care" should be orthogonal to other "special care" detection.

In the current version it doesn't really matter if there's one or more special care concepts, so I think either approach would be the same, but I'm curious if there's a specific reason behind this.

This is relevant because the conditional adjustment = 0.01 introduces a dependency between loop iterations. The alternative approach would be to move this outside the loop as: "if any special care score > 0 then set adjustment = 0.01 for bad concepts". With this, I think both loops could be vectorized (and in flax version they could probably be a part of __call__).

cc @patrickvonplaten

Copy link
Contributor

@patrickvonplaten patrickvonplaten Oct 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good point - we can/should definitely move it out of the loop :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want feel free to open a PR for it :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really good point @jonatanklosko

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks! I will submit a PR sometime this week :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

awesome, thanks a lot!


for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)

result.append(result_img)

has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]

images_was_copied = False
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
if not images_was_copied:
images_was_copied = True
images = images.copy()

images[idx] = np.zeros(images[idx].shape) # black image

if any(has_nsfw_concepts):
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)

return images, has_nsfw_concepts