Skip to content

Commit 5975c2a

Browse files
committed
Add TextToImage and StableDiffusion3TextToImage
1 parent 49def20 commit 5975c2a

15 files changed

+834
-40
lines changed

keras_nlp/api/models/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,16 @@
220220
)
221221
from keras_nlp.src.models.roberta.roberta_tokenizer import RobertaTokenizer
222222
from keras_nlp.src.models.seq_2_seq_lm import Seq2SeqLM
223+
from keras_nlp.src.models.stable_diffusion_v3.stable_diffusion_3_backbone import (
224+
StableDiffusion3Backbone,
225+
)
226+
from keras_nlp.src.models.stable_diffusion_v3.stable_diffusion_3_text_to_image import (
227+
StableDiffusion3TextToImage,
228+
)
223229
from keras_nlp.src.models.t5.t5_backbone import T5Backbone
224230
from keras_nlp.src.models.t5.t5_tokenizer import T5Tokenizer
225231
from keras_nlp.src.models.task import Task
232+
from keras_nlp.src.models.text_to_image import TextToImage
226233
from keras_nlp.src.models.vgg.vgg_backbone import VGGBackbone
227234
from keras_nlp.src.models.vgg.vgg_image_classifier import VGGImageClassifier
228235
from keras_nlp.src.models.vit_det.vit_det_backbone import ViTDetBackbone

keras_nlp/src/models/stable_diffusion_v3/clip_encoder_block.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,20 @@ def quick_gelu(x):
1919
return x * ops.sigmoid(1.702 * x)
2020

2121

22+
class CLIPMultiHeadAttention(layers.MultiHeadAttention):
23+
# We should set compute_dtype to be float32 in Softmax.
24+
# TODO: We can fix this upstream.
25+
def _build_attention(self, rank):
26+
super()._build_attention(rank)
27+
self._softmax.dtype_policy = "float32"
28+
29+
def _masked_softmax(self, attention_scores, attention_mask=None):
30+
attention_scores = super()._masked_softmax(
31+
attention_scores, attention_mask
32+
)
33+
return ops.cast(attention_scores, self.compute_dtype)
34+
35+
2236
class CLIPEncoderBlock(layers.Layer):
2337
def __init__(
2438
self,
@@ -43,16 +57,16 @@ def __init__(
4357
intermediate_activation = quick_gelu
4458

4559
self.layer_norm_1 = layers.LayerNormalization(
46-
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_1"
60+
epsilon=1e-5, dtype="float32", name="layer_norm_1"
4761
)
48-
self.attention = layers.MultiHeadAttention(
62+
self.attention = CLIPMultiHeadAttention(
4963
num_heads,
5064
hidden_dim // num_heads,
5165
dtype=self.dtype_policy,
5266
name="attention",
5367
)
5468
self.layer_norm_2 = layers.LayerNormalization(
55-
epsilon=0.00001, dtype=self.dtype_policy, name="layer_norm_2"
69+
epsilon=1e-5, dtype="float32", name="layer_norm_2"
5670
)
5771
self.dense_1 = layers.Dense(
5872
self.intermediate_dim, dtype=self.dtype_policy, name="dense_1"

keras_nlp/src/models/stable_diffusion_v3/clip_preprocessor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def __init__(
3636
tokenizer,
3737
sequence_length=77,
3838
add_start_token=True,
39-
add_end_token=False,
39+
add_end_token=True,
4040
to_lower=True,
41-
pad_with_end_token=True,
41+
pad_with_end_token=False,
4242
**kwargs,
4343
):
4444
super().__init__(**kwargs)

keras_nlp/src/models/stable_diffusion_v3/clip_text_encoder.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,44 @@
2323
)
2424

2525

26+
class Projection(layers.Layer):
27+
def __init__(self, hidden_dim, **kwargs):
28+
super().__init__(**kwargs)
29+
self.hidden_dim = int(hidden_dim)
30+
31+
self.text_projection = layers.Dense(
32+
hidden_dim,
33+
use_bias=False,
34+
dtype=self.dtype_policy,
35+
name="text_projection",
36+
)
37+
38+
def build(self, inputs_shape, token_ids_shape):
39+
inputs_shape = list(inputs_shape)
40+
self.text_projection.build([None, inputs_shape[-1]])
41+
self.text_projection._kernel.assign(
42+
ops.transpose(ops.eye(self.hidden_dim), (1, 0))
43+
)
44+
45+
def call(self, inputs, token_ids):
46+
indices = ops.expand_dims(
47+
ops.cast(ops.argmax(token_ids, axis=-1), "int32"), axis=-1
48+
)
49+
pooled_output = ops.take_along_axis(inputs, indices[:, :, None], axis=1)
50+
pooled_output = ops.squeeze(pooled_output, axis=1)
51+
projection_output = self.text_projection(pooled_output)
52+
return projection_output, pooled_output
53+
54+
def get_config(self):
55+
config = super().get_config()
56+
config.update(
57+
{
58+
"hidden_dim": self.hidden_dim,
59+
}
60+
)
61+
return config
62+
63+
2664
class CLIPTextEncoder(keras.Model):
2765
def __init__(
2866
self,
@@ -63,13 +101,10 @@ def __init__(
63101
for _ in range(num_layers)
64102
]
65103
self.layer_norm = layers.LayerNormalization(
66-
epsilon=0.00001, dtype=dtype, name="layer_norm"
104+
epsilon=1e-6, dtype="float32", name="layer_norm"
67105
)
68-
self.text_projection = layers.Dense(
69-
hidden_dim,
70-
use_bias=False,
71-
dtype=dtype,
72-
name="text_projection",
106+
self.text_projection = Projection(
107+
hidden_dim, dtype=dtype, name="text_projection"
73108
)
74109

75110
# === Functional Model ===
@@ -78,24 +113,19 @@ def __init__(
78113
)
79114
x = self.embedding(encoder_token_ids)
80115
encoder_intermediate_output = None
116+
81117
# Encoder.
82118
for i, block in enumerate(self.encoder_layers):
83119
x = block(x)
84120
if i == intermediate_output_index:
85121
encoder_intermediate_output = x
86122
x = self.layer_norm(x)
87123
encoder_output = x
88-
if encoder_intermediate_output is not None:
89-
encoder_intermediate_output = self.layer_norm(
90-
encoder_intermediate_output
91-
)
124+
92125
# Projection.
93-
indices = ops.expand_dims(
94-
ops.cast(ops.argmax(encoder_token_ids, axis=-1), "int32"), axis=-1
126+
projection_output, pooled_output = self.text_projection(
127+
x, encoder_token_ids
95128
)
96-
pooled_output = ops.take_along_axis(x, indices[:, :, None], axis=1)
97-
pooled_output = ops.squeeze(pooled_output, axis=1)
98-
projection_output = self.text_projection(pooled_output)
99129

100130
outputs = {
101131
"encoder_sequence_output": encoder_output,
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2024 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from keras import ops
15+
16+
17+
class FlowMatchEulerDiscreteScheduler:
18+
def __init__(self, num_train_timesteps=1000, shift=1.0):
19+
self.num_train_timesteps = int(num_train_timesteps)
20+
self.shift = float(shift)
21+
22+
timesteps = ops.linspace(
23+
1, num_train_timesteps, num_train_timesteps, dtype="float32"
24+
)
25+
timesteps = ops.flip(timesteps, axis=0)
26+
sigmas = self.timestep_to_sigma(timesteps)
27+
28+
self.timesteps = ops.multiply(sigmas, num_train_timesteps)
29+
self.sigma_min = sigmas[-1]
30+
self.sigma_max = sigmas[0]
31+
32+
def sigma_to_timestep(self, sigma):
33+
return sigma * self.num_train_timesteps
34+
35+
def timestep_to_sigma(self, timestep):
36+
sigma = ops.divide(timestep, self.num_train_timesteps)
37+
if self.shift != 1.0:
38+
sigma = ops.divide(
39+
ops.multiply(self.shift, sigma),
40+
ops.add(1, ops.multiply(self.shift - 1.0, sigma)),
41+
)
42+
return sigma
43+
44+
def get_sigma(self, step, num_steps):
45+
start = self.sigma_to_timestep(self.sigma_max)
46+
end = self.sigma_to_timestep(self.sigma_min)
47+
step_size = ops.divide(
48+
ops.subtract(end, start), ops.subtract(num_steps, 1)
49+
)
50+
result_timestep = ops.add(start, ops.multiply(step, step_size))
51+
result_sigma = self.timestep_to_sigma(result_timestep)
52+
return ops.maximum(result_sigma, 0.0)
53+
54+
def step(self, latents, noise_residual, sigma, sigma_next):
55+
return latents + (sigma_next - sigma) * noise_residual

keras_nlp/src/models/stable_diffusion_v3/mmdit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __init__(self, hidden_dim, output_dim, **kwargs):
186186
epsilon=1e-6,
187187
center=False,
188188
scale=False,
189-
dtype=self.dtype_policy,
189+
dtype="float32",
190190
name="norm",
191191
)
192192
self.output_dense = layers.Dense(
@@ -274,7 +274,7 @@ def __init__(
274274
output_dim,
275275
mlp_ratio=4.0,
276276
latent_shape=(64, 64, 16),
277-
context_shape=(1024, 4096),
277+
context_shape=(None, 4096),
278278
pooled_projection_shape=(2048,),
279279
data_format=None,
280280
dtype=None,

keras_nlp/src/models/stable_diffusion_v3/mmdit_block.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
epsilon=1e-6,
5656
center=False,
5757
scale=False,
58-
dtype=self.dtype_policy,
58+
dtype="float32",
5959
name="norm1",
6060
)
6161
self.attention_qkv = layers.Dense(
@@ -69,7 +69,7 @@ def __init__(
6969
epsilon=1e-6,
7070
center=False,
7171
scale=False,
72-
dtype=self.dtype_policy,
72+
dtype="float32",
7373
name="norm2",
7474
)
7575
self.mlp = models.Sequential(
@@ -230,6 +230,7 @@ def __init__(
230230
dtype=self.dtype_policy,
231231
name="context_block",
232232
)
233+
self.softmax = layers.Softmax(dtype="float32")
233234

234235
def build(self, inputs_shape, context_shape, timestep_embedding_shape):
235236
self.x_block.build(inputs_shape, timestep_embedding_shape)
@@ -240,7 +241,9 @@ def _compute_attention(self, query, key, value):
240241
query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
241242
)
242243
attention_scores = ops.einsum(self._dot_product_equation, key, query)
243-
attention_scores = ops.nn.softmax(attention_scores, axis=-1)
244+
original_dtype = attention_scores.dtype
245+
attention_scores = self.softmax(attention_scores)
246+
attention_scores = ops.cast(attention_scores, original_dtype)
244247
attention_output = ops.einsum(
245248
self._combine_equation, attention_scores, value
246249
)

0 commit comments

Comments
 (0)