@@ -248,8 +248,8 @@ def _step(
248248 )
249249 if self .missing_tolerance and next_tensordict_select .is_empty ():
250250 return next_tensordict
251- next_tensordict_norm = self ._stateful_norm (next_tensordict_select )
252251 self ._stateful_update (next_tensordict_select )
252+ next_tensordict_norm = self ._stateful_norm (next_tensordict_select )
253253 else :
254254 self ._maybe_stateless_init (tensordict )
255255 next_tensordict_select = next_tensordict .select (
@@ -261,10 +261,10 @@ def _step(
261261 var = tensordict [f"{ self .prefix } _var" ]
262262 count = tensordict [f"{ self .prefix } _count" ]
263263
264- next_tensordict_norm = self ._stateless_norm (
264+ loc , var , count = self ._stateless_update (
265265 next_tensordict_select , loc , var , count
266266 )
267- loc , var , count = self ._stateless_update (
267+ next_tensordict_norm = self ._stateless_norm (
268268 next_tensordict_select , loc , var , count
269269 )
270270 # updates have been done in-place, we're good
@@ -328,14 +328,24 @@ def _in_keys_safe(self):
328328 return self .in_keys [:- 3 ]
329329 return self .in_keys
330330
331- def _norm (self , data , loc , var ):
331+ def _norm (self , data , loc , var , count ):
332332 if self .missing_tolerance :
333333 loc = loc .select (* data .keys (True , True ))
334334 var = var .select (* data .keys (True , True ))
335+ count = count .select (* data .keys (True , True ))
335336 if loc .is_empty ():
336337 return data
337338
339+ if self .decay < 1.0 :
340+ bias_correction = 1 - (count * math .log (self .decay )).exp ()
341+ bias_correction = bias_correction .apply (lambda x , y : x .to (y .dtype ), data )
342+ else :
343+ bias_correction = 1
344+
338345 var = var - loc .pow (2 )
346+ loc = loc / bias_correction
347+ var = var / bias_correction
348+
339349 scale = var .sqrt ().clamp_min (self .eps )
340350
341351 data_update = (data - loc ) / scale
@@ -348,7 +358,7 @@ def _norm(self, data, loc, var):
348358 return data_update
349359
350360 def _stateful_norm (self , data ):
351- return self ._norm (data , self ._loc , self ._var )
361+ return self ._norm (data , self ._loc , self ._var , self . _count )
352362
353363 def _stateful_update (self , data ):
354364 if self .frozen :
@@ -363,12 +373,7 @@ def _stateful_update(self, data):
363373 count = self ._count
364374 count += 1
365375 data = self ._maybe_cast_to_float (data )
366- if self .decay < 1.0 :
367- bias_correction = 1 - (count * math .log (self .decay )).exp ()
368- bias_correction = bias_correction .apply (lambda x , y : x .to (y .dtype ), data )
369- else :
370- bias_correction = 1
371- weight = (1 - self .decay ) / bias_correction
376+ weight = 1 - self .decay
372377 loc .lerp_ (end = data , weight = weight )
373378 var .lerp_ (end = data .pow (2 ), weight = weight )
374379
@@ -398,20 +403,15 @@ def _maybe_stateless_init(self, data):
398403 data [f"{ self .prefix } _var" ] = var
399404
400405 def _stateless_norm (self , data , loc , var , count ):
401- data = self ._norm (data , loc , var )
406+ data = self ._norm (data , loc , var , count )
402407 return data
403408
404409 def _stateless_update (self , data , loc , var , count ):
405410 if self .frozen :
406411 return loc , var , count
407412 count = count + 1
408413 data = self ._maybe_cast_to_float (data )
409- if self .decay < 1.0 :
410- bias_correction = 1 - (count * math .log (self .decay )).exp ()
411- bias_correction = bias_correction .apply (lambda x , y : x .to (y .dtype ), data )
412- else :
413- bias_correction = 1
414- weight = (1 - self .decay ) / bias_correction
414+ weight = 1 - self .decay
415415 loc = loc .lerp (end = data , weight = weight )
416416 var = var .lerp (end = data .pow (2 ), weight = weight )
417417 return loc , var , count
@@ -563,10 +563,18 @@ def to_observation_norm(self) -> Compose | ObservationNorm:
563563 def _get_loc_scale (self , loc_only : bool = False ) -> tuple :
564564 if self .stateful :
565565 loc = self ._loc
566+ count = self ._count
567+ if self .decay < 1.0 :
568+ bias_correction = 1 - (count * math .log (self .decay )).exp ()
569+ bias_correction = bias_correction .apply (lambda x , y : x .to (y .dtype ), loc )
570+ else :
571+ bias_correction = 1
566572 if loc_only :
567- return loc , None
573+ return loc / bias_correction , None
568574 var = self ._var
569575 var = var - loc .pow (2 )
576+ loc = loc / bias_correction
577+ var = var / bias_correction
570578 scale = var .sqrt ().clamp_min (self .eps )
571579 return loc , scale
572580 else :
0 commit comments