File tree Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Expand file tree Collapse file tree 1 file changed +2
-3
lines changed Original file line number Diff line number Diff line change @@ -192,7 +192,7 @@ def setup(self):
192192 def __call__ (
193193 self ,
194194 sample ,
195- timestep ,
195+ timesteps ,
196196 encoder_hidden_states ,
197197 return_dict : bool = True ,
198198 train : bool = False ,
@@ -214,11 +214,10 @@ def __call__(
214214 When returning a tuple, the first element is the sample tensor.
215215 """
216216 # 1. time
217- timesteps = timestep
218217 if not isinstance (timesteps , jnp .ndarray ):
219218 timesteps = jnp .array ([timesteps ], dtype = jnp .int32 )
220219 elif isinstance (timesteps , jnp .ndarray ) and len (timesteps .shape ) == 0 :
221- timesteps = timesteps .to (dtype = jnp .float32 )
220+ timesteps = timesteps .astype (dtype = jnp .float32 )
222221 timesteps = timesteps [None ]
223222
224223 t_emb = self .time_proj (timesteps )
You can’t perform that action at this time.
0 commit comments