-
Notifications
You must be signed in to change notification settings - Fork 6.5k
JAX/Flax safety checker #558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
The documentation is not available anymore as the PR was closed or merged. |
| class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin): | ||
| projection_dim: int = 768 | ||
| # CLIPVisionConfig fields | ||
| vision_config: dict = field(default_factory=dict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config file that is currently used contains all the CLIP configuration options. We just retrieve the ones corresponding to the CLIPVisionConfig that we use.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the saving & loading work correctly with it?
Think long-term what would be better is to actually just write the parameters that we would people like to change in here (e.g. the size of the CLIP, ....) and then use CLIP default config parameters (see suggestion below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the saving & loading work correctly with it?
Loading works (from CLIP config files as the one currently in the fusing repo). I did not test saving yet, will do now.
Think long-term what would be better is to actually just write the parameters that we would people like to change
I thought about this approach, but there are many parameters in the configuration file and wasn't sure which ones we wanted to expose. In addition, it could still be useful to reuse any CLIP configuration json and just retrieve the vision part, instead of reading the parameters from the root of the json, which I presume would be incompatible with existing configurations.
These are all the relevant keys in the current configuration file:
"vision_config": {
"_name_or_path": "",
"add_cross_attention": false,
"architectures": null,
"attention_dropout": 0.0,
"bad_words_ids": null,
"bos_token_id": null,
"chunk_size_feed_forward": 0,
"cross_attention_hidden_size": null,
"decoder_start_token_id": null,
"diversity_penalty": 0.0,
"do_sample": false,
"dropout": 0.0,
"early_stopping": false,
"encoder_no_repeat_ngram_size": 0,
"eos_token_id": null,
"exponential_decay_length_penalty": null,
"finetuning_task": null,
"forced_bos_token_id": null,
"forced_eos_token_id": null,
"hidden_act": "quick_gelu",
"hidden_size": 1024,
"id2label": {
"0": "LABEL_0",
"1": "LABEL_1"
},
"image_size": 224,
"initializer_factor": 1.0,
"initializer_range": 0.02,
"intermediate_size": 4096,
"is_decoder": false,
"is_encoder_decoder": false,
"label2id": {
"LABEL_0": 0,
"LABEL_1": 1
},
"layer_norm_eps": 1e-05,
"length_penalty": 1.0,
"max_length": 20,
"min_length": 0,
"model_type": "clip_vision_model",
"no_repeat_ngram_size": 0,
"num_attention_heads": 16,
"num_beam_groups": 1,
"num_beams": 1,
"num_channels": 3,
"num_hidden_layers": 24,
"num_return_sequences": 1,
"output_attentions": false,
"output_hidden_states": false,
"output_scores": false,
"pad_token_id": null,
"patch_size": 14,
"prefix": null,
"problem_type": null,
"pruned_heads": {},
"remove_invalid_values": false,
"repetition_penalty": 1.0,
"return_dict": true,
"return_dict_in_generate": false,
"sep_token_id": null,
"task_specific_params": null,
"temperature": 1.0,
"tf_legacy_loss": false,
"tie_encoder_decoder": false,
"tie_word_embeddings": true,
"tokenizer_class": null,
"top_k": 50,
"top_p": 1.0,
"torch_dtype": null,
"torchscript": false,
"transformers_version": "4.22.0.dev0",
"typical_p": 1.0,
"use_bfloat16": false
},
"vision_config_dict": {
"hidden_size": 1024,
"intermediate_size": 4096,
"num_attention_heads": 16,
"num_hidden_layers": 24,
"patch_size": 14
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did not test saving yet, will do now.
Saving works after #565 was applied.
| cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) | ||
| return special_cos_dist, cos_dist | ||
|
|
||
| def filtered_with_scores(self, special_cos_dist, cos_dist, images): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This method is not meant to be used with pmap, but it's fast as it just checks the scores.
patrickvonplaten
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If saving & loading works correctly good to go for me!
patil-suraj
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for adding this!
src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
Outdated
Show resolved
Hide resolved
src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
Outdated
Show resolved
Hide resolved
Co-authored-by: Suraj Patil <[email protected]>
Co-authored-by: Suraj Patil <[email protected]>
| result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) | ||
| if result_img["special_scores"][concept_idx] > 0: | ||
| result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) | ||
| adjustment = 0.01 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey! Is there any reason to increase adjustment on the first "special care" concept, rather than after the loop, before we check "bad" concepts? Intuitively I understand that we use "special care" concepts to make the "bad" concepts check more strict, but I feel like "special care" should be orthogonal to other "special care" detection.
In the current version it doesn't really matter if there's one or more special care concepts, so I think either approach would be the same, but I'm curious if there's a specific reason behind this.
This is relevant because the conditional adjustment = 0.01 introduces a dependency between loop iterations. The alternative approach would be to move this outside the loop as: "if any special care score > 0 then set adjustment = 0.01 for bad concepts". With this, I think both loops could be vectorized (and in flax version they could probably be a part of __call__).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good point - we can/should definitely move it out of the loop :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want feel free to open a PR for it :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really good point @jonatanklosko
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thanks! I will submit a PR sometime this week :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
awesome, thanks a lot!
* Starting to integrate safety checker. * Fix initialization of CLIPVisionConfig * Remove commented lines. * make style * Remove unused import * Pass dtype to modules Co-authored-by: Suraj Patil <[email protected]> * Pass dtype to modules Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
Docstrings and tests pending.
To test: