Skip to content

Commit 36d21ff

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 7434396 commit 36d21ff

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

torchrl/envs/transforms/vecnorm.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -374,9 +374,14 @@ def _stateful_update(self, data):
374374
count = self._count
375375
count += 1
376376
data = self._maybe_cast_to_float(data)
377-
weight = 1 - self.decay
378-
loc.lerp_(end=data, weight=weight)
379-
var.lerp_(end=data.pow(2), weight=weight)
377+
if self.decay != 1.0:
378+
weight = 1 - self.decay
379+
loc.lerp_(end=data, weight=weight)
380+
var.lerp_(end=data.pow(2), weight=weight)
381+
else:
382+
weight = 1 / count
383+
loc.lerp_(end=data, weight=weight)
384+
var.lerp_(end=data.pow(2), weight=weight)
380385

381386
def _maybe_stateless_init(self, data):
382387
if not self.initialized or f"{self.prefix}_loc" not in data.keys():
@@ -412,7 +417,10 @@ def _stateless_update(self, data, loc, var, count):
412417
return loc, var, count
413418
count = count + 1
414419
data = self._maybe_cast_to_float(data)
415-
weight = 1 - self.decay
420+
if self.decay != 1.0:
421+
weight = 1 - self.decay
422+
else:
423+
weight = 1 / count
416424
loc = loc.lerp(end=data, weight=weight)
417425
var = var.lerp(end=data.pow(2), weight=weight)
418426
return loc, var, count
@@ -565,7 +573,7 @@ def _get_loc_scale(self, loc_only: bool = False) -> tuple:
565573
if self.stateful:
566574
loc = self._loc
567575
count = self._count
568-
if self.decay < 1.0:
576+
if self.decay != 1.0:
569577
bias_correction = 1 - (count * math.log(self.decay)).exp()
570578
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), loc)
571579
else:

0 commit comments

Comments
 (0)