Skip to content

Commit abe1f91

Browse files
author
Vincent Moens
committed
[Refactor] VecNormV2: update before norm, bias_correction at the right time
ghstack-source-id: 04eea72 Pull Request resolved: #2900
1 parent 9e3c4df commit abe1f91

File tree

3 files changed

+59
-42
lines changed

3 files changed

+59
-42
lines changed

test/test_rb.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,6 @@
1818
import pytest
1919
import torch
2020

21-
if os.getenv("PYTORCH_TEST_FBCODE"):
22-
from pytorch.rl.test._utils_internal import (
23-
capture_log_records,
24-
CARTPOLE_VERSIONED,
25-
get_default_devices,
26-
make_tc,
27-
)
28-
from pytorch.rl.test.mocking_classes import CountingEnv
29-
else:
30-
from _utils_internal import (
31-
capture_log_records,
32-
CARTPOLE_VERSIONED,
33-
get_default_devices,
34-
make_tc,
35-
)
36-
from mocking_classes import CountingEnv
37-
3821
from packaging import version
3922
from packaging.version import parse
4023
from tensordict import (
@@ -124,6 +107,23 @@
124107
)
125108

126109

110+
if os.getenv("PYTORCH_TEST_FBCODE"):
111+
from pytorch.rl.test._utils_internal import (
112+
capture_log_records,
113+
CARTPOLE_VERSIONED,
114+
get_default_devices,
115+
make_tc,
116+
)
117+
from pytorch.rl.test.mocking_classes import CountingEnv
118+
else:
119+
from _utils_internal import (
120+
capture_log_records,
121+
CARTPOLE_VERSIONED,
122+
get_default_devices,
123+
make_tc,
124+
)
125+
from mocking_classes import CountingEnv
126+
127127
OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
128128
_has_tv = importlib.util.find_spec("torchvision") is not None
129129
_has_gym = importlib.util.find_spec("gym") is not None

test/test_transforms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9906,14 +9906,14 @@ def test_to_obsnorm_multikeys(self):
99069906
{"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4]
99079907
)
99089908
td0 = transform0._step(td, td.clone())
9909-
td0.update(transform0[0]._stateful_norm(td.select(*transform0[0].in_keys)))
9909+
# td0.update(transform0[0]._stateful_norm(td.select(*transform0[0].in_keys)))
99109910
td1 = transform0[0].to_observation_norm()._step(td, td.clone())
99119911
assert_allclose_td(td0, td1)
99129912

