|
17 | 17 | import jax.numpy as jnp |
18 | 18 |
|
19 | 19 |
|
20 | | -# This is like models.embeddings.get_timestep_embedding (PyTorch) but |
21 | | -# less general (only handles the case we currently need). |
22 | | -def get_sinusoidal_embeddings(timesteps, embedding_dim, freq_shift: float = 1): |
| 20 | +def get_sinusoidal_embeddings( |
| 21 | + timesteps: jnp.ndarray, |
| 22 | + embedding_dim: int, |
| 23 | + freq_shift: float = 1, |
| 24 | + min_timescale: float = 1, |
| 25 | + max_timescale: float = 1.0e4, |
| 26 | + flip_sin_to_cos: bool = False, |
| 27 | + scale: float = 1.0, |
| 28 | +) -> jnp.ndarray: |
| 29 | + """Returns the positional encoding (same as Tensor2Tensor). |
| 30 | + Args: |
| 31 | + timesteps: a 1-D Tensor of N indices, one per batch element. |
| 32 | + These may be fractional. |
| 33 | + embedding_dim: The number of output channels. |
| 34 | + min_timescale: The smallest time unit (should probably be 0.0). |
| 35 | + max_timescale: The largest time unit. |
| 36 | + Returns: |
| 37 | + a Tensor of timing signals [N, num_channels] |
23 | 38 | """ |
24 | | - This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. |
| 39 | + assert timesteps.ndim == 1, "Timesteps should be a 1d-array" |
| 40 | + assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even" |
| 41 | + num_timescales = float(embedding_dim // 2) |
| 42 | + log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift) |
| 43 | + inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment) |
| 44 | + emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0) |
25 | 45 |
|
26 | | - :param timesteps: a 1-D tensor of N indices, one per batch element. |
27 | | - These may be fractional. |
28 | | - :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the |
29 | | - embeddings. :return: an [N x dim] tensor of positional embeddings. |
30 | | - """ |
31 | | - half_dim = embedding_dim // 2 |
32 | | - emb = math.log(10000) / (half_dim - freq_shift) |
33 | | - emb = jnp.exp(jnp.arange(half_dim) * -emb) |
34 | | - emb = timesteps[:, None] * emb[None, :] |
35 | | - emb = jnp.concatenate([jnp.cos(emb), jnp.sin(emb)], -1) |
36 | | - return emb |
| 46 | + # scale embeddings |
| 47 | + scaled_time = scale * emb |
| 48 | + |
| 49 | + if flip_sin_to_cos: |
| 50 | + signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1) |
| 51 | + else: |
| 52 | + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) |
| 53 | + signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim]) |
| 54 | + return signal |
37 | 55 |
|
38 | 56 |
|
39 | 57 | class FlaxTimestepEmbedding(nn.Module): |
@@ -70,4 +88,4 @@ class FlaxTimesteps(nn.Module): |
70 | 88 |
|
71 | 89 | @nn.compact |
72 | 90 | def __call__(self, timesteps): |
73 | | - return get_sinusoidal_embeddings(timesteps, self.dim, freq_shift=self.freq_shift) |
| 91 | + return get_sinusoidal_embeddings(timesteps, embedding_dim=self.dim, freq_shift=self.freq_shift) |
0 commit comments