Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,6 @@
import pytest
import torch

if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)
from pytorch.rl.test.mocking_classes import CountingEnv
else:
from _utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)
from mocking_classes import CountingEnv

from packaging import version
from packaging.version import parse
from tensordict import (
Expand Down Expand Up @@ -124,6 +107,23 @@
)


if os.getenv("PYTORCH_TEST_FBCODE"):
from pytorch.rl.test._utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)
from pytorch.rl.test.mocking_classes import CountingEnv
else:
from _utils_internal import (
capture_log_records,
CARTPOLE_VERSIONED,
get_default_devices,
make_tc,
)
from mocking_classes import CountingEnv

OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
_has_tv = importlib.util.find_spec("torchvision") is not None
_has_gym = importlib.util.find_spec("gym") is not None
Expand Down
77 changes: 67 additions & 10 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9490,11 +9490,68 @@ def test_vc1_spec_against_real(self, del_keys, device):
class TestVecNormV2:
SEED = -1

# @pytest.fixture(scope="class")
# def set_dtype(self):
# def_dtype = torch.get_default_dtype()
# yield torch.set_default_dtype(torch.double)
# torch.set_default_dtype(def_dtype)
class SimpleEnv(EnvBase):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.full_reward_spec = Composite(reward=Unbounded((1,)))
self.full_observation_spec = Composite(observation=Unbounded(()))
self.full_action_spec = Composite(action=Unbounded(()))

def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
tensordict = (
TensorDict()
.update(self.full_observation_spec.rand())
.update(self.full_done_spec.zero())
)
return tensordict

def _step(
self,
tensordict: TensorDictBase,
) -> TensorDictBase:
tensordict = (
TensorDict()
.update(self.full_observation_spec.rand())
.update(self.full_done_spec.zero())
)
tensordict["reward"] = self.reward_spec.rand()
return tensordict

def _set_seed(self, seed: int | None):
...

def test_vecnorm2_decay1(self):
env = self.SimpleEnv().append_transform(
VecNormV2(
in_keys=["reward", "observation"],
out_keys=["reward_norm", "obs_norm"],
decay=1,
)
)
s_ = env.reset()
ss = []
N = 20
for i in range(N):
s, s_ = env.step_and_maybe_reset(env.rand_action(s_))
ss.append(s)
sstack = torch.stack(ss)
if i >= 2:
for k in ("reward",):
loc = sstack[: i + 1]["next", k].mean(0)
scale = (
sstack[: i + 1]["next", k]
.std(0, unbiased=False)
.clamp_min(1e-6)
)
# Assert that loc and scale match the expected values
torch.testing.assert_close(
loc,
env.transform.loc[k],
), f"Loc mismatch at step {i}"
torch.testing.assert_close(
scale,
env.transform.scale[k],
), f"Scale mismatch at step {i}"

@pytest.mark.skipif(not _has_gym, reason="gym not available")
@pytest.mark.parametrize("stateful", [True, False])
Expand Down Expand Up @@ -9906,14 +9963,14 @@ def test_to_obsnorm_multikeys(self):
{"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4]
)
td0 = transform0._step(td, td.clone())
td0.update(transform0[0]._stateful_norm(td.select(*transform0[0].in_keys)))
# td0.update(transform0[0]._stateful_norm(td.select(*transform0[0].in_keys)))
td1 = transform0[0].to_observation_norm()._step(td, td.clone())
assert_allclose_td(td0, td1)

loc = transform0[0].loc
scale = transform0[0].scale
keys = list(transform0[0].in_keys)
td2 = (td.select(*keys) - loc) / (scale + torch.finfo(scale.dtype).eps)
td2 = (td.select(*keys) - loc) / (scale.clamp_min(torch.finfo(scale.dtype).eps))
td2.rename_key_("a", "a_avg")
td2.rename_key_(("b", "c"), ("b", "c_avg"))
assert_allclose_td(td0.select(*td2.keys(True, True)), td2)
Expand All @@ -9928,16 +9985,16 @@ def test_frozen(self):
transform0.frozen_copy()
td = TensorDict({"a": torch.randn(3, 4), ("b", "c"): torch.randn(3, 4)}, [3, 4])
td0 = transform0._step(td, td.clone())
td0.update(transform0._stateful_norm(td0.select(*transform0.in_keys)))
# td0.update(transform0._stateful_norm(td0.select(*transform0.in_keys)))

transform1 = transform0.frozen_copy()
td1 = transform1._step(td, td.clone())
assert_allclose_td(td0, td1)

td += 1
td2 = transform0._step(td, td.clone())
td3 = transform1._step(td, td.clone())
assert_allclose_td(td2, td3)
transform1._step(td, td.clone())
# assert_allclose_td(td2, td3)
with pytest.raises(AssertionError):
assert_allclose_td(td0, td2)

