diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index e4f56d94dac8..fe0e284c6720 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -1,3 +1,4 @@ +import warnings from functools import partial from typing import Dict, List, Optional, Union @@ -97,9 +98,9 @@ def prepare_inputs(self, prompt: Union[str, List[str]]): ) return text_input.input_ids - def _get_safety_scores(self, features, params): - special_cos_dist, cos_dist = self.safety_checker(features, params) - return (special_cos_dist, cos_dist) + def _get_has_nsfw_concepts(self, features, params): + has_nsfw_concepts = self.safety_checker(features, params) + return has_nsfw_concepts def _run_safety_checker(self, images, safety_model_params, jit=False): # safety_model_params should already be replicated when jit is True @@ -108,20 +109,28 @@ def _run_safety_checker(self, images, safety_model_params, jit=False): if jit: features = shard(features) - special_cos_dist, cos_dist = _p_get_safety_scores(self, features, safety_model_params) - special_cos_dist = unshard(special_cos_dist) - cos_dist = unshard(cos_dist) + has_nsfw_concepts = _p_get_has_nsfw_concepts(self, features, safety_model_params) + has_nsfw_concepts = unshard(has_nsfw_concepts) safety_model_params = unreplicate(safety_model_params) else: - special_cos_dist, cos_dist = self._get_safety_scores(features, safety_model_params) + has_nsfw_concepts = self._get_has_nsfw_concepts(features, safety_model_params) - images, has_nsfw = self.safety_checker.filtered_with_scores( - special_cos_dist, - cos_dist, - images, - safety_model_params, - ) - return images, has_nsfw + images_was_copied = False + for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): + if has_nsfw_concept: + if not images_was_copied: + images_was_copied = True + images = images.copy() + + images[idx] = np.zeros(images[idx].shape, dtype=np.uint8) # black image + + if any(has_nsfw_concepts): + warnings.warn( + "Potential NSFW content was detected in one or more images. A black image will be returned" + " instead. Try again with a different prompt and/or seed." + ) + + return images, has_nsfw_concepts def _generate( self, @@ -310,8 +319,8 @@ def _p_generate( @partial(jax.pmap, static_broadcasted_argnums=(0,)) -def _p_get_safety_scores(pipe, features, params): - return pipe._get_safety_scores(features, params) +def _p_get_has_nsfw_concepts(pipe, features, params): + return pipe._get_has_nsfw_concepts(features, params) def unshard(x: jnp.ndarray): diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index b3cd8eef02fa..e4ea381a8c9b 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -1,8 +1,5 @@ -import warnings from typing import Optional, Tuple -import numpy as np - import jax import jax.numpy as jnp from flax import linen as nn @@ -39,56 +36,22 @@ def __call__(self, clip_input): special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) - return special_cos_dist, cos_dist - - def filtered_with_scores(self, special_cos_dist, cos_dist, images): - batch_size = special_cos_dist.shape[0] - special_cos_dist = np.asarray(special_cos_dist) - cos_dist = np.asarray(cos_dist) - - result = [] - for i in range(batch_size): - result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} - - # increase this value to create a stronger `nfsw` filter - # at the cost of increasing the possibility of filtering benign image inputs - adjustment = 0.0 - - for concept_idx in range(len(special_cos_dist[0])): - concept_cos = special_cos_dist[i][concept_idx] - concept_threshold = self.special_care_embeds_weights[concept_idx].item() - result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["special_scores"][concept_idx] > 0: - result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) - adjustment = 0.01 - - for concept_idx in range(len(cos_dist[0])): - concept_cos = cos_dist[i][concept_idx] - concept_threshold = self.concept_embeds_weights[concept_idx].item() - result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["concept_scores"][concept_idx] > 0: - result_img["bad_concepts"].append(concept_idx) - - result.append(result_img) - has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] + # increase this value to create a stronger `nfsw` filter + # at the cost of increasing the possibility of filtering benign image inputs + adjustment = 0.0 - images_was_copied = False - for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): - if has_nsfw_concept: - if not images_was_copied: - images_was_copied = True - images = images.copy() + special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment + special_scores = jnp.round(special_scores, 3) + is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) + # Use a lower threshold if an image has any special care concept + special_adjustment = is_special_care * 0.01 - images[idx] = np.zeros(images[idx].shape) # black image + concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment + concept_scores = jnp.round(concept_scores, 3) + has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) - if any(has_nsfw_concepts): - warnings.warn( - "Potential NSFW content was detected in one or more images. A black image will be returned" - " instead. Try again with a different prompt and/or seed." - ) - - return images, has_nsfw_concepts + return has_nsfw_concepts class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): @@ -133,15 +96,3 @@ def __call__( jnp.array(clip_input, dtype=jnp.float32), rngs={}, ) - - def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None): - def _filtered_with_scores(module, special_cos_dist, cos_dist, images): - return module.filtered_with_scores(special_cos_dist, cos_dist, images) - - return self.module.apply( - {"params": params or self.params}, - special_cos_dist, - cos_dist, - images, - method=_filtered_with_scores, - )