99139913
loc = transform0[0].loc
99149914
scale = transform0[0].scale
99159915
keys = list(transform0[0].in_keys)
9916-
td2 = (td.select(*keys) - loc) / (scale + torch.finfo(scale.dtype).eps)
9916+
td2 = (td.select(*keys) - loc) / (scale.clamp_min(torch.finfo(scale.dtype).eps))
99179917
td2.rename_key_("a", "a_avg")
99189918
td2.rename_key_(("b", "c"), ("b", "c_avg"))
99199919
assert_allclose_td(td0.select(*td2.keys(True, True)), td2)
@@ -9928,16 +9928,16 @@ def test_frozen(self):
99289928
transform0.frozen_copy()
99299929
td = TensorDict({"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4])
99309930
td0 = transform0._step(td, td.clone())
9931-
td0.update(transform0._stateful_norm(td0.select(*transform0.in_keys)))
9931+
# td0.update(transform0._stateful_norm(td0.select(*transform0.in_keys)))
99329932

99339933
transform1 = transform0.frozen_copy()
99349934
td1 = transform1._step(td, td.clone())
99359935
assert_allclose_td(td0, td1)
99369936

99379937
td += 1
99389938
td2 = transform0._step(td, td.clone())
9939-
td3 = transform1._step(td, td.clone())
9940-
assert_allclose_td(td2, td3)
9939+
transform1._step(td, td.clone())
9940+
# assert_allclose_td(td2, td3)
99419941
with pytest.raises(AssertionError):
99429942
assert_allclose_td(td0, td2)
99439943

torchrl/envs/transforms/vecnorm.py

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -248,8 +248,8 @@ def _step(
248248
)
249249
if self.missing_tolerance and next_tensordict_select.is_empty():
250250
return next_tensordict
251-
next_tensordict_norm = self._stateful_norm(next_tensordict_select)
252251
self._stateful_update(next_tensordict_select)
252+
next_tensordict_norm = self._stateful_norm(next_tensordict_select)
253253
else:
254254
self._maybe_stateless_init(tensordict)
255255
next_tensordict_select = next_tensordict.select(
@@ -261,10 +261,10 @@ def _step(
261261
var = tensordict[f"{self.prefix}_var"]
262262
count = tensordict[f"{self.prefix}_count"]
263263

264-
next_tensordict_norm = self._stateless_norm(
264+
loc, var, count = self._stateless_update(
265265
next_tensordict_select, loc, var, count
266266
)
267-
loc, var, count = self._stateless_update(
267+
next_tensordict_norm = self._stateless_norm(
268268
next_tensordict_select, loc, var, count
269269
)
270270
# updates have been done in-place, we're good
@@ -328,27 +328,38 @@ def _in_keys_safe(self):
328328
return self.in_keys[:-3]
329329
return self.in_keys
330330

331-
def _norm(self, data, loc, var):
331+
def _norm(self, data, loc, var, count):
332332
if self.missing_tolerance:
333333
loc = loc.select(*data.keys(True, True))
334334
var = var.select(*data.keys(True, True))
335+
count = count.select(*data.keys(True, True))
335336
if loc.is_empty():
336337
return data
337338

339+
if self.decay < 1.0:
340+
bias_correction = 1 - (count * math.log(self.decay)).exp()
341+
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
342+
else:
343+
bias_correction = 1
344+
338345
var = var - loc.pow(2)
346+
loc = loc / bias_correction
347+
var = var / bias_correction
348+
339349
scale = var.sqrt().clamp_min(self.eps)
340350

341351
data_update = (data - loc) / scale
342352
if self.out_keys[: len(self.in_keys)] != self.in_keys:
343353
# map names
344354
for in_key, out_key in _zip_strict(self._in_keys_safe, self.out_keys):
345-
data_update.rename_key_(in_key, out_key)
355+
if in_key in data_update:
356+
data_update.rename_key_(in_key, out_key)
346357
else:
347358
pass
348359
return data_update
349360

350361
def _stateful_norm(self, data):
351-
return self._norm(data, self._loc, self._var)
362+
return self._norm(data, self._loc, self._var, self._count)
352363

353364
def _stateful_update(self, data):
354365
if self.frozen:
@@ -363,14 +374,14 @@ def _stateful_update(self, data):
363374
count = self._count
364375
count += 1
365376
data = self._maybe_cast_to_float(data)
366-
if self.decay < 1.0:
367-
bias_correction = 1 - (count * math.log(self.decay)).exp()
368-
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
377+
if self.decay != 1.0:
378+
weight = 1 - self.decay
379+
loc.lerp_(end=data, weight=weight)
380+
var.lerp_(end=data.pow(2), weight=weight)
369381
else:
370-
bias_correction = 1
371-
weight = (1 - self.decay) / bias_correction
372-
loc.lerp_(end=data, weight=weight)
373-
var.lerp_(end=data.pow(2), weight=weight)
382+
weight = 1 / count
383+
loc.lerp_(end=data, weight=weight)
384+
var.lerp_(end=data.pow(2), weight=weight)
374385

375386
def _maybe_stateless_init(self, data):
376387
if not self.initialized or f"{self.prefix}_loc" not in data.keys():
@@ -398,20 +409,18 @@ def _maybe_stateless_init(self, data):
398409
data[f"{self.prefix}_var"] = var
399410

400411
def _stateless_norm(self, data, loc, var, count):
401-
data = self._norm(data, loc, var)
412+
data = self._norm(data, loc, var, count)
402413
return data
403414

404415
def _stateless_update(self, data, loc, var, count):
405416
if self.frozen:
406417
return loc, var, count
407418
count = count + 1
408419
data = self._maybe_cast_to_float(data)
409-
if self.decay < 1.0:
410-
bias_correction = 1 - (count * math.log(self.decay)).exp()
411-
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
420+
if self.decay != 1.0:
421+
weight = 1 - self.decay
412422
else:
413-
bias_correction = 1
414-
weight = (1 - self.decay) / bias_correction
423+
weight = 1 / count
415424
loc = loc.lerp(end=data, weight=weight)
416425
var = var.lerp(end=data.pow(2), weight=weight)
417426
return loc, var, count
@@ -563,10 +572,18 @@ def to_observation_norm(self) -> Compose | ObservationNorm:
563572
def _get_loc_scale(self, loc_only: bool = False) -> tuple:
564573
if self.stateful:
565574
loc = self._loc
575+
count = self._count
576+
if self.decay != 1.0:
577+
bias_correction = 1 - (count * math.log(self.decay)).exp()
578+
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), loc)
579+
else:
580+
bias_correction = 1
566581
if loc_only:
567-
return loc, None
582+
return loc / bias_correction, None
568583
var = self._var
569584
var = var - loc.pow(2)
585+
loc = loc / bias_correction
586+
var = var / bias_correction
570587
scale = var.sqrt().clamp_min(self.eps)
571588
return loc, scale
572589
else:

0 commit comments

Comments
 (0)