Skip to content

Commit 7527ab1

Browse files
committed
replace unsued arg
1 parent e9dde49 commit 7527ab1

File tree

1 file changed

+2
-3
lines changed

1 file changed

+2
-3
lines changed

src/diffusers/models/unet_2d_condition_flax.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def setup(self):
192192
def __call__(
193193
self,
194194
sample,
195-
timestep,
195+
timesteps,
196196
encoder_hidden_states,
197197
return_dict: bool = True,
198198
train: bool = False,
@@ -214,11 +214,10 @@ def __call__(
214214
When returning a tuple, the first element is the sample tensor.
215215
"""
216216
# 1. time
217-
timesteps = timestep
218217
if not isinstance(timesteps, jnp.ndarray):
219218
timesteps = jnp.array([timesteps], dtype=jnp.int32)
220219
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
221-
timesteps = timesteps.to(dtype=jnp.float32)
220+
timesteps = timesteps.astype(dtype=jnp.float32)
222221
timesteps = timesteps[None]
223222

224223
t_emb = self.time_proj(timesteps)

0 commit comments

Comments
 (0)