@@ -374,9 +374,14 @@ def _stateful_update(self, data):
374374 count = self ._count
375375 count += 1
376376 data = self ._maybe_cast_to_float (data )
377- weight = 1 - self .decay
378- loc .lerp_ (end = data , weight = weight )
379- var .lerp_ (end = data .pow (2 ), weight = weight )
377+ if self .decay != 1.0 :
378+ weight = 1 - self .decay
379+ loc .lerp_ (end = data , weight = weight )
380+ var .lerp_ (end = data .pow (2 ), weight = weight )
381+ else :
382+ weight = 1 / count
383+ loc .lerp_ (end = data , weight = weight )
384+ var .lerp_ (end = data .pow (2 ), weight = weight )
380385
381386 def _maybe_stateless_init (self , data ):
382387 if not self .initialized or f"{ self .prefix } _loc" not in data .keys ():
@@ -412,7 +417,10 @@ def _stateless_update(self, data, loc, var, count):
412417 return loc , var , count
413418 count = count + 1
414419 data = self ._maybe_cast_to_float (data )
415- weight = 1 - self .decay
420+ if self .decay != 1.0 :
421+ weight = 1 - self .decay
422+ else :
423+ weight = 1 / count
416424 loc = loc .lerp (end = data , weight = weight )
417425 var = var .lerp (end = data .pow (2 ), weight = weight )
418426 return loc , var , count
@@ -565,7 +573,7 @@ def _get_loc_scale(self, loc_only: bool = False) -> tuple:
565573 if self .stateful :
566574 loc = self ._loc
567575 count = self ._count
568- if self .decay < 1.0 :
576+ if self .decay != 1.0 :
569577 bias_correction = 1 - (count * math .log (self .decay )).exp ()
570578 bias_correction = bias_correction .apply (lambda x , y : x .to (y .dtype ), loc )
571579 else :
0 commit comments