Skip to content

Commit c6629e6

Browse files
authored
[flax safety checker] Use FlaxPreTrainedModel for saving/loading (#591)
* use FlaxPreTrainedModel for flax safety module * fix name * fix one more * Apply suggestions from code review
1 parent 8a6833b commit c6629e6

File tree

1 file changed

+64
-28
lines changed

1 file changed

+64
-28
lines changed

src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py

Lines changed: 64 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,33 @@
11
import warnings
2+
from typing import Optional, Tuple
23

34
import numpy as np
45

56
import jax
67
import jax.numpy as jnp
78
from flax import linen as nn
89
from flax.core.frozen_dict import FrozenDict
9-
from flax.struct import field
10-
from transformers import CLIPVisionConfig
10+
from transformers import CLIPConfig, FlaxPreTrainedModel
1111
from transformers.models.clip.modeling_flax_clip import FlaxCLIPVisionModule
1212

13-
from ...configuration_utils import ConfigMixin, flax_register_to_config
14-
from ...modeling_flax_utils import FlaxModelMixin
15-
1613

1714
def jax_cosine_distance(emb_1, emb_2, eps=1e-12):
1815
norm_emb_1 = jnp.divide(emb_1.T, jnp.clip(jnp.linalg.norm(emb_1, axis=1), a_min=eps)).T
1916
norm_emb_2 = jnp.divide(emb_2.T, jnp.clip(jnp.linalg.norm(emb_2, axis=1), a_min=eps)).T
2017
return jnp.matmul(norm_emb_1, norm_emb_2.T)
2118

2219

23-
@flax_register_to_config
24-
class FlaxStableDiffusionSafetyChecker(nn.Module, FlaxModelMixin, ConfigMixin):
25-
projection_dim: int = 768
26-
# CLIPVisionConfig fields
27-
vision_config: dict = field(default_factory=dict)
20+
class FlaxStableDiffusionSafetyCheckerModule(nn.Module):
21+
config: CLIPConfig
2822
dtype: jnp.dtype = jnp.float32
2923

30-
def init_weights(self, rng: jax.random.PRNGKey) -> FrozenDict:
31-
# init input tensor
32-
input_shape = (
33-
1,
34-
self.vision_config["image_size"],
35-
self.vision_config["image_size"],
36-
self.vision_config["num_channels"],
37-
)
38-
pixel_values = jax.random.normal(rng, input_shape)
39-
params_rng, dropout_rng = jax.random.split(rng)
40-
rngs = {"params": params_rng, "dropout": dropout_rng}
41-
return self.init(rngs, pixel_values)["params"]
42-
4324
def setup(self):
44-
clip_vision_config = CLIPVisionConfig(**self.vision_config)
45-
self.vision_model = FlaxCLIPVisionModule(clip_vision_config, dtype=self.dtype)
46-
self.visual_projection = nn.Dense(self.projection_dim, use_bias=False, dtype=self.dtype)
25+
self.vision_model = FlaxCLIPVisionModule(self.config.vision_config)
26+
self.visual_projection = nn.Dense(self.config.projection_dim, use_bias=False, dtype=self.dtype)
4727

48-
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.projection_dim))
28+
self.concept_embeds = self.param("concept_embeds", jax.nn.initializers.ones, (17, self.config.projection_dim))
4929
self.special_care_embeds = self.param(
50-
"special_care_embeds", jax.nn.initializers.ones, (3, self.projection_dim)
30+
"special_care_embeds", jax.nn.initializers.ones, (3, self.config.projection_dim)
5131
)
5232

5333
self.concept_embeds_weights = self.param("concept_embeds_weights", jax.nn.initializers.ones, (17,))
@@ -109,3 +89,59 @@ def filtered_with_scores(self, special_cos_dist, cos_dist, images):
10989
)
11090

11191
return images, has_nsfw_concepts
92+
93+
94+
class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
95+
config_class = CLIPConfig
96+
main_input_name = "clip_input"
97+
module_class = FlaxStableDiffusionSafetyCheckerModule
98+
99+
def __init__(
100+
self,
101+
config: CLIPConfig,
102+
input_shape: Optional[Tuple] = None,
103+
seed: int = 0,
104+
dtype: jnp.dtype = jnp.float32,
105+
_do_init: bool = True,
106+
**kwargs,
107+
):
108+
if input_shape is None:
109+
input_shape = (1, 224, 224, 3)
110+
module = self.module_class(config=config, dtype=dtype, **kwargs)
111+
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
112+
113+
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
114+
# init input tensor
115+
clip_input = jax.random.normal(rng, input_shape)
116+
117+
params_rng, dropout_rng = jax.random.split(rng)
118+
rngs = {"params": params_rng, "dropout": dropout_rng}
119+
120+
random_params = self.module.init(rngs, clip_input)["params"]
121+
122+
return random_params
123+
124+
def __call__(
125+
self,
126+
clip_input,
127+
params: dict = None,
128+
):
129+
clip_input = jnp.transpose(clip_input, (0, 2, 3, 1))
130+
131+
return self.module.apply(
132+
{"params": params or self.params},
133+
jnp.array(clip_input, dtype=jnp.float32),
134+
rngs={},
135+
)
136+
137+
def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict = None):
138+
def _filtered_with_scores(module, special_cos_dist, cos_dist, images):
139+
return module.filtered_with_scores(special_cos_dist, cos_dist, images)
140+
141+
return self.module.apply(
142+
{"params": params or self.params},
143+
special_cos_dist,
144+
cos_dist,
145+
images,
146+
method=_filtered_with_scores,
147+
)

0 commit comments

Comments
 (0)