@@ -248,8 +248,8 @@ def _step(
248248 )
249249 if self .missing_tolerance and next_tensordict_select .is_empty ():
250250 return next_tensordict
251- next_tensordict_norm = self ._stateful_norm (next_tensordict_select )
252251 self ._stateful_update (next_tensordict_select )
252+ next_tensordict_norm = self ._stateful_norm (next_tensordict_select )
253253 else :
254254 self ._maybe_stateless_init (tensordict )
255255 next_tensordict_select = next_tensordict .select (
@@ -261,10 +261,10 @@ def _step(
261261 var = tensordict [f"{ self .prefix } _var" ]
262262 count = tensordict [f"{ self .prefix } _count" ]
263263
264- next_tensordict_norm = self ._stateless_norm (
264+ loc , var , count = self ._stateless_update (
265265 next_tensordict_select , loc , var , count
266266 )
267- loc , var , count = self ._stateless_update (
267+ next_tensordict_norm = self ._stateless_norm (
268268 next_tensordict_select , loc , var , count
269269 )
270270 # updates have been done in-place, we're good
@@ -328,27 +328,38 @@ def _in_keys_safe(self):
328328 return self .in_keys [:- 3 ]
329329 return self .in_keys
330330
331- def _norm (self , data , loc , var ):
331+ def _norm (self , data , loc , var , count ):
332332 if self .missing_tolerance :
333333 loc = loc .select (* data .keys (True , True ))
334334 var = var .select (* data .keys (True , True ))
335+ count = count .select (* data .keys (True , True ))
335336 if loc .is_empty ():
336337 return data
337338
339+ if self .decay < 1.0 :
340+ bias_correction = 1 - (count * math .log (self .decay )).exp ()
341+ bias_correction = bias_correction .apply (lambda x , y : x .to (y .dtype ), data )
342+ else :
343+ bias_correction = 1
344+
338345 var = var - loc .pow (2 )
346+ loc = loc / bias_correction
347+ var = var / bias_correction
348+
339349 scale = var .sqrt ().clamp_min (self .eps )
340350
341351 data_update = (data - loc ) / scale
342352 if self .out_keys [: len (self .in_keys )] != self .in_keys :
343353 # map names
344354 for in_key , out_key in _zip_strict (self ._in_keys_safe , self .out_keys ):
345- data_update .rename_key_ (in_key , out_key )
355+ if in_key in data_update :
356+ data_update .rename_key_ (in_key , out_key )
346357 else :
347358 pass
348359 return data_update
349360
350361 def _stateful_norm (self , data ):
351- return self ._norm (data , self ._loc , self ._var )
362+ return self ._norm (data , self ._loc , self ._var , self . _count )
352363
353364 def _stateful_update (self , data ):
354365 if self .frozen :
@@ -363,14 +374,14 @@ def _stateful_update(self, data):
363374 count = self ._count
364375 count += 1
365376 data = self ._maybe_cast_to_float (data )
366- if self .decay < 1.0 :
367- bias_correction = 1 - (count * math .log (self .decay )).exp ()
368- bias_correction = bias_correction .apply (lambda x , y : x .to (y .dtype ), data )
377+ if self .decay != 1.0 :
378+ weight = 1 - self .decay
379+ loc .lerp_ (end = data , weight = weight )
380+ var .lerp_ (end = data .pow (2 ), weight = weight )
369381 else :
370- bias_correction = 1
371- weight = (1 - self .decay ) / bias_correction
372- loc .lerp_ (end = data , weight = weight )
373- var .lerp_ (end = data .pow (2 ), weight = weight )
382+ weight = 1 / count
383+ loc .lerp_ (end = data , weight = weight )
384+ var .lerp_ (end = data .pow (2 ), weight = weight )
374385
375386 def _maybe_stateless_init (self , data ):
376387 if not self .initialized or f"{ self .prefix } _loc" not in data .keys ():
@@ -398,20 +409,18 @@ def _maybe_stateless_init(self, data):
398409 data [f"{ self .prefix } _var" ] = var
399410
400411 def _stateless_norm (self , data , loc , var , count ):
401- data = self ._norm (data , loc , var )
412+ data = self ._norm (data , loc , var , count )
402413 return data
403414
404415 def _stateless_update (self , data , loc , var , count ):
405416 if self .frozen :
406417 return loc , var , count
407418 count = count + 1
408419 data = self ._maybe_cast_to_float (data )
409- if self .decay < 1.0 :
410- bias_correction = 1 - (count * math .log (self .decay )).exp ()
411- bias_correction = bias_correction .apply (lambda x , y : x .to (y .dtype ), data )
420+ if self .decay != 1.0 :
421+ weight = 1 - self .decay
412422 else :
413- bias_correction = 1
414- weight = (1 - self .decay ) / bias_correction
423+ weight = 1 / count
415424 loc = loc .lerp (end = data , weight = weight )
416425 var = var .lerp (end = data .pow (2 ), weight = weight )
417426 return loc , var , count
@@ -563,10 +572,18 @@ def to_observation_norm(self) -> Compose | ObservationNorm:
563572 def _get_loc_scale (self , loc_only : bool = False ) -> tuple :
564573 if self .stateful :
565574 loc = self ._loc
575+ count = self ._count
576+ if self .decay != 1.0 :
577+ bias_correction = 1 - (count * math .log (self .decay )).exp ()
578+ bias_correction = bias_correction .apply (lambda x , y : x .to (y .dtype ), loc )
579+ else :
580+ bias_correction = 1
566581 if loc_only :
567- return loc , None
582+ return loc / bias_correction , None
568583 var = self ._var
569584 var = var - loc .pow (2 )
585+ loc = loc / bias_correction
586+ var = var / bias_correction
570587 scale = var .sqrt ().clamp_min (self .eps )
571588 return loc , scale
572589 else :
0 commit comments