|
1 | | -import warnings |
2 | 1 | from typing import Optional, Tuple |
3 | 2 |
|
4 | | -import numpy as np |
5 | | - |
6 | 3 | import jax |
7 | 4 | import jax.numpy as jnp |
8 | 5 | from flax import linen as nn |
@@ -39,56 +36,22 @@ def __call__(self, clip_input): |
39 | 36 |
|
40 | 37 | special_cos_dist = jax_cosine_distance(image_embeds, self.special_care_embeds) |
41 | 38 | cos_dist = jax_cosine_distance(image_embeds, self.concept_embeds) |
42 | | - return special_cos_dist, cos_dist |
43 | | - |
44 | | - def filtered_with_scores(self, special_cos_dist, cos_dist, images): |
45 | | - batch_size = special_cos_dist.shape[0] |
46 | | - special_cos_dist = np.asarray(special_cos_dist) |
47 | | - cos_dist = np.asarray(cos_dist) |
48 | | - |
49 | | - result = [] |
50 | | - for i in range(batch_size): |
51 | | - result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} |
52 | | - |
53 | | - # increase this value to create a stronger `nfsw` filter |
54 | | - # at the cost of increasing the possibility of filtering benign image inputs |
55 | | - adjustment = 0.0 |
56 | | - |
57 | | - for concept_idx in range(len(special_cos_dist[0])): |
58 | | - concept_cos = special_cos_dist[i][concept_idx] |
59 | | - concept_threshold = self.special_care_embeds_weights[concept_idx].item() |
60 | | - result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) |
61 | | - if result_img["special_scores"][concept_idx] > 0: |
62 | | - result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) |
63 | | - adjustment = 0.01 |
64 | | - |
65 | | - for concept_idx in range(len(cos_dist[0])): |
66 | | - concept_cos = cos_dist[i][concept_idx] |
67 | | - concept_threshold = self.concept_embeds_weights[concept_idx].item() |
68 | | - result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) |
69 | | - if result_img["concept_scores"][concept_idx] > 0: |
70 | | - result_img["bad_concepts"].append(concept_idx) |
71 | | - |
72 | | - result.append(result_img) |
73 | 39 |
|
74 | | - has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result] |
| 40 | + # increase this value to create a stronger `nfsw` filter |
| 41 | + # at the cost of increasing the possibility of filtering benign image inputs |
| 42 | + adjustment = 0.0 |
75 | 43 |
|
76 | | - images_was_copied = False |
77 | | - for idx, has_nsfw_concept in enumerate(has_nsfw_concepts): |
78 | | - if has_nsfw_concept: |
79 | | - if not images_was_copied: |
80 | | - images_was_copied = True |
81 | | - images = images.copy() |
| 44 | + special_scores = special_cos_dist - self.special_care_embeds_weights[None, :] + adjustment |
| 45 | + special_scores = jnp.round(special_scores, 3) |
| 46 | + is_special_care = jnp.any(special_scores > 0, axis=1, keepdims=True) |
| 47 | + # Use a lower threshold if an image has any special care concept |
| 48 | + special_adjustment = is_special_care * 0.01 |
82 | 49 |
|
83 | | - images[idx] = np.zeros(images[idx].shape) # black image |
| 50 | + concept_scores = cos_dist - self.concept_embeds_weights[None, :] + special_adjustment |
| 51 | + concept_scores = jnp.round(concept_scores, 3) |
| 52 | + has_nsfw_concepts = jnp.any(concept_scores > 0, axis=1) |
84 | 53 |
|
85 | | - if any(has_nsfw_concepts): |
86 | | - warnings.warn( |
87 | | - "Potential NSFW content was detected in one or more images. A black image will be returned" |
88 | | - " instead. Try again with a different prompt and/or seed." |
89 | | - ) |
90 | | - |
91 | | - return images, has_nsfw_concepts |
| 54 | + return has_nsfw_concepts |
92 | 55 |
|
93 | 56 |
|
94 | 57 | class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): |
@@ -133,15 +96,3 @@ def __call__( |
133 | 96 | jnp.array(clip_input, dtype=jnp.float32), |
134 | 97 | rngs={}, |
135 | 98 | ) |
136 | | - |
137 | | - def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None): |
138 | | - def _filtered_with_scores(module, special_cos_dist, cos_dist, images): |
139 | | - return module.filtered_with_scores(special_cos_dist, cos_dist, images) |
140 | | - |
141 | | - return self.module.apply( |
142 | | - {"params": params or self.params}, |
143 | | - special_cos_dist, |
144 | | - cos_dist, |
145 | | - images, |
146 | | - method=_filtered_with_scores, |
147 | | - ) |
0 commit comments