diff --git a/examples/a2c/a2c_atari.py b/examples/a2c/a2c_atari.py index d3393e4308e..3eeba1c31dc 100644 --- a/examples/a2c/a2c_atari.py +++ b/examples/a2c/a2c_atari.py @@ -75,6 +75,9 @@ def main(cfg: "DictConfig"): # noqa: F821 critic_coef=cfg.loss.critic_coef, ) + # use end-of-life as done key + loss_module.set_keys(done="eol", terminated="eol") + # Create optimizer optim = torch.optim.Adam( loss_module.parameters(), diff --git a/examples/a2c/utils_atari.py b/examples/a2c/utils_atari.py index 42b75473b20..d1ad2c5c54e 100644 --- a/examples/a2c/utils_atari.py +++ b/examples/a2c/utils_atari.py @@ -3,20 +3,19 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import gymnasium as gym import numpy as np import torch.nn import torch.optim from tensordict.nn import TensorDictModule -from torchrl.data import CompositeSpec +from torchrl.data import CompositeSpec, UnboundedDiscreteTensorSpec from torchrl.data.tensor_specs import DiscreteBox from torchrl.envs import ( CatFrames, - default_info_dict_reader, DoubleToFloat, EnvCreator, ExplorationType, GrayScale, + GymEnv, NoopResetEnv, ParallelEnv, Resize, @@ -24,10 +23,10 @@ RewardSum, StepCounter, ToTensorImage, + Transform, TransformedEnv, VecNorm, ) -from torchrl.envs.libs.gym import GymWrapper from torchrl.modules import ( ActorValueOperator, ConvNet, @@ -43,43 +42,52 @@ # -------------------------------------------------------------------- -class EpisodicLifeEnv(gym.Wrapper): - def __init__(self, env): - """Make end-of-life == end-of-episode, but only reset on true game over. - Done by DeepMind for the DQN and co. It helps value estimation. - """ - gym.Wrapper.__init__(self, env) - self.lives = 0 +class EndOfLifeTransform(Transform): + """Registers the end-of-life signal from a Gym env with a `lives` method. - def step(self, action): - obs, rew, done, truncate, info = self.env.step(action) - lives = self.env.unwrapped.ale.lives() - info["end_of_life"] = False - if (lives < self.lives) or done: - info["end_of_life"] = True - self.lives = lives - return obs, rew, done, truncate, info + Done by DeepMind for the DQN and co. It helps value estimation. + """ - def reset(self, **kwargs): - reset_data = self.env.reset(**kwargs) - self.lives = self.env.unwrapped.ale.lives() - return reset_data + def _step(self, tensordict, next_tensordict): + lives = self.parent.base_env._env.unwrapped.ale.lives() + end_of_life = torch.tensor( + [tensordict["lives"] < lives], device=self.parent.device + ) + end_of_life = end_of_life | next_tensordict.get("done") + next_tensordict.set("eol", end_of_life) + next_tensordict.set("lives", lives) + return next_tensordict + + def reset(self, tensordict): + lives = self.parent.base_env._env.unwrapped.ale.lives() + end_of_life = False + tensordict.set("eol", [end_of_life]) + tensordict.set("lives", lives) + return tensordict + + def transform_observation_spec(self, observation_spec): + full_done_spec = self.parent.output_spec["full_done_spec"] + observation_spec["eol"] = full_done_spec["done"].clone() + observation_spec["lives"] = UnboundedDiscreteTensorSpec( + self.parent.batch_size, device=self.parent.device + ) + return observation_spec def make_base_env( env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False ): - env = gym.make(env_name) - if not is_test: - env = EpisodicLifeEnv(env) - env = GymWrapper( - env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device + env = GymEnv( + env_name, + frame_skip=frame_skip, + from_pixels=True, + pixels_only=False, + device=device, ) env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: - reader = default_info_dict_reader(["end_of_life"]) - env.set_info_dict_reader(reader) + env.append_transform(EndOfLifeTransform()) return env diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py index 2bb7cc6a3e8..2ef08ad976e 100644 --- a/examples/ppo/ppo_atari.py +++ b/examples/ppo/ppo_atari.py @@ -79,7 +79,7 @@ def main(cfg: "DictConfig"): # noqa: F821 ) # use end-of-life as done key - loss_module.set_keys(done="eol") + loss_module.set_keys(done="eol", terminated="eol") # Create optimizer optim = torch.optim.Adam( diff --git a/examples/ppo/utils_atari.py b/examples/ppo/utils_atari.py index ddb69555c19..478a9ed7326 100644 --- a/examples/ppo/utils_atari.py +++ b/examples/ppo/utils_atari.py @@ -3,7 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import gymnasium as gym import torch.nn import torch.optim from tensordict.nn import TensorDictModule @@ -11,11 +10,11 @@ from torchrl.data.tensor_specs import DiscreteBox, UnboundedDiscreteTensorSpec from torchrl.envs import ( CatFrames, - default_info_dict_reader, DoubleToFloat, EnvCreator, ExplorationType, GrayScale, + GymEnv, NoopResetEnv, ParallelEnv, Resize, @@ -27,7 +26,6 @@ TransformedEnv, VecNorm, ) -from torchrl.envs.libs.gym import GymWrapper from torchrl.modules import ( ActorValueOperator, ConvNet, @@ -78,15 +76,17 @@ def transform_observation_spec(self, observation_spec): def make_base_env( env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False ): - env = gym.make(env_name) - env = GymWrapper( - env, frame_skip=frame_skip, from_pixels=True, pixels_only=False, device=device + env = GymEnv( + env_name, + frame_skip=frame_skip, + from_pixels=True, + pixels_only=False, + device=device, ) - env = TransformedEnv(env, EndOfLifeTransform()) + env = TransformedEnv(env) env.append_transform(NoopResetEnv(noops=30, random=True)) if not is_test: - reader = default_info_dict_reader(["end_of_life"]) - env.set_info_dict_reader(reader) + env.append_transform(EndOfLifeTransform()) return env