Skip to content

Commit 8e4fd68

Browse files
Move safety detection to model call in Flax safety checker (#1023)
* Move safety detection to model call in Flax safety checker * Update src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
1 parent 95414bd commit 8e4fd68

File tree

2 files changed

+37
-77
lines changed

2 files changed

+37
-77
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from functools import partial
23
from typing import Dict, List, Optional, Union
34

@@ -97,9 +98,9 @@ def prepare_inputs(self, prompt: Union[str, List[str]]):
9798
)
9899
return text_input.input_ids
99100

100-
def _get_safety_scores(self, features, params):
101-
special_cos_dist, cos_dist = self.safety_checker(features, params)
102-
return (special_cos_dist, cos_dist)
101+
def _get_has_nsfw_concepts(self, features, params):
102+
has_nsfw_concepts = self.safety_checker(features, params)
103+
return has_nsfw_concepts
103104

104105
def _run_safety_checker(self, images, safety_model_params, jit=False):
105106
# safety_model_params should already be replicated when jit is True
@@ -108,20 +109,28 @@ def _run_safety_checker(self, images, safety_model_params, jit=False):
108109

109110
if jit:
110111
features = shard(features)
111-
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
112-
special_cos_dist = unshard(special_cos_dist)
113-
cos_dist = unshard(cos_dist)
112+
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
113+
has_nsfw_concepts = unshard(has_nsfw_concepts)
114114
safety_model_params = unreplicate(safety_model_params)
115115
else:
116-
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
116+
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)
117117

118-
images, has_nsfw = self.safety_checker.filtered_with_scores(
119-
special_cos_dist,
120-
cos_dist,
121-
images,
122-
safety_model_params,
123-
)
124-
return images, has_nsfw
118+
images_was_copied = False
119+
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
120+
if has_nsfw_concept:
121+
if not images_was_copied:
122+
images_was_copied = True
123+
images = images.copy()
124+
125+
images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image
126+
127+
if any(has_nsfw_concepts):
128+
warnings.warn(
129+
"Potential NSFW content was detected in one or more images. A black image will be returned"
130+
" instead. Try again with a different prompt and/or seed."
131+
)
132+
133+
return images, has_nsfw_concepts
125134

126135
def _generate(
127136
self,
@@ -310,8 +319,8 @@ def _p_generate(
310319

311320

312321
@partial(jax.pmap, static_broadcasted_argnums=(0,))
313-
def _p_get_safety_scores(pipe, features, params):
314-
return pipe._get_safety_scores(features, params)
322+
def _p_get_has_nsfw_concepts(pipe, features, params):
323+
return pipe._get_has_nsfw_concepts(features, params)
315324

316325

317326
def unshard(x: jnp.ndarray):

src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py

Lines changed: 12 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import warnings
21
from typing import Optional, Tuple
32

4-
import numpy as np
5-
63
import jax
74
import jax.numpy as jnp
85
from flax import linen as nn
@@ -39,56 +36,22 @@ def __call__(self, clip_input):
3936

4037
special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
4138
cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
42-
return special_cos_dist, cos_dist
43-
44-
def filtered_with_scores(self, special_cos_dist, cos_dist, images):
45-
batch_size = special_cos_dist.shape[0]
46-
special_cos_dist = np.asarray(special_cos_dist)
47-
cos_dist = np.asarray(cos_dist)
48-
49-
result = []
50-
for i in range(batch_size):
51-
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}
52-
53-
# increase this value to create a stronger `nfsw` filter
54-
# at the cost of increasing the possibility of filtering benign image inputs
55-
adjustment = 0.0
56-
57-
for concept_idx in range(len(special_cos_dist[0])):
58-
concept_cos = special_cos_dist[i][concept_idx]
59-
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
60-
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
61-
if result_img["special_scores"][concept_idx] > 0:
62-
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
63-
adjustment = 0.01
64-
65-
for concept_idx in range(len(cos_dist[0])):
66-
concept_cos = cos_dist[i][concept_idx]
67-
concept_threshold = self.concept_embeds_weights[concept_idx].item()
68-
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
69-
if result_img["concept_scores"][concept_idx] > 0:
70-
result_img["bad_concepts"].append(concept_idx)
71-
72-
result.append(result_img)
7339

74-
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
40+
# increase this value to create a stronger `nfsw` filter
41+
# at the cost of increasing the possibility of filtering benign image inputs
42+
adjustment = 0.0
7543

76-
images_was_copied = False
77-
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
78-
if has_nsfw_concept:
79-
if not images_was_copied:
80-
images_was_copied = True
81-
images = images.copy()
44+
special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment
45+
special_scores = jnp.round(special_scores, 3)
46+
is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True)
47+
# Use a lower threshold if an image has any special care concept
48+
special_adjustment = is_special_care * 0.01
8249

83-
images[idx] = np.zeros(images[idx].shape) # black image
50+
concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment
51+
concept_scores = jnp.round(concept_scores, 3)
52+
has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1)
8453

85-
if any(has_nsfw_concepts):
86-
warnings.warn(
87-
"Potential NSFW content was detected in one or more images. A black image will be returned"
88-
" instead. Try again with a different prompt and/or seed."
89-
)
90-
91-
return images, has_nsfw_concepts
54+
return has_nsfw_concepts
9255

9356

9457
class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
@@ -133,15 +96,3 @@ def __call__(
13396
jnp.array(clip_input, dtype=jnp.float32),
13497
rngs={},
13598
)
136-
137-
def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
138-
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
139-
return module.filtered_with_scores(special_cos_dist, cos_dist, images)
140-
141-
return self.module.apply(
142-
{"params": params or self.params},
143-
special_cos_dist,
144-
cos_dist,
145-
images,
146-
method=_filtered_with_scores,
147-
)

0 commit comments

Comments
 (0)