Skip to content

Commit be01d16

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 36d21ff commit be01d16

File tree

1 file changed

+62
-5
lines changed

1 file changed

+62
-5
lines changed

test/test_transforms.py

Lines changed: 62 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9490,11 +9490,68 @@ def test_vc1_spec_against_real(self, del_keys, device):
94909490
class TestVecNormV2:
94919491
SEED = -1
94929492

9493-
# @pytest.fixture(scope="class")
9494-
# def set_dtype(self):
9495-
# def_dtype = torch.get_default_dtype()
9496-
# yield torch.set_default_dtype(torch.double)
9497-
# torch.set_default_dtype(def_dtype)
9493+
class SimpleEnv(EnvBase):
9494+
def __init__(self, **kwargs):
9495+
super().__init__(**kwargs)
9496+
self.full_reward_spec = Composite(reward=Unbounded((1,)))
9497+
self.full_observation_spec = Composite(observation=Unbounded(()))
9498+
self.full_action_spec = Composite(action=Unbounded(()))
9499+
9500+
def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
9501+
tensordict = (
9502+
TensorDict()
9503+
.update(self.full_observation_spec.rand())
9504+
.update(self.full_done_spec.zero())
9505+
)
9506+
return tensordict
9507+
9508+
def _step(
9509+
self,
9510+
tensordict: TensorDictBase,
9511+
) -> TensorDictBase:
9512+
tensordict = (
9513+
TensorDict()
9514+
.update(self.full_observation_spec.rand())
9515+
.update(self.full_done_spec.zero())
9516+
)
9517+
tensordict["reward"] = self.reward_spec.rand()
9518+
return tensordict
9519+
9520+
def _set_seed(self, seed: int | None):
9521+
...
9522+
9523+
def test_vecnorm2_decay1(self):
9524+
env = self.SimpleEnv().append_transform(
9525+
VecNormV2(
9526+
in_keys=["reward", "observation"],
9527+
out_keys=["reward_norm", "obs_norm"],
9528+
decay=1,
9529+
)
9530+
)
9531+
s_ = env.reset()
9532+
ss = []
9533+
N = 20
9534+
for i in range(N):
9535+
s, s_ = env.step_and_maybe_reset(env.rand_action(s_))
9536+
ss.append(s)
9537+
sstack = torch.stack(ss)
9538+
if i >= 2:
9539+
for k in ("reward",):
9540+
loc = sstack[: i + 1]["next", k].mean(0)
9541+
scale = (
9542+
sstack[: i + 1]["next", k]
9543+
.std(0, unbiased=False)
9544+
.clamp_min(1e-6)
9545+
)
9546+
# Assert that loc and scale match the expected values
9547+
torch.testing.assert_close(
9548+
loc,
9549+
env.transform.loc[k],
9550+
), f"Loc mismatch at step {i}"
9551+
torch.testing.assert_close(
9552+
scale,
9553+
env.transform.scale[k],
9554+
), f"Scale mismatch at step {i}"
94989555

94999556
@pytest.mark.skipif(not _has_gym, reason="gym not available")
95009557
@pytest.mark.parametrize("stateful", [True, False])

0 commit comments

Comments
 (0)