Skip to content

Commit fde9abc

Browse files
pcuencapatil-suraj
andauthored
JAX/Flax safety checker (#558)
* Starting to integrate safety checker. * Fix initialization of CLIPVisionConfig * Remove commented lines. * make style * Remove unused import * Pass dtype to modules Co-authored-by: Suraj Patil <[email protected]> * Pass dtype to modules Co-authored-by: Suraj Patil <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent b1182bc commit fde9abc

File tree

2 files changed

+115
-1
lines changed

2 files changed

+115
-1
lines changed

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import PIL
77
from PIL import Image
88

9-
from ...utils import BaseOutput, is_onnx_available, is_transformers_available
9+
from ...utils import BaseOutput, is_flax_available, is_onnx_available, is_transformers_available
1010

1111

1212
@dataclass
@@ -35,3 +35,6 @@ class StableDiffusionPipelineOutput(BaseOutput):
3535

3636
if is_transformers_available() and is_onnx_available():
3737
from .pipeline_stable_diffusion_onnx import StableDiffusionOnnxPipeline
38+
39+
if is_transformers_available() and is_flax_available():
40+
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import warnings
2+
3+
import numpy as np
4+
5+
import jax
6+
import jax.numpy as jnp
7+
from flax import linen as nn
8+
from flax.core.frozen_dict import FrozenDict
9+
from flax.struct import field
10+
from transformers import CLIPVisionConfig
11+
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
12+
13+
from ...configuration_utils import ConfigMixin, flax_register_to_config
14+
from ...modeling_flax_utils import FlaxModelMixin
15+
16+
17+
def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
18+
norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
19+
norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T
20+
return jnp.matmul(norm_emb_1, norm_emb_2.T)
21+
22+
23+
@flax_register_to_config
24+
class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin):
25+
projection_dim: int = 768
26+
# CLIPVisionConfig fields
27+
vision_config: dict = field(default_factory=dict)
28+
dtype: jnp.dtype = jnp.float32
29+
30+
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
31+
# init input tensor
32+
input_shape = (
33+
1,
34+
self.vision_config["image_size"],
35+
self.vision_config["image_size"],
36+
self.vision_config["num_channels"],
37+
)
38+
pixel_values = jax.random.normal(rng, input_shape)
39+
params_rng, dropout_rng = jax.random.split(rng)
40+
rngs = {"params": params_rng, "dropout": dropout_rng}
41+
return self.init(rngs, pixel_values)["params"]
42+
43+
def setup(self):
44+
clip_vision_config = CLIPVisionConfig(**self.vision_config)
45+
self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype)
46+
self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype)
47+
48+
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim))
49+
self.special_care_embeds = self.param(
50+
"special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim)
51+
)
52+
53+
self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
54+
self.special_care_embeds_weights = self.param("special_care_embeds_weights", jax.nn.initializers.ones, (3,))
55+
56+
def __call__(self, clip_input):
57+
pooled_output = self.vision_model(clip_input)[1]
58+
image_embeds = self.visual_projection(pooled_output)
59+
60+
special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
61+
cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
62+
return special_cos_dist, cos_dist
63+
64+
def filtered_with_scores(self, special_cos_dist, cos_dist, images):
65+
batch_size = special_cos_dist.shape[0]
66+
special_cos_dist = np.asarray(special_cos_dist)
67+
cos_dist = np.asarray(cos_dist)
68+
69+
result = []
70+
for i in range(batch_size):
71+
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
72+
73+
# increase this value to create a stronger `nfsw` filter
74+
# at the cost of increasing the possibility of filtering benign image inputs
75+
adjustment = 0.0
76+
77+
for concept_idx in range(len(special_cos_dist[0])):
78+
concept_cos = special_cos_dist[i][concept_idx]
79+
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
80+
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
81+
if result_img["special_scores"][concept_idx] > 0:
82+
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
83+
adjustment = 0.01
84+
85+
for concept_idx in range(len(cos_dist[0])):
86+
concept_cos = cos_dist[i][concept_idx]
87+
concept_threshold = self.concept_embeds_weights[concept_idx].item()
88+
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
89+
if result_img["concept_scores"][concept_idx] > 0:
90+
result_img["bad_concepts"].append(concept_idx)
91+
92+
result.append(result_img)
93+
94+
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
95+
96+
images_was_copied = False
97+
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
98+
if has_nsfw_concept:
99+
if not images_was_copied:
100+
images_was_copied = True
101+
images = images.copy()
102+
103+
images[idx] = np.zeros(images[idx].shape) # black image
104+
105+
if any(has_nsfw_concepts):
106+
warnings.warn(
107+
"Potential NSFW content was detected in one or more images. A black image will be returned"
108+
" instead. Try again with a different prompt and/or seed."
109+
)
110+
111+
return images, has_nsfw_concepts

0 commit comments

Comments
 (0)