Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from functools import partial
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -97,9 +98,9 @@ def prepare_inputs(self, prompt: Union[str, List[str]]):
)
return text_input.input_ids

def _get_safety_scores(self, features, params):
special_cos_dist, cos_dist = self.safety_checker(features, params)
return (special_cos_dist, cos_dist)
def _get_has_nsfw_concepts(self, features, params):
has_nsfw_concepts = self.safety_checker(features, params)
return has_nsfw_concepts

def _run_safety_checker(self, images, safety_model_params, jit=False):
# safety_model_params should already be replicated when jit is True
Expand All @@ -108,20 +109,28 @@ def _run_safety_checker(self, images, safety_model_params, jit=False):

if jit:
features = shard(features)
special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params)
special_cos_dist = unshard(special_cos_dist)
cos_dist = unshard(cos_dist)
has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params)
has_nsfw_concepts = unshard(has_nsfw_concepts)
safety_model_params = unreplicate(safety_model_params)
else:
special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params)
has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params)

images, has_nsfw = self.safety_checker.filtered_with_scores(
special_cos_dist,
cos_dist,
images,
safety_model_params,
)
return images, has_nsfw
images_was_copied = False
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
if not images_was_copied:
images_was_copied = True
images = images.copy()

images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image

if any(has_nsfw_concepts):
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)

return images, has_nsfw_concepts

def _generate(
self,
Expand Down Expand Up @@ -310,8 +319,8 @@ def _p_generate(


@partial(jax.pmap, static_broadcasted_argnums=(0,))
def _p_get_safety_scores(pipe, features, params):
return pipe._get_safety_scores(features, params)
def _p_get_has_nsfw_concepts(pipe, features, params):
return pipe._get_has_nsfw_concepts(features, params)


def unshard(x: jnp.ndarray):
Expand Down
73 changes: 12 additions & 61 deletions src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import warnings
from typing import Optional, Tuple

import numpy as np

import jax
import jax.numpy as jnp
from flax import linen as nn
Expand Down Expand Up @@ -39,56 +36,22 @@ def __call__(self, clip_input):

special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds)
cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds)
return special_cos_dist, cos_dist

def filtered_with_scores(self, special_cos_dist, cos_dist, images):
batch_size = special_cos_dist.shape[0]
special_cos_dist = np.asarray(special_cos_dist)
cos_dist = np.asarray(cos_dist)

result = []
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}

# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign image inputs
adjustment = 0.0

for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01

for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)

result.append(result_img)

has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign image inputs
adjustment = 0.0

images_was_copied = False
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
if has_nsfw_concept:
if not images_was_copied:
images_was_copied = True
images = images.copy()
special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment
special_scores = jnp.round(special_scores, 3)
is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True)
# Use a lower threshold if an image has any special care concept
special_adjustment = is_special_care * 0.01

images[idx] = np.zeros(images[idx].shape) # black image
concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment
concept_scores = jnp.round(concept_scores, 3)
has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1)

if any(has_nsfw_concepts):
warnings.warn(
"Potential NSFW content was detected in one or more images. A black image will be returned"
" instead. Try again with a different prompt and/or seed."
)

return images, has_nsfw_concepts
return has_nsfw_concepts


class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
Expand Down Expand Up @@ -133,15 +96,3 @@ def __call__(
jnp.array(clip_input, dtype=jnp.float32),
rngs={},
)

def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
return module.filtered_with_scores(special_cos_dist, cos_dist, images)

return self.module.apply(
{"params": params or self.params},
special_cos_dist,
cos_dist,
images,
method=_filtered_with_scores,
)