Expand Down
57 changes: 37 additions & 20 deletions torchrl/envs/transforms/vecnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,8 @@ def _step(
)
if self.missing_tolerance and next_tensordict_select.is_empty():
return next_tensordict
next_tensordict_norm = self._stateful_norm(next_tensordict_select)
self._stateful_update(next_tensordict_select)
next_tensordict_norm = self._stateful_norm(next_tensordict_select)
else:
self._maybe_stateless_init(tensordict)
next_tensordict_select = next_tensordict.select(
Expand All @@ -261,10 +261,10 @@ def _step(
var = tensordict[f"{self.prefix}_var"]
count = tensordict[f"{self.prefix}_count"]

next_tensordict_norm = self._stateless_norm(
loc, var, count = self._stateless_update(
next_tensordict_select, loc, var, count
)
loc, var, count = self._stateless_update(
next_tensordict_norm = self._stateless_norm(
next_tensordict_select, loc, var, count
)
# updates have been done in-place, we're good
Expand Down Expand Up @@ -328,27 +328,38 @@ def _in_keys_safe(self):
return self.in_keys[:-3]
return self.in_keys

def _norm(self, data, loc, var):
def _norm(self, data, loc, var, count):
if self.missing_tolerance:
loc = loc.select(*data.keys(True, True))
var = var.select(*data.keys(True, True))
count = count.select(*data.keys(True, True))
if loc.is_empty():
return data

if self.decay < 1.0:
bias_correction = 1 - (count * math.log(self.decay)).exp()
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
else:
bias_correction = 1

var = var - loc.pow(2)
loc = loc / bias_correction
var = var / bias_correction

scale = var.sqrt().clamp_min(self.eps)

data_update = (data - loc) / scale
if self.out_keys[: len(self.in_keys)] != self.in_keys:
# map names
for in_key, out_key in _zip_strict(self._in_keys_safe, self.out_keys):
data_update.rename_key_(in_key, out_key)
if in_key in data_update:
data_update.rename_key_(in_key, out_key)
else:
pass
return data_update

def _stateful_norm(self, data):
return self._norm(data, self._loc, self._var)
return self._norm(data, self._loc, self._var, self._count)

def _stateful_update(self, data):
if self.frozen:
Expand All @@ -363,14 +374,14 @@ def _stateful_update(self, data):
count = self._count
count += 1
data = self._maybe_cast_to_float(data)
if self.decay < 1.0:
bias_correction = 1 - (count * math.log(self.decay)).exp()
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
if self.decay != 1.0:
weight = 1 - self.decay
loc.lerp_(end=data, weight=weight)
var.lerp_(end=data.pow(2), weight=weight)
else:
bias_correction = 1
weight = (1 - self.decay) / bias_correction
loc.lerp_(end=data, weight=weight)
var.lerp_(end=data.pow(2), weight=weight)
weight = 1 / count
loc.lerp_(end=data, weight=weight)
var.lerp_(end=data.pow(2), weight=weight)

def _maybe_stateless_init(self, data):
if not self.initialized or f"{self.prefix}_loc" not in data.keys():
Expand Down Expand Up @@ -398,20 +409,18 @@ def _maybe_stateless_init(self, data):
data[f"{self.prefix}_var"] = var

def _stateless_norm(self, data, loc, var, count):
data = self._norm(data, loc, var)
data = self._norm(data, loc, var, count)
return data

def _stateless_update(self, data, loc, var, count):
if self.frozen:
return loc, var, count
count = count + 1
data = self._maybe_cast_to_float(data)
if self.decay < 1.0:
bias_correction = 1 - (count * math.log(self.decay)).exp()
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), data)
if self.decay != 1.0:
weight = 1 - self.decay
else:
bias_correction = 1
weight = (1 - self.decay) / bias_correction
weight = 1 / count
loc = loc.lerp(end=data, weight=weight)
var = var.lerp(end=data.pow(2), weight=weight)
return loc, var, count
Expand Down Expand Up @@ -563,10 +572,18 @@ def to_observation_norm(self) -> Compose | ObservationNorm:
def _get_loc_scale(self, loc_only: bool = False) -> tuple:
if self.stateful:
loc = self._loc
count = self._count
if self.decay != 1.0:
bias_correction = 1 - (count * math.log(self.decay)).exp()
bias_correction = bias_correction.apply(lambda x, y: x.to(y.dtype), loc)
else:
bias_correction = 1
if loc_only:
return loc, None
return loc / bias_correction, None
var = self._var
var = var - loc.pow(2)
loc = loc / bias_correction
var = var / bias_correction
scale = var.sqrt().clamp_min(self.eps)
return loc, scale
else:
Expand Down
Loading