|
1 | 1 | import warnings |
| 2 | +from typing import Optional, Tuple |
2 | 3 |
|
3 | 4 | import numpy as np |
4 | 5 |
|
5 | 6 | import jax |
6 | 7 | import jax.numpy as jnp |
7 | 8 | from flax import linen as nn |
8 | 9 | from flax.core.frozen_dict import FrozenDict |
9 | | -from flax.struct import field |
10 | | -from transformers import CLIPVisionConfig |
| 10 | +from transformers import CLIPConfig, FlaxPreTrainedModel |
11 | 11 | from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule |
12 | 12 |
|
13 | | -from ...configuration_utils import ConfigMixin, flax_register_to_config |
14 | | -from ...modeling_flax_utils import FlaxModelMixin |
15 | | - |
16 | 13 |
|
17 | 14 | def jax_cosine_distance(emb_1, emb_2, eps=1e-12): |
18 | 15 | norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T |
19 | 16 | norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T |
20 | 17 | return jnp.matmul(norm_emb_1, norm_emb_2.T) |
21 | 18 |
|
22 | 19 |
|
23 | | -@flax_register_to_config |
24 | | -class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin): |
25 | | - projection_dim: int = 768 |
26 | | - # CLIPVisionConfig fields |
27 | | - vision_config: dict = field(default_factory=dict) |
| 20 | +class FlaxStableDiffusionSafetyCheckerModule(nn.Module): |
| 21 | + config: CLIPConfig |
28 | 22 | dtype: jnp.dtype = jnp.float32 |
29 | 23 |
|
30 | | - def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict: |
31 | | - # init input tensor |
32 | | - input_shape = ( |
33 | | - 1, |
34 | | - self.vision_config["image_size"], |
35 | | - self.vision_config["image_size"], |
36 | | - self.vision_config["num_channels"], |
37 | | - ) |
38 | | - pixel_values = jax.random.normal(rng, input_shape) |
39 | | - params_rng, dropout_rng = jax.random.split(rng) |
40 | | - rngs = {"params": params_rng, "dropout": dropout_rng} |
41 | | - return self.init(rngs, pixel_values)["params"] |
42 | | - |
43 | 24 | def setup(self): |
44 | | - clip_vision_config = CLIPVisionConfig(**self.vision_config) |
45 | | - self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype) |
46 | | - self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype) |
| 25 | + self.vision_model = FlaxCLIPVisionModule(self.config.vision_config) |
| 26 | + self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype) |
47 | 27 |
|
48 | | - self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim)) |
| 28 | + self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim)) |
49 | 29 | self.special_care_embeds = self.param( |
50 | | - "special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim) |
| 30 | + "special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim) |
51 | 31 | ) |
52 | 32 |
|
53 | 33 | self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,)) |
@@ -109,3 +89,59 @@ def filtered_with_scores(self, special_cos_dist, cos_dist, images): |
109 | 89 | ) |
110 | 90 |
|
111 | 91 | return images, has_nsfw_concepts |
| 92 | + |
| 93 | + |
| 94 | +class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): |
| 95 | + config_class = CLIPConfig |
| 96 | + main_input_name = "clip_input" |
| 97 | + module_class = FlaxStableDiffusionSafetyCheckerModule |
| 98 | + |
| 99 | + def __init__( |
| 100 | + self, |
| 101 | + config: CLIPConfig, |
| 102 | + input_shape: Optional[Tuple] = None, |
| 103 | + seed: int = 0, |
| 104 | + dtype: jnp.dtype = jnp.float32, |
| 105 | + _do_init: bool = True, |
| 106 | + **kwargs, |
| 107 | + ): |
| 108 | + if input_shape is None: |
| 109 | + input_shape = (1, 224, 224, 3) |
| 110 | + module = self.module_class(config=config, dtype=dtype, **kwargs) |
| 111 | + super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) |
| 112 | + |
| 113 | + def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: |
| 114 | + # init input tensor |
| 115 | + clip_input = jax.random.normal(rng, input_shape) |
| 116 | + |
| 117 | + params_rng, dropout_rng = jax.random.split(rng) |
| 118 | + rngs = {"params": params_rng, "dropout": dropout_rng} |
| 119 | + |
| 120 | + random_params = self.module.init(rngs, clip_input)["params"] |
| 121 | + |
| 122 | + return random_params |
| 123 | + |
| 124 | + def __call__( |
| 125 | + self, |
| 126 | + clip_input, |
| 127 | + params: dict = None, |
| 128 | + ): |
| 129 | + clip_input = jnp.transpose(clip_input, (0, 2, 3, 1)) |
| 130 | + |
| 131 | + return self.module.apply( |
| 132 | + {"params": params or self.params}, |
| 133 | + jnp.array(clip_input, dtype=jnp.float32), |
| 134 | + rngs={}, |
| 135 | + ) |
| 136 | + |
| 137 | + def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None): |
| 138 | + def _filtered_with_scores(module, special_cos_dist, cos_dist, images): |
| 139 | + return module.filtered_with_scores(special_cos_dist, cos_dist, images) |
| 140 | + |
| 141 | + return self.module.apply( |
| 142 | + {"params": params or self.params}, |
| 143 | + special_cos_dist, |
| 144 | + cos_dist, |
| 145 | + images, |
| 146 | + method=_filtered_with_scores, |
| 147 | + ) |
0 commit comments