-
Notifications
You must be signed in to change notification settings - Fork 6.5k
JAX/Flax safety checker #558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
1201501
a3fb07f
612b0c0
c8d2cb9
804e98a
882542c
9706814
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
|
|
||
| import jax | ||
| import jax.numpy as jnp | ||
| from flax import linen as nn | ||
| from flax.core.frozen_dict import FrozenDict | ||
| from flax.struct import field | ||
| from transformers import CLIPVisionConfig | ||
| from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule | ||
|
|
||
| from ...configuration_utils import ConfigMixin, flax_register_to_config | ||
| from ...modeling_flax_utils import FlaxModelMixin | ||
|
|
||
|
|
||
| def jax_cosine_distance(emb_1, emb_2, eps=1e-12): | ||
| norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T | ||
| norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T | ||
| return jnp.matmul(norm_emb_1, norm_emb_2.T) | ||
|
|
||
|
|
||
| @flax_register_to_config | ||
| class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin): | ||
| projection_dim: int = 768 | ||
| # CLIPVisionConfig fields | ||
| vision_config: dict = field(default_factory=dict) | ||
| dtype: jnp.dtype = jnp.float32 | ||
|
|
||
| def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: | ||
| # init input tensor | ||
| input_shape = ( | ||
| 1, | ||
| self.vision_config["image_size"], | ||
| self.vision_config["image_size"], | ||
| self.vision_config["num_channels"], | ||
| ) | ||
| pixel_values = jax.random.normal(rng, input_shape) | ||
| params_rng, dropout_rng = jax.random.split(rng) | ||
| rngs = {"params": params_rng, "dropout": dropout_rng} | ||
| return self.init(rngs, pixel_values)["params"] | ||
|
|
||
| def setup(self): | ||
| clip_vision_config = CLIPVisionConfig(**self.vision_config) | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype) | ||
| self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype) | ||
|
|
||
| self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim)) | ||
| self.special_care_embeds = self.param( | ||
| "special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim) | ||
| ) | ||
|
|
||
| self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) | ||
| self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,)) | ||
|
|
||
| def __call__(self, clip_input): | ||
| pooled_output = self.vision_model(clip_input)[1] | ||
| image_embeds = self.visual_projection(pooled_output) | ||
|
|
||
| special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) | ||
| cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) | ||
| return special_cos_dist, cos_dist | ||
|
|
||
| def filtered_with_scores(self, special_cos_dist, cos_dist, images): | ||
|
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This method is not meant to be used with |
||
| batch_size = special_cos_dist.shape[0] | ||
| special_cos_dist = np.asarray(special_cos_dist) | ||
| cos_dist = np.asarray(cos_dist) | ||
|
|
||
| result = [] | ||
| for i in range(batch_size): | ||
| result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} | ||
|
|
||
| # increase this value to create a stronger `nfsw` filter | ||
| # at the cost of increasing the possibility of filtering benign image inputs | ||
| adjustment = 0.0 | ||
|
|
||
| for concept_idx in range(len(special_cos_dist[0])): | ||
| concept_cos = special_cos_dist[i][concept_idx] | ||
| concept_threshold = self.special_care_embeds_weights[concept_idx].item() | ||
| result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) | ||
| if result_img["special_scores"][concept_idx] > 0: | ||
| result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) | ||
| adjustment = 0.01 | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey! Is there any reason to increase In the current version it doesn't really matter if there's one or more special care concepts, so I think either approach would be the same, but I'm curious if there's a specific reason behind this. This is relevant because the conditional
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's a good point - we can/should definitely move it out of the loop :-)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you want feel free to open a PR for it :-)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Really good point @jonatanklosko
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perfect, thanks! I will submit a PR sometime this week :)
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. awesome, thanks a lot! |
||
|
|
||
| for concept_idx in range(len(cos_dist[0])): | ||
| concept_cos = cos_dist[i][concept_idx] | ||
| concept_threshold = self.concept_embeds_weights[concept_idx].item() | ||
| result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) | ||
| if result_img["concept_scores"][concept_idx] > 0: | ||
| result_img["bad_concepts"].append(concept_idx) | ||
|
|
||
| result.append(result_img) | ||
|
|
||
| has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] | ||
|
|
||
| images_was_copied = False | ||
| for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): | ||
| if has_nsfw_concept: | ||
| if not images_was_copied: | ||
| images_was_copied = True | ||
| images = images.copy() | ||
|
|
||
| images[idx] = np.zeros(images[idx].shape) # black image | ||
|
|
||
| if any(has_nsfw_concepts): | ||
| warnings.warn( | ||
| "Potential NSFW content was detected in one or more images. A black image will be returned" | ||
| " instead. Try again with a different prompt and/or seed." | ||
| ) | ||
|
|
||
| return images, has_nsfw_concepts | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config file that is currently used contains all the CLIP configuration options. We just retrieve the ones corresponding to the
CLIPVisionConfigthat we use.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the saving & loading work correctly with it?
Think long-term what would be better is to actually just write the parameters that we would people like to change in here (e.g. the size of the CLIP, ....) and then use CLIP default config parameters (see suggestion below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Loading works (from CLIP config files as the one currently in the
fusingrepo). I did not test saving yet, will do now.I thought about this approach, but there are many parameters in the configuration file and wasn't sure which ones we wanted to expose. In addition, it could still be useful to reuse any CLIP configuration json and just retrieve the vision part, instead of reading the parameters from the root of the json, which I presume would be incompatible with existing configurations.
These are all the relevant keys in the current configuration file:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Saving works after #565 was applied.