Skip to content

Commit 7a5efb6

Browse files
author
Vincent Moens
committed
[Refactor] VecNormV2: update before norm, bias_correction at the right time
ghstack-source-id: 4513567 Pull Request resolved: #2900
1 parent 9e3c4df commit 7a5efb6

File tree

2 files changed

+32
-24
lines changed

2 files changed

+32
-24
lines changed

test/test_transforms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9906,14 +9906,14 @@ def test_to_obsnorm_multikeys(self):
99069906
{"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4]
99079907
)
99089908
td0 = transform0._step(td, td.clone())
9909-
td0.update(transform0[0]._stateful_norm(td.select(*transform0[0].in_keys)))
9909+
# td0.update(transform0[0]._stateful_norm(td.select(*transform0[0].in_keys)))
99109910
td1 = transform0[0].to_observation_norm()._step(td, td.clone())
99119911
assert_allclose_td(td0, td1)
99129912

99139913
loc = transform0[0].loc
99149914
scale = transform0[0].scale
99159915
keys = list(transform0[0].in_keys)
9916-
td2 = (td.select(*keys) - loc) / (scale + torch.finfo(scale.dtype).eps)
9916+
td2 = (td.select(*keys) - loc) / (scale.clamp_min(torch.finfo(scale.dtype).eps))
99179917
td2.rename_key_("a", "a_avg")
99189918
td2.rename_key_(("b", "c"), ("b", "c_avg"))
99199919
assert_allclose_td(td0.select(*td2.keys(True, True)), td2)
@@ -9928,16 +9928,16 @@ def test_frozen(self):
99289928
transform0.frozen_copy()
99299929
td = TensorDict({"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4])
99309930
td0 = transform0._step(td, td.clone())
9931-
td0.update(transform0._stateful_norm(td0.select(*transform0.in_keys)))
9931+
# td0.update(transform0._stateful_norm(td0.select(*transform0.in_keys)))
99329932

99339933
transform1 = transform0.frozen_copy()
99349934
td1 = transform1._step(td, td.clone())
99359935
assert_allclose_td(td0, td1)
99369936

99379937
td += 1
99389938
td2 = transform0._step(td, td.clone())
9939-
td3 = transform1._step(td, td.clone())
9940-
assert_allclose_td(td2, td3)
9939+
transform1._step(td, td.clone())
9940+
# assert_allclose_td(td2, td3)
99419941
with pytest.raises(AssertionError):
99429942
assert_allclose_td(td0, td2)
99439943

torchrl/envs/transforms/vecnorm.py

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ def _step(
248248
)
249249
if self.missing_tolerance and next_tensordict_select.is_empty():
250250
return next_tensordict
251-
next_tensordict_norm = self._stateful_norm(next_tensordict_select)
252251
self._stateful_update(next_tensordict_select)
252+
next_tensordict_norm = self._stateful_norm(next_tensordict_select)
253253
else:
254254
self._maybe_stateless_init(tensordict)
255255
next_tensordict_select = next_tensordict.select(
@@ -261,10 +261,10 @@ def _step(
261261
var = tensordict[f"{self.prefix}_var"]
262262
count = tensordict[f"{self.prefix}_count"]
263263

264-
next_tensordict_norm = self._stateless_norm(
264+
loc, var, count = self._stateless_update(
265265
next_tensordict_select, loc, var, count
266266
)
267-
loc, var, count = self._stateless_update(
267+
next_tensordict_norm = self._stateless_norm(
268268
next_tensordict_select, loc, var, count
269269
)
270270
# updates have been done in-place, we're good
@@ -328,14 +328,24 @@ def _in_keys_safe(self):
328328
return self.in_keys[:-3]
329329
return self.in_keys
330330

331-
def _norm(self, data, loc, var):
331+
def _norm(self, data, loc, var, count):
332332
if self.missing_tolerance:
333333
loc = loc.select(*data.keys(True, True))
334334
var = var.select(*data.keys(True, True))
335+
count = count.select(*data.keys(True, True))
335336
if loc.is_empty():
336337
return data
337338

339+
if self.decay < 1.0:
340+
bias_correction = 1 - (count * math.log(self.decay)).exp()
341+
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
342+
else:
343+
bias_correction = 1
344+
338345
var = var - loc.pow(2)
346+
loc = loc / bias_correction
347+
var = var / bias_correction
348+
339349
scale = var.sqrt().clamp_min(self.eps)
340350

341351
data_update = (data - loc) / scale
@@ -348,7 +358,7 @@ def _norm(self, data, loc, var):
348358
return data_update
349359

350360
def _stateful_norm(self, data):
351-
return self._norm(data, self._loc, self._var)
361+
return self._norm(data, self._loc, self._var, self._count)
352362

353363
def _stateful_update(self, data):
354364
if self.frozen:
@@ -363,12 +373,7 @@ def _stateful_update(self, data):
363373
count = self._count
364374
count += 1
365375
data = self._maybe_cast_to_float(data)
366-
if self.decay < 1.0:
367-
bias_correction = 1 - (count * math.log(self.decay)).exp()
368-
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
369-
else:
370-
bias_correction = 1
371-
weight = (1 - self.decay) / bias_correction
376+
weight = 1 - self.decay
372377
loc.lerp_(end=data, weight=weight)
373378
var.lerp_(end=data.pow(2), weight=weight)
374379

@@ -398,20 +403,15 @@ def _maybe_stateless_init(self, data):
398403
data[f"{self.prefix}_var"] = var
399404

400405
def _stateless_norm(self, data, loc, var, count):
401-
data = self._norm(data, loc, var)
406+
data = self._norm(data, loc, var, count)
402407
return data
403408

404409
def _stateless_update(self, data, loc, var, count):
405410
if self.frozen:
406411
return loc, var, count
407412
count = count + 1
408413
data = self._maybe_cast_to_float(data)
409-
if self.decay < 1.0:
410-
bias_correction = 1 - (count * math.log(self.decay)).exp()
411-
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
412-
else:
413-
bias_correction = 1
414-
weight = (1 - self.decay) / bias_correction
414+
weight = 1 - self.decay
415415
loc = loc.lerp(end=data, weight=weight)
416416
var = var.lerp(end=data.pow(2), weight=weight)
417417
return loc, var, count
@@ -563,10 +563,18 @@ def to_observation_norm(self) -> Compose | ObservationNorm:
563563
def _get_loc_scale(self, loc_only: bool = False) -> tuple:
564564
if self.stateful:
565565
loc = self._loc
566+
count = self._count
567+
if self.decay < 1.0:
568+
bias_correction = 1 - (count * math.log(self.decay)).exp()
569+
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), loc)
570+
else:
571+
bias_correction = 1
566572
if loc_only:
567-
return loc, None
573+
return loc / bias_correction, None
568574
var = self._var
569575
var = var - loc.pow(2)
576+
loc = loc / bias_correction
577+
var = var / bias_correction
570578
scale = var.sqrt().clamp_min(self.eps)
571579
return loc, scale
572580
else:

0 commit comments

Comments
 (0)