Skip to content

Commit 0b61cea

Browse files
authored
[Flax] time embedding (#1081)
* initial get_sinusoidal_embeddings * added asserts * better var name * fix docs
1 parent 33c4874 commit 0b61cea

File tree

1 file changed

+34
-16
lines changed

1 file changed

+34
-16
lines changed

src/diffusers/models/embeddings_flax.py

Lines changed: 34 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,41 @@
1717
import jax.numpy as jnp
1818

1919

20-
# This is like models.embeddings.get_timestep_embedding (PyTorch) but
21-
# less general (only handles the case we currently need).
22-
def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1):
20+
def get_sinusoidal_embeddings(
21+
timesteps: jnp.ndarray,
22+
embedding_dim: int,
23+
freq_shift: float = 1,
24+
min_timescale: float = 1,
25+
max_timescale: float = 1.0e4,
26+
flip_sin_to_cos: bool = False,
27+
scale: float = 1.0,
28+
) -> jnp.ndarray:
29+
"""Returns the positional encoding (same as Tensor2Tensor).
30+
Args:
31+
timesteps: a 1-D Tensor of N indices, one per batch element.
32+
These may be fractional.
33+
embedding_dim: The number of output channels.
34+
min_timescale: The smallest time unit (should probably be 0.0).
35+
max_timescale: The largest time unit.
36+
Returns:
37+
a Tensor of timing signals [N, num_channels]
2338
"""
24-
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
39+
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
40+
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
41+
num_timescales = float(embedding_dim // 2)
42+
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
43+
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
44+
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
2545

26-
:param timesteps: a 1-D tensor of N indices, one per batch element.
27-
These may be fractional.
28-
:param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the
29-
embeddings. :return: an [N x dim] tensor of positional embeddings.
30-
"""
31-
half_dim = embedding_dim // 2
32-
emb = math.log(10000) / (half_dim - freq_shift)
33-
emb = jnp.exp(jnp.arange(half_dim) * -emb)
34-
emb = timesteps[:, None] * emb[None, :]
35-
emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1)
36-
return emb
46+
# scale embeddings
47+
scaled_time = scale * emb
48+
49+
if flip_sin_to_cos:
50+
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
51+
else:
52+
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
53+
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
54+
return signal
3755

3856

3957
class FlaxTimestepEmbedding(nn.Module):
@@ -70,4 +88,4 @@ class FlaxTimesteps(nn.Module):
7088

7189
@nn.compact
7290
def __call__(self, timesteps):
73-
return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift)
91+
return get_sinusoidal_embeddings(timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift)

0 commit comments

Comments
 (0)