@@ -9490,11 +9490,68 @@ def test_vc1_spec_against_real(self, del_keys, device):
94909490class TestVecNormV2:
94919491 SEED = -1
94929492
9493- # @pytest.fixture(scope="class")
9494- # def set_dtype(self):
9495- # def_dtype = torch.get_default_dtype()
9496- # yield torch.set_default_dtype(torch.double)
9497- # torch.set_default_dtype(def_dtype)
9493+ class SimpleEnv(EnvBase):
9494+ def __init__(self, **kwargs):
9495+ super().__init__(**kwargs)
9496+ self.full_reward_spec = Composite(reward=Unbounded((1,)))
9497+ self.full_observation_spec = Composite(observation=Unbounded(()))
9498+ self.full_action_spec = Composite(action=Unbounded(()))
9499+
9500+ def _reset(self, tensordict: TensorDictBase, **kwargs) -> TensorDictBase:
9501+ tensordict = (
9502+ TensorDict()
9503+ .update(self.full_observation_spec.rand())
9504+ .update(self.full_done_spec.zero())
9505+ )
9506+ return tensordict
9507+
9508+ def _step(
9509+ self,
9510+ tensordict: TensorDictBase,
9511+ ) -> TensorDictBase:
9512+ tensordict = (
9513+ TensorDict()
9514+ .update(self.full_observation_spec.rand())
9515+ .update(self.full_done_spec.zero())
9516+ )
9517+ tensordict["reward"] = self.reward_spec.rand()
9518+ return tensordict
9519+
9520+ def _set_seed(self, seed: int | None):
9521+ ...
9522+
9523+ def test_vecnorm2_decay1(self):
9524+ env = self.SimpleEnv().append_transform(
9525+ VecNormV2(
9526+ in_keys=["reward", "observation"],
9527+ out_keys=["reward_norm", "obs_norm"],
9528+ decay=1,
9529+ )
9530+ )
9531+ s_ = env.reset()
9532+ ss = []
9533+ N = 20
9534+ for i in range(N):
9535+ s, s_ = env.step_and_maybe_reset(env.rand_action(s_))
9536+ ss.append(s)
9537+ sstack = torch.stack(ss)
9538+ if i >= 2:
9539+ for k in ("reward",):
9540+ loc = sstack[: i + 1]["next", k].mean(0)
9541+ scale = (
9542+ sstack[: i + 1]["next", k]
9543+ .std(0, unbiased=False)
9544+ .clamp_min(1e-6)
9545+ )
9546+ # Assert that loc and scale match the expected values
9547+ torch.testing.assert_close(
9548+ loc,
9549+ env.transform.loc[k],
9550+ ), f"Loc mismatch at step {i}"
9551+ torch.testing.assert_close(
9552+ scale,
9553+ env.transform.scale[k],
9554+ ), f"Scale mismatch at step {i}"
94989555
94999556 @pytest.mark.skipif(not _has_gym, reason="gym not available")
95009557 @pytest.mark.parametrize("stateful", [True, False])
0 commit comments