From 5daa49acf4af1284cd9291d633755eeb62e5b590 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 15 Jan 2024 21:08:03 +0000 Subject: [PATCH 1/5] init --- torchrl/objectives/a2c.py | 60 +++++++++++++++++++++++++++--------- torchrl/objectives/ppo.py | 64 +++++++++++++++++++++++++-------------- 2 files changed, 86 insertions(+), 38 deletions(-) diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 4384ccef282..42d477f15f2 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import contextlib import warnings from copy import deepcopy from dataclasses import dataclass @@ -46,8 +47,8 @@ class A2CLoss(LossModule): https://arxiv.org/abs/1602.01783v2 Args: - actor (ProbabilisticTensorDictSequential): policy operator. - critic (ValueOperator): value operator. + actor_network (ProbabilisticTensorDictSequential): policy operator. + critic_network (ValueOperator): value operator. entropy_bonus (bool): if ``True``, an entropy bonus will be added to the loss to favour exploratory policies. samples_mc_entropy (int): if the distribution retrieved from the policy @@ -221,8 +222,8 @@ class _AcceptedKeys: def __init__( self, - actor: ProbabilisticTensorDictSequential, - critic: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential, + critic_network: TensorDictModule, *, entropy_bonus: bool = True, samples_mc_entropy: int = 1, @@ -233,23 +234,44 @@ def __init__( separate_losses: bool = False, advantage_key: str = None, value_target_key: str = None, + functional: bool = True, + actor: ProbabilisticTensorDictSequential = None, + critic: ProbabilisticTensorDictSequential = None, ): + if actor is not None: + actor_network = actor + del actor + if critic is not None: + critic_network = critic + del critic + self._out_keys = None super().__init__() self._set_deprecated_ctor_keys( advantage=advantage_key, value_target=value_target_key ) - self.convert_to_functional( - actor, "actor", funs_to_decorate=["forward", "get_dist"] - ) + self.functional = functional + if functional: + self.convert_to_functional( + actor_network, "actor_network", funs_to_decorate=["forward", "get_dist"] + ) + else: + self.actor_network = actor_network + if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared policy_params = list(actor.parameters()) else: policy_params = None - self.convert_to_functional(critic, "critic", compare_against=policy_params) + if functional: + self.convert_to_functional( + critic_network, "critic_network", compare_against=policy_params + ) + else: + self.critic_network = critic_network + self.samples_mc_entropy = samples_mc_entropy self.entropy_bonus = entropy_bonus and entropy_coef @@ -265,6 +287,10 @@ def __init__( self.gamma = gamma self.loss_critic_type = loss_critic_type + @property + def actor(self): + return self.actor_network + @property def in_keys(self): keys = [ @@ -272,8 +298,8 @@ def in_keys(self): ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), ("next", self.tensor_keys.terminated), - *self.actor.in_keys, - *[("next", key) for key in self.actor.in_keys], + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], ] if self.critic_coef: keys.extend(self.critic.in_keys) @@ -326,9 +352,11 @@ def _log_probs( raise RuntimeError( f"tensordict stored {self.tensor_keys.action} require grad." ) - tensordict_clone = tensordict.select(*self.actor.in_keys).clone() - with self.actor_params.to_module(self.actor): - dist = self.actor.get_dist(tensordict_clone) + tensordict_clone = tensordict.select(*self.actor_network.in_keys).clone() + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): + dist = self.actor_network.get_dist(tensordict_clone) log_prob = dist.log_prob(action) log_prob = log_prob.unsqueeze(-1) return log_prob, dist @@ -339,7 +367,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: # overhead that we could easily reduce. target_return = tensordict.get(self.tensor_keys.value_target) tensordict_select = tensordict.select(*self.critic.in_keys) - with self.critic_params.to_module(self.critic): + with self.critic_params.to_module( + self.critic + ) if self.functional else contextlib.nullcontext(): state_value = self.critic( tensordict_select, ).get(self.tensor_keys.value) @@ -407,7 +437,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor actor_with_params = repopulate_module( - deepcopy(self.actor), self.actor_params + deepcopy(self.actor_network), self.actor_network_params ) self._value_estimator = VTrace( value_network=self.critic, actor_network=actor_with_params, **hp diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 11b5fef2ae7..c2d17324203 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -49,8 +49,8 @@ class PPOLoss(LossModule): https://arxiv.org/abs/1707.06347 Args: - actor (ProbabilisticTensorDictSequential): policy operator. - critic (ValueOperator): value operator. + actor_network (ProbabilisticTensorDictSequential): policy operator. + critic_network (ValueOperator): value operator. Keyword Args: entropy_bonus (bool, optional): if ``True``, an entropy bonus will be added to the @@ -259,8 +259,8 @@ class _AcceptedKeys: def __init__( self, - actor: ProbabilisticTensorDictSequential, - critic: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential = None, + critic_network: TensorDictModule = None, *, entropy_bonus: bool = True, samples_mc_entropy: int = 1, @@ -273,18 +273,30 @@ def __init__( advantage_key: str = None, value_target_key: str = None, value_key: str = None, + functional: bool = True, + actor: ProbabilisticTensorDictSequential = None, + critic: ProbabilisticTensorDictSequential = None, ): + if actor is not None: + actor_network = actor + del actor + if critic is not None: + critic_network = critic + del critic + self._in_keys = None self._out_keys = None super().__init__() - self.convert_to_functional(actor, "actor") + self.convert_to_functional(actor_network, "actor_network") if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared policy_params = list(actor.parameters()) else: policy_params = None - self.convert_to_functional(critic, "critic", compare_against=policy_params) + self.convert_to_functional( + critic_network, "critic_network", compare_against=policy_params + ) self.samples_mc_entropy = samples_mc_entropy self.entropy_bonus = entropy_bonus self.separate_losses = separate_losses @@ -314,9 +326,9 @@ def _set_in_keys(self): ("next", self.tensor_keys.reward), ("next", self.tensor_keys.done), ("next", self.tensor_keys.terminated), - *self.actor.in_keys, - *[("next", key) for key in self.actor.in_keys], - *self.critic.in_keys, + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], + *self.critic_network.in_keys, ] self._in_keys = list(set(keys)) @@ -378,8 +390,8 @@ def _log_weight( f"tensordict stored {self.tensor_keys.action} requires grad." ) - with self.actor_params.to_module(self.actor): - dist = self.actor.get_dist(tensordict) + with self.actor_network_params.to_module(self.actor_network): + dist = self.actor_network.get_dist(tensordict) log_prob = dist.log_prob(action) prev_log_prob = tensordict.get(self.tensor_keys.sample_log_prob) @@ -405,8 +417,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: f"can be used for the value loss." ) - with self.critic_params.to_module(self.critic): - state_value_td = self.critic(tensordict) + with self.critic_network_params.to_module(self.critic_network): + state_value_td = self.critic_network(tensordict) try: state_value = state_value_td.get(self.tensor_keys.value) @@ -426,7 +438,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: @property @_cache_values def _cached_critic_params_detached(self): - return self.critic_params.detach() + return self.critic_network_params.detach() @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -465,20 +477,26 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams hp["gamma"] = self.gamma hp.update(hyperparams) if value_type == ValueEstimators.TD1: - self._value_estimator = TD1Estimator(value_network=self.critic, **hp) + self._value_estimator = TD1Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.TD0: - self._value_estimator = TD0Estimator(value_network=self.critic, **hp) + self._value_estimator = TD0Estimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.GAE: - self._value_estimator = GAE(value_network=self.critic, **hp) + self._value_estimator = GAE(value_network=self.critic_network, **hp) elif value_type == ValueEstimators.TDLambda: - self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) + self._value_estimator = TDLambdaEstimator( + value_network=self.critic_network, **hp + ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor actor_with_params = repopulate_module( - deepcopy(self.actor), self.actor_params + deepcopy(self.actor_network), self.actor_network_params ) self._value_estimator = VTrace( - value_network=self.critic, actor_network=actor_with_params, **hp + value_network=self.critic_network, actor_network=actor_with_params, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") @@ -859,9 +877,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: log_weight, dist = self._log_weight(tensordict) neg_loss = log_weight.exp() * advantage - previous_dist = self.actor.build_dist_from_params(tensordict) - with self.actor_params.to_module(self.actor): - current_dist = self.actor.get_dist(tensordict) + previous_dist = self.actor_network.build_dist_from_params(tensordict) + with self.actor_network_params.to_module(self.actor_network): + current_dist = self.actor_network.get_dist(tensordict) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) except NotImplementedError: From 3139d6bc58b6fca17cda448751059eba26b0417c Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 19 Jan 2024 17:48:15 +0000 Subject: [PATCH 2/5] amend --- test/test_cost.py | 49 ++++++++----- torchrl/objectives/a2c.py | 54 ++++++++++----- torchrl/objectives/ppo.py | 118 ++++++++++++++++++++++---------- torchrl/objectives/reinforce.py | 118 +++++++++++++++++++++++--------- 4 files changed, 233 insertions(+), 106 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 8d704566c39..b8b5e265f8b 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -5820,7 +5820,10 @@ def _create_seq_mock_data_ppo( @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): + @pytest.mark.parametrize("functional", [True, False]) + def test_ppo( + self, loss_class, device, gradient_mode, advantage, td_est, functional + ): torch.manual_seed(self.seed) td = self._create_seq_mock_data_ppo(device=device) @@ -5850,7 +5853,7 @@ def test_ppo(self, loss_class, device, gradient_mode, advantage, td_est): else: raise NotImplementedError - loss_fn = loss_class(actor, value, loss_critic_type="l2") + loss_fn = loss_class(actor, value, loss_critic_type="l2", functional=functional) if advantage is not None: advantage(td) else: @@ -6328,7 +6331,7 @@ def test_ppo_notensordict( ) value = self._create_mock_value(observation_key=observation_key) - loss = loss_class(actor=actor, critic=value) + loss = loss_class(actor_network=actor, critic_network=value) loss.set_keys( action=action_key, reward=reward_key, @@ -6537,7 +6540,8 @@ def _create_seq_mock_data_a2c( @pytest.mark.parametrize("advantage", ("gae", "vtrace", "td", "td_lambda", None)) @pytest.mark.parametrize("device", get_default_devices()) @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) - def test_a2c(self, device, gradient_mode, advantage, td_est): + @pytest.mark.parametrize("functional", (True, False)) + def test_a2c(self, device, gradient_mode, advantage, td_est, functional): torch.manual_seed(self.seed) td = self._create_seq_mock_data_a2c(device=device) @@ -6567,7 +6571,7 @@ def test_a2c(self, device, gradient_mode, advantage, td_est): else: raise NotImplementedError - loss_fn = A2CLoss(actor, value, loss_critic_type="l2") + loss_fn = A2CLoss(actor, value, loss_critic_type="l2", functional=functional) # Check error is raised when actions require grads td["action"].requires_grad = True @@ -6629,7 +6633,9 @@ def test_a2c_state_dict(self, device, gradient_mode): def test_a2c_separate_losses(self, separate_losses): torch.manual_seed(self.seed) actor, critic, common, td = self._create_mock_common_layer_setup() - loss_fn = A2CLoss(actor=actor, critic=critic, separate_losses=separate_losses) + loss_fn = A2CLoss( + actor_network=actor, critic_network=critic, separate_losses=separate_losses + ) # Check error is raised when actions require grads td["action"].requires_grad = True @@ -6966,7 +6972,6 @@ def test_a2c_notensordict( class TestReinforce(LossModuleTestBase): seed = 0 - @pytest.mark.parametrize("delay_value", [True, False]) @pytest.mark.parametrize("gradient_mode", [True, False]) @pytest.mark.parametrize("advantage", ["gae", "td", "td_lambda", None]) @pytest.mark.parametrize( @@ -6979,7 +6984,12 @@ class TestReinforce(LossModuleTestBase): None, ], ) - def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est): + @pytest.mark.parametrize( + "delay_value,functional", [[False, True], [False, False], [True, True]] + ) + def test_reinforce_value_net( + self, advantage, gradient_mode, delay_value, td_est, functional + ): n_obs = 3 n_act = 5 batch = 4 @@ -7023,8 +7033,9 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est loss_fn = ReinforceLoss( actor_net, - critic=value_net, + critic_network=value_net, delay_value=delay_value, + functional=functional, ) td = TensorDict( @@ -7049,7 +7060,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est if advantage is not None: params = TensorDict.from_module(value_net) if delay_value: - target_params = loss_fn.target_critic_params + target_params = loss_fn.target_critic_network_params else: target_params = None advantage(td, params=params, target_params=target_params) @@ -7108,7 +7119,7 @@ def test_reinforce_tensordict_keys(self, td_est): loss_fn = ReinforceLoss( actor_net, - critic=value_net, + critic_network=value_net, ) default_keys = { @@ -7133,7 +7144,7 @@ def test_reinforce_tensordict_keys(self, td_est): loss_fn = ReinforceLoss( actor_net, - critic=value_net, + critic_network=value_net, ) key_mapping = { @@ -7207,14 +7218,14 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): torch.manual_seed(self.seed) actor, critic, common, td = self._create_mock_common_layer_setup() loss_fn = ReinforceLoss( - actor=actor, critic=critic, separate_losses=separate_losses + actor_network=actor, critic_network=critic, separate_losses=separate_losses ) loss = loss_fn(td) assert all( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.critic_params.values(True, True) + for p in loss_fn.critic_network_params.values(True, True) ) assert all( (p.grad is None) or (p.grad == 0).all() @@ -7234,14 +7245,14 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): for p in loss_fn.actor_network_params.values(True, True) ) common_layers = itertools.islice( - loss_fn.critic_params.values(True, True), + loss_fn.critic_network_params.values(True, True), common_layers_no, ) assert all( (p.grad is None) or (p.grad == 0).all() for p in common_layers ) critic_layers = itertools.islice( - loss_fn.critic_params.values(True, True), + loss_fn.critic_network_params.values(True, True), common_layers_no, None, ) @@ -7250,7 +7261,7 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): ) else: common_layers = itertools.islice( - loss_fn.critic_params.values(True, True), + loss_fn.critic_network_params.values(True, True), common_layers_no, ) assert not any( @@ -7266,7 +7277,7 @@ def test_reinforce_tensordict_separate_losses(self, separate_losses): ) assert not any( (p.grad is None) or (p.grad == 0).all() - for p in loss_fn.critic_params.values(True, True) + for p in loss_fn.critic_network_params.values(True, True) ) else: @@ -7297,7 +7308,7 @@ def test_reinforce_notensordict( in_keys=["loc", "scale"], spec=UnboundedContinuousTensorSpec(n_act), ) - loss = ReinforceLoss(actor=actor_net, critic=value_net) + loss = ReinforceLoss(actor_network=actor_net, critic_network=value_net) loss.set_keys( reward=reward_key, done=done_key, diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 42d477f15f2..6cbd1792eaf 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -9,12 +9,7 @@ from typing import Tuple import torch -from tensordict.nn import ( - dispatch, - ProbabilisticTensorDictSequential, - repopulate_module, - TensorDictModule, -) +from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -69,6 +64,10 @@ class A2CLoss(LossModule): The input tensordict key where the advantage is expected to be written. default: "advantage" value_target_key (str): [Deprecated, use set_keys() instead] the input tensordict key where the target state value is expected to be written. Defaults to ``"value_target"``. + functional (bool, optional): whether modules should be functionalized. + Functionalizing permits features like meta-RL, but makes it + impossible to use distributed models (DDP, FSDP, ...) and comes + with a little cost. Defaults to ``True``. .. note: The advantage (typically GAE) can be computed by the loss function or @@ -222,8 +221,8 @@ class _AcceptedKeys: def __init__( self, - actor_network: ProbabilisticTensorDictSequential, - critic_network: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential = None, + critic_network: TensorDictModule = None, *, entropy_bonus: bool = True, samples_mc_entropy: int = 1, @@ -244,14 +243,18 @@ def __init__( if critic is not None: critic_network = critic del critic + if actor_network is None or critic_network is None: + raise TypeError( + "Missing positional arguments actor_network or critic_network." + ) + self._functional = functional self._out_keys = None super().__init__() self._set_deprecated_ctor_keys( advantage=advantage_key, value_target=value_target_key ) - self.functional = functional if functional: self.convert_to_functional( actor_network, "actor_network", funs_to_decorate=["forward", "get_dist"] @@ -262,7 +265,7 @@ def __init__( if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared - policy_params = list(actor.parameters()) + policy_params = list(actor_network.parameters()) else: policy_params = None if functional: @@ -271,6 +274,7 @@ def __init__( ) else: self.critic_network = critic_network + self.target_critic_network_params = None self.samples_mc_entropy = samples_mc_entropy self.entropy_bonus = entropy_bonus and entropy_coef @@ -287,10 +291,18 @@ def __init__( self.gamma = gamma self.loss_critic_type = loss_critic_type + @property + def functional(self): + return self._functional + @property def actor(self): return self.actor_network + @property + def critic(self): + return self.critic_network + @property def in_keys(self): keys = [ @@ -367,7 +379,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: # overhead that we could easily reduce. target_return = tensordict.get(self.tensor_keys.value_target) tensordict_select = tensordict.select(*self.critic.in_keys) - with self.critic_params.to_module( + with self.critic_network_params.to_module( self.critic ) if self.functional else contextlib.nullcontext(): state_value = self.critic( @@ -390,8 +402,10 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: @property @_cache_values - def _cached_detach_critic_params(self): - return self.critic_params.detach() + def _cached_detach_critic_network_params(self): + if not self.functional: + return None + return self.critic_network_params.detach() @dispatch() def forward(self, tensordict: TensorDictBase) -> TensorDictBase: @@ -400,8 +414,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if advantage is None: self.value_estimator( tensordict, - params=self._cached_detach_critic_params, - target_params=self.target_critic_params, + params=self._cached_detach_critic_network_params, + target_params=self.target_critic_network_params, ) advantage = tensordict.get(self.tensor_keys.advantage) assert not advantage.requires_grad @@ -436,11 +450,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor - actor_with_params = repopulate_module( - deepcopy(self.actor_network), self.actor_network_params - ) + if self.functional: + actor_with_params = deepcopy(self.actor_network) + self.actor_network_params.to_module(actor_with_params) + else: + actor_with_params = self.actor_network self._value_estimator = VTrace( - value_network=self.critic, actor_network=actor_with_params, **hp + value_network=self.critic_network, actor_network=actor_with_params, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 32cef51aa9b..86e224058f8 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import contextlib + import math import warnings from copy import deepcopy @@ -9,12 +11,7 @@ from typing import Tuple import torch -from tensordict.nn import ( - dispatch, - ProbabilisticTensorDictSequential, - repopulate_module, - TensorDictModule, -) +from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torch import distributions as d @@ -82,6 +79,10 @@ class PPOLoss(LossModule): value_key (str, optional): [Deprecated, use set_keys(value_key) instead] The input tensordict key where the state value is expected to be written. Defaults to ``"state_value"``. + functional (bool, optional): whether modules should be functionalized. + Functionalizing permits features like meta-RL, but makes it + impossible to use distributed models (DDP, FSDP, ...) and comes + with a little cost. Defaults to ``True``. .. note:: The advantage (typically GAE) can be computed by the loss function or @@ -283,20 +284,37 @@ def __init__( if critic is not None: critic_network = critic del critic + if actor_network is None or critic_network is None: + raise TypeError( + "Missing positional arguments actor_network or critic_network." + ) + self._functional = functional self._in_keys = None self._out_keys = None super().__init__() - self.convert_to_functional(actor_network, "actor_network") + if functional: + self.convert_to_functional(actor_network, "actor_network") + else: + self.actor_network = actor_network + self.actor_network_params = None + self.target_actor_network_params = None + if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared - policy_params = list(actor.parameters()) + policy_params = list(actor_network.parameters()) else: policy_params = None - self.convert_to_functional( - critic_network, "critic_network", compare_against=policy_params - ) + if functional: + self.convert_to_functional( + critic_network, "critic_network", compare_against=policy_params + ) + else: + self.critic_network = critic_network + self.critic_network_params = None + self.target_critic_network_params = None + self.samples_mc_entropy = samples_mc_entropy self.entropy_bonus = entropy_bonus self.separate_losses = separate_losses @@ -319,6 +337,18 @@ def __init__( value=value_key, ) + @property + def functional(self): + return self._functional + + @property + def actor(self): + return self.actor_network + + @property + def critic(self): + return self.critic_network + def _set_in_keys(self): keys = [ self.tensor_keys.action, @@ -390,7 +420,9 @@ def _log_weight( f"tensordict stored {self.tensor_keys.action} requires grad." ) - with self.actor_network_params.to_module(self.actor_network): + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): dist = self.actor_network.get_dist(tensordict) log_prob = dist.log_prob(action) @@ -417,7 +449,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: f"can be used for the value loss." ) - with self.critic_network_params.to_module(self.critic_network): + with self.critic_network_params.to_module( + self.critic_network + ) if self.functional else contextlib.nullcontext(): state_value_td = self.critic_network(tensordict) try: @@ -438,6 +472,8 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: @property @_cache_values def _cached_critic_params_detached(self): + if not self.functional: + return None return self.critic_network_params.detach() @dispatch @@ -447,8 +483,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if advantage is None: self.value_estimator( tensordict, - params=self._cached_critic_params_detached, - target_params=self.target_critic_params, + params=self._cached_critic_network_params_detached, + target_params=self.target_critic_network_params, ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: @@ -492,9 +528,11 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams ) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor - actor_with_params = repopulate_module( - deepcopy(self.actor_network), self.actor_network_params - ) + if self.functional: + actor_with_params = deepcopy(self.actor_network) + self.actor_network_params.to_module(actor_with_params) + else: + actor_with_params = self.actor_network self._value_estimator = VTrace( value_network=self.critic_network, actor_network=actor_with_params, **hp ) @@ -520,8 +558,8 @@ class ClipPPOLoss(PPOLoss): loss = -min( weight * advantage, min(max(weight, 1-eps), 1+eps) * advantage) Args: - actor (ProbabilisticTensorDictSequential): policy operator. - critic (ValueOperator): value operator. + actor_network (ProbabilisticTensorDictSequential): policy operator. + critic_network (ValueOperator): value operator. Keyword Args: clip_epsilon (scalar, optional): weight clipping threshold in the clipped PPO loss equation. @@ -555,6 +593,10 @@ class ClipPPOLoss(PPOLoss): value_key (str, optional): [Deprecated, use set_keys(value_key) instead] The input tensordict key where the state value is expected to be written. Defaults to ``"state_value"``. + functional (bool, optional): whether modules should be functionalized. + Functionalizing permits features like meta-RL, but makes it + impossible to use distributed models (DDP, FSDP, ...) and comes + with a little cost. Defaults to ``True``. .. note: The advantage (typically GAE) can be computed by the loss function or @@ -601,8 +643,8 @@ class ClipPPOLoss(PPOLoss): def __init__( self, - actor: ProbabilisticTensorDictSequential, - critic: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential, + critic_network: TensorDictModule, *, clip_epsilon: float = 0.2, entropy_bonus: bool = True, @@ -616,8 +658,8 @@ def __init__( **kwargs, ): super(ClipPPOLoss, self).__init__( - actor, - critic, + actor_network, + critic_network, entropy_bonus=entropy_bonus, samples_mc_entropy=samples_mc_entropy, entropy_coef=entropy_coef, @@ -660,8 +702,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if advantage is None: self.value_estimator( tensordict, - params=self._cached_critic_params_detached, - target_params=self.target_critic_params, + params=self._cached_critic_network_params_detached, + target_params=self.target_critic_network_params, ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: @@ -707,8 +749,8 @@ class KLPENPPOLoss(PPOLoss): favouring a certain level of distancing between the two while still preventing them to be too much apart. Args: - actor (ProbabilisticTensorDictSequential): policy operator. - critic (ValueOperator): value operator. + actor_network (ProbabilisticTensorDictSequential): policy operator. + critic_network (ValueOperator): value operator. Keyword Args: dtarg (scalar, optional): target KL divergence. Defaults to ``0.01``. @@ -749,6 +791,10 @@ class KLPENPPOLoss(PPOLoss): value_key (str, optional): [Deprecated, use set_keys(value_key) instead] The input tensordict key where the state value is expected to be written. Defaults to ``"state_value"``. + functional (bool, optional): whether modules should be functionalized. + Functionalizing permits features like meta-RL, but makes it + impossible to use distributed models (DDP, FSDP, ...) and comes + with a little cost. Defaults to ``True``. .. note: @@ -796,8 +842,8 @@ class KLPENPPOLoss(PPOLoss): def __init__( self, - actor: ProbabilisticTensorDictSequential, - critic: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential, + critic_network: TensorDictModule, *, dtarg: float = 0.01, beta: float = 1.0, @@ -815,8 +861,8 @@ def __init__( **kwargs, ): super(KLPENPPOLoss, self).__init__( - actor, - critic, + actor_network, + critic_network, entropy_bonus=entropy_bonus, samples_mc_entropy=samples_mc_entropy, entropy_coef=entropy_coef, @@ -866,8 +912,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: if advantage is None: self.value_estimator( tensordict, - params=self._cached_critic_params_detached, - target_params=self.target_critic_params, + params=self._cached_critic_network_params_detached, + target_params=self.target_critic_network_params, ) advantage = tensordict.get(self.tensor_keys.advantage) if self.normalize_advantage and advantage.numel() > 1: @@ -878,7 +924,9 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: neg_loss = log_weight.exp() * advantage previous_dist = self.actor_network.build_dist_from_params(tensordict) - with self.actor_network_params.to_module(self.actor_network): + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): current_dist = self.actor_network.get_dist(tensordict) try: kl = torch.distributions.kl.kl_divergence(previous_dist, current_dist) diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 832af829c64..2ba10ece317 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -2,19 +2,16 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + +import contextlib import warnings from copy import deepcopy from dataclasses import dataclass -from typing import Optional import torch -from tensordict.nn import ( - dispatch, - ProbabilisticTensorDictSequential, - repopulate_module, - TensorDictModule, -) +from tensordict.nn import dispatch, ProbabilisticTensorDictSequential, TensorDictModule from tensordict.tensordict import TensorDict, TensorDictBase from tensordict.utils import NestedKey from torchrl.objectives.common import LossModule @@ -41,10 +38,12 @@ class ReinforceLoss(LossModule): Args: - actor (ProbabilisticTensorDictSequential): policy operator. - critic (ValueOperator): value operator. + actor_network (ProbabilisticTensorDictSequential): policy operator. + critic_network (ValueOperator): value operator. + + Keyword Args: delay_value (bool, optional): if ``True``, a target network is needed - for the critic. Defaults to ``False``. + for the critic. Defaults to ``False``. Incompatible with ``functional=False``. loss_critic_type (str): loss function for the value discrepancy. Can be one of "l1", "l2" or "smooth_l1". Defaults to ``"smooth_l1"``. advantage_key (str): [Deprecated, use .set_keys(advantage_key=advantage_key) instead] @@ -57,6 +56,10 @@ class ReinforceLoss(LossModule): policy and critic will only be trained on the policy loss. Defaults to ``False``, ie. gradients are propagated to shared parameters for both policy and critic losses. + functional (bool, optional): whether modules should be functionalized. + Functionalizing permits features like meta-RL, but makes it + impossible to use distributed models (DDP, FSDP, ...) and comes + with a little cost. Defaults to ``True``. .. note: The advantage (typically GAE) can be computed by the loss function or @@ -208,8 +211,8 @@ def __new__(cls, *args, **kwargs): def __init__( self, - actor: ProbabilisticTensorDictSequential, - critic: Optional[TensorDictModule] = None, + actor_network: ProbabilisticTensorDictSequential, + critic_network: TensorDictModule | None = None, *, delay_value: bool = False, loss_critic_type: str = "smooth_l1", @@ -217,7 +220,27 @@ def __init__( advantage_key: str = None, value_target_key: str = None, separate_losses: bool = False, + functional: bool = True, + actor: ProbabilisticTensorDictSequential = None, + critic: ProbabilisticTensorDictSequential = None, ) -> None: + if actor is not None: + actor_network = actor + del actor + if critic is not None: + critic_network = critic + del critic + if actor_network is None or critic_network is None: + raise TypeError( + "Missing positional arguments actor_network or critic_network." + ) + if not functional and delay_value: + raise RuntimeError( + "delay_value and ~functional are incompatible, as delayed value currently relies on functional calls." + ) + + self._functional = functional + super().__init__() self.in_keys = None self._set_deprecated_ctor_keys( @@ -228,29 +251,50 @@ def __init__( self.loss_critic_type = loss_critic_type # Actor - self.convert_to_functional( - actor, - "actor_network", - create_target_params=False, - ) + if self.functional: + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=False, + ) + else: + self.actor_network = actor_network + if separate_losses: # we want to make sure there are no duplicates in the params: the # params of critic must be refs to actor if they're shared - policy_params = list(actor.parameters()) + policy_params = list(actor_network.parameters()) else: policy_params = None # Value - if critic is not None: - self.convert_to_functional( - critic, - "critic", - create_target_params=self.delay_value, - compare_against=policy_params, - ) + if critic_network is not None: + if self.functional: + self.convert_to_functional( + critic_network, + "critic_network", + create_target_params=self.delay_value, + compare_against=policy_params, + ) + else: + self.critic_network = critic_network + self.target_critic_network_params = None + if gamma is not None: warnings.warn(_GAMMA_LMBDA_DEPREC_WARNING, category=DeprecationWarning) self.gamma = gamma + @property + def functional(self): + return self._functional + + @property + def actor(self): + return self.actor_network + + @property + def critic(self): + return self.critic_network + def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys( @@ -291,13 +335,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: if advantage is None: self.value_estimator( tensordict, - params=self.critic_params.detach(), - target_params=self.target_critic_params, + params=self.critic_network_params.detach() if self.functional else None, + target_params=self.target_critic_network_params + if self.functional + else None, ) advantage = tensordict.get(self.tensor_keys.advantage) # compute log-prob - with self.actor_network_params.to_module(self.actor_network): + with self.actor_network_params.to_module( + self.actor_network + ) if self.functional else contextlib.nullcontext(): tensordict = self.actor_network(tensordict) log_prob = tensordict.get(self.tensor_keys.sample_log_prob) @@ -315,7 +363,9 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: try: target_return = tensordict.get(self.tensor_keys.value_target) tensordict_select = tensordict.select(*self.critic.in_keys) - with self.critic_params.to_module(self.critic): + with self.critic_network_params.to_module( + self.critic + ) if self.functional else contextlib.nullcontext(): state_value = self.critic(tensordict_select).get(self.tensor_keys.value) loss_value = distance_loss( target_return, @@ -350,11 +400,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams self._value_estimator = TDLambdaEstimator(value_network=self.critic, **hp) elif value_type == ValueEstimators.VTrace: # VTrace currently does not support functional call on the actor - actor_with_params = repopulate_module( - deepcopy(self.actor), self.actor_params - ) + if self.functional: + actor_with_params = deepcopy(self.actor_network) + self.actor_network_params.to_module(actor_with_params) + else: + actor_with_params = self.actor_network self._value_estimator = VTrace( - value_network=self.critic, actor_network=actor_with_params, **hp + value_network=self.critic_network, actor_network=actor_with_params, **hp ) else: raise NotImplementedError(f"Unknown value type {value_type}") From 3eab2d488f8c44632ba042f4652a7415b595eedd Mon Sep 17 00:00:00 2001 From: vmoens Date: Fri, 19 Jan 2024 20:18:39 +0000 Subject: [PATCH 3/5] amend --- torchrl/objectives/ppo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 86e224058f8..6379996d47a 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -471,7 +471,7 @@ def loss_critic(self, tensordict: TensorDictBase) -> torch.Tensor: @property @_cache_values - def _cached_critic_params_detached(self): + def _cached_critic_network_params_detached(self): if not self.functional: return None return self.critic_network_params.detach() From 92f763756b2c3362058478e9c35efba8e43b661b Mon Sep 17 00:00:00 2001 From: vmoens Date: Sat, 20 Jan 2024 06:09:17 +0000 Subject: [PATCH 4/5] amend --- benchmarks/test_objectives_benchmarks.py | 6 +++--- examples/a2c/a2c_atari.py | 4 ++-- examples/a2c/a2c_mujoco.py | 4 ++-- .../collectors/multi_nodes/ray_train.py | 2 +- examples/impala/impala_multi_node_ray.py | 4 ++-- examples/impala/impala_multi_node_submitit.py | 4 ++-- examples/impala/impala_single_node.py | 4 ++-- examples/multiagent/mappo_ippo.py | 4 ++-- examples/ppo/ppo_atari.py | 4 ++-- examples/ppo/ppo_mujoco.py | 4 ++-- torchrl/objectives/ppo.py | 14 ++++++++------ tutorials/sphinx-tutorials/coding_ppo.py | 4 ++-- tutorials/sphinx-tutorials/multiagent_ppo.py | 4 ++-- 13 files changed, 32 insertions(+), 30 deletions(-) diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index d07e8f5da90..4cfc8470a15 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -548,7 +548,7 @@ def test_a2c_speed( actor(td.clone()) critic(td.clone()) - loss = A2CLoss(actor=actor, critic=critic) + loss = A2CLoss(actor_network=actor, critic_network=critic) advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) @@ -605,7 +605,7 @@ def test_ppo_speed( actor(td.clone()) critic(td.clone()) - loss = ClipPPOLoss(actor=actor, critic=critic) + loss = ClipPPOLoss(actor_network=actor, critic_network=critic) advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) @@ -662,7 +662,7 @@ def test_reinforce_speed( actor(td.clone()) critic(td.clone()) - loss = ReinforceLoss(actor=actor, critic=critic) + loss = ReinforceLoss(actor_network=actor, critic_network=critic) advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95, shifted=True) advantage(td) loss(td) diff --git a/examples/a2c/a2c_atari.py b/examples/a2c/a2c_atari.py index 8d19080f223..0452d7d600f 100644 --- a/examples/a2c/a2c_atari.py +++ b/examples/a2c/a2c_atari.py @@ -69,8 +69,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_gae=True, ) loss_module = A2CLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, diff --git a/examples/a2c/a2c_mujoco.py b/examples/a2c/a2c_mujoco.py index 4076631f1ef..2628a6f388c 100644 --- a/examples/a2c/a2c_mujoco.py +++ b/examples/a2c/a2c_mujoco.py @@ -63,8 +63,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_gae=False, ) loss_module = A2CLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, diff --git a/examples/distributed/collectors/multi_nodes/ray_train.py b/examples/distributed/collectors/multi_nodes/ray_train.py index 955d97113fe..2db86b9f917 100644 --- a/examples/distributed/collectors/multi_nodes/ray_train.py +++ b/examples/distributed/collectors/multi_nodes/ray_train.py @@ -145,7 +145,7 @@ ) loss_module = ClipPPOLoss( actor=policy_module, - critic=value_module, + critic_network=value_module, advantage_key="advantage", clip_epsilon=clip_epsilon, entropy_bonus=bool(entropy_eps), diff --git a/examples/impala/impala_multi_node_ray.py b/examples/impala/impala_multi_node_ray.py index 46941529c00..49b3dd4bd4d 100644 --- a/examples/impala/impala_multi_node_ray.py +++ b/examples/impala/impala_multi_node_ray.py @@ -114,8 +114,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_adv=False, ) loss_module = A2CLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, diff --git a/examples/impala/impala_multi_node_submitit.py b/examples/impala/impala_multi_node_submitit.py index 7eef42ec98f..2b89ef046a1 100644 --- a/examples/impala/impala_multi_node_submitit.py +++ b/examples/impala/impala_multi_node_submitit.py @@ -106,8 +106,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_adv=False, ) loss_module = A2CLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, diff --git a/examples/impala/impala_single_node.py b/examples/impala/impala_single_node.py index 9a853e9bc76..f5b64e4718a 100644 --- a/examples/impala/impala_single_node.py +++ b/examples/impala/impala_single_node.py @@ -84,8 +84,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_adv=False, ) loss_module = A2CLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, critic_coef=cfg.loss.critic_coef, diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py index 95d340046fa..cb31eabcd37 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/examples/multiagent/mappo_ippo.py @@ -137,8 +137,8 @@ def train(cfg: "DictConfig"): # noqa: F821 # Loss loss_module = ClipPPOLoss( - actor=policy, - critic=value_module, + actor_network=policy, + critic_network=value_module, clip_epsilon=cfg.loss.clip_epsilon, entropy_coef=cfg.loss.entropy_eps, normalize_advantage=False, diff --git a/examples/ppo/ppo_atari.py b/examples/ppo/ppo_atari.py index 86685fa2642..1e69dd7678d 100644 --- a/examples/ppo/ppo_atari.py +++ b/examples/ppo/ppo_atari.py @@ -70,8 +70,8 @@ def main(cfg: "DictConfig"): # noqa: F821 average_gae=False, ) loss_module = ClipPPOLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, clip_epsilon=cfg.loss.clip_epsilon, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, diff --git a/examples/ppo/ppo_mujoco.py b/examples/ppo/ppo_mujoco.py index eca985c2069..90fe74650f5 100644 --- a/examples/ppo/ppo_mujoco.py +++ b/examples/ppo/ppo_mujoco.py @@ -70,8 +70,8 @@ def main(cfg: "DictConfig"): # noqa: F821 ) loss_module = ClipPPOLoss( - actor=actor, - critic=critic, + actor_network=actor, + critic_network=critic, clip_epsilon=cfg.loss.clip_epsilon, loss_critic_type=cfg.loss.loss_critic_type, entropy_coef=cfg.loss.entropy_coef, diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 6379996d47a..260ab359ae6 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -2,6 +2,8 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from __future__ import annotations + import contextlib import math @@ -260,8 +262,8 @@ class _AcceptedKeys: def __init__( self, - actor_network: ProbabilisticTensorDictSequential = None, - critic_network: TensorDictModule = None, + actor_network: ProbabilisticTensorDictSequential | None = None, + critic_network: TensorDictModule | None = None, *, entropy_bonus: bool = True, samples_mc_entropy: int = 1, @@ -643,8 +645,8 @@ class ClipPPOLoss(PPOLoss): def __init__( self, - actor_network: ProbabilisticTensorDictSequential, - critic_network: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential | None = None, + critic_network: TensorDictModule | None = None, *, clip_epsilon: float = 0.2, entropy_bonus: bool = True, @@ -842,8 +844,8 @@ class KLPENPPOLoss(PPOLoss): def __init__( self, - actor_network: ProbabilisticTensorDictSequential, - critic_network: TensorDictModule, + actor_network: ProbabilisticTensorDictSequential | None = None, + critic_network: TensorDictModule | None = None, *, dtarg: float = 0.01, beta: float = 1.0, diff --git a/tutorials/sphinx-tutorials/coding_ppo.py b/tutorials/sphinx-tutorials/coding_ppo.py index 51228e66da1..56f96221a40 100644 --- a/tutorials/sphinx-tutorials/coding_ppo.py +++ b/tutorials/sphinx-tutorials/coding_ppo.py @@ -555,8 +555,8 @@ ) loss_module = ClipPPOLoss( - actor=policy_module, - critic=value_module, + actor_network=policy_module, + critic_network=value_module, clip_epsilon=clip_epsilon, entropy_bonus=bool(entropy_eps), entropy_coef=entropy_eps, diff --git a/tutorials/sphinx-tutorials/multiagent_ppo.py b/tutorials/sphinx-tutorials/multiagent_ppo.py index d8726e804f4..f32d2d93b2f 100644 --- a/tutorials/sphinx-tutorials/multiagent_ppo.py +++ b/tutorials/sphinx-tutorials/multiagent_ppo.py @@ -595,8 +595,8 @@ # loss_module = ClipPPOLoss( - actor=policy, - critic=critic, + actor_network=policy, + critic_network=critic, clip_epsilon=clip_epsilon, entropy_coef=entropy_eps, normalize_advantage=False, # Important to avoid normalizing across the agent dimension From 55127fe811c68a145b8e223607ac394408703be0 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 21:14:44 +0000 Subject: [PATCH 5/5] amend --- examples/multiagent/mappo_ippo.py | 2 +- torchrl/objectives/a2c.py | 33 +++++++++++++++++++++++++++++++ torchrl/objectives/ppo.py | 33 +++++++++++++++++++++++++++++++ torchrl/objectives/reinforce.py | 33 +++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 1 deletion(-) diff --git a/examples/multiagent/mappo_ippo.py b/examples/multiagent/mappo_ippo.py index cb31eabcd37..b00bb18a2a0 100644 --- a/examples/multiagent/mappo_ippo.py +++ b/examples/multiagent/mappo_ippo.py @@ -174,7 +174,7 @@ def train(cfg: "DictConfig"): # noqa: F821 with torch.no_grad(): loss_module.value_estimator( tensordict_data, - params=loss_module.critic_params, + params=loss_module.critic_network_params, target_params=loss_module.target_critic_params, ) current_frames = tensordict_data.numel() diff --git a/torchrl/objectives/a2c.py b/torchrl/objectives/a2c.py index 6cbd1792eaf..397b9de4e23 100644 --- a/torchrl/objectives/a2c.py +++ b/torchrl/objectives/a2c.py @@ -3,6 +3,7 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import contextlib +import logging import warnings from copy import deepcopy from dataclasses import dataclass @@ -297,12 +298,44 @@ def functional(self): @property def actor(self): + logging.warning( + f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " + "link will be removed in v0.4." + ) return self.actor_network @property def critic(self): + logging.warning( + f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " + "link will be removed in v0.4." + ) return self.critic_network + @property + def actor_params(self): + logging.warning( + f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " + "link will be removed in v0.4." + ) + return self.actor_network_params + + @property + def critic_params(self): + logging.warning( + f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.critic_network_params + + @property + def target_critic_params(self): + logging.warning( + f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.target_critic_network_params + @property def in_keys(self): keys = [ diff --git a/torchrl/objectives/ppo.py b/torchrl/objectives/ppo.py index 260ab359ae6..0d4714e6a2d 100644 --- a/torchrl/objectives/ppo.py +++ b/torchrl/objectives/ppo.py @@ -5,6 +5,7 @@ from __future__ import annotations import contextlib +import logging import math import warnings @@ -345,12 +346,44 @@ def functional(self): @property def actor(self): + logging.warning( + f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " + "link will be removed in v0.4." + ) return self.actor_network @property def critic(self): + logging.warning( + f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " + "link will be removed in v0.4." + ) return self.critic_network + @property + def actor_params(self): + logging.warning( + f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " + "link will be removed in v0.4." + ) + return self.actor_network_params + + @property + def critic_params(self): + logging.warning( + f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.critic_network_params + + @property + def target_critic_params(self): + logging.warning( + f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.target_critic_network_params + def _set_in_keys(self): keys = [ self.tensor_keys.action, diff --git a/torchrl/objectives/reinforce.py b/torchrl/objectives/reinforce.py index 2ba10ece317..98c4d4d14d3 100644 --- a/torchrl/objectives/reinforce.py +++ b/torchrl/objectives/reinforce.py @@ -5,6 +5,7 @@ from __future__ import annotations import contextlib +import logging import warnings from copy import deepcopy from dataclasses import dataclass @@ -289,12 +290,44 @@ def functional(self): @property def actor(self): + logging.warning( + f"{self.__class__.__name__}.actor is deprecated, use {self.__class__.__name__}.actor_network instead. This " + "link will be removed in v0.4." + ) return self.actor_network @property def critic(self): + logging.warning( + f"{self.__class__.__name__}.critic is deprecated, use {self.__class__.__name__}.critic_network instead. This " + "link will be removed in v0.4." + ) return self.critic_network + @property + def actor_params(self): + logging.warning( + f"{self.__class__.__name__}.actor_params is deprecated, use {self.__class__.__name__}.actor_network_params instead. This " + "link will be removed in v0.4." + ) + return self.actor_network_params + + @property + def critic_params(self): + logging.warning( + f"{self.__class__.__name__}.critic_params is deprecated, use {self.__class__.__name__}.critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.critic_network_params + + @property + def target_critic_params(self): + logging.warning( + f"{self.__class__.__name__}.target_critic_params is deprecated, use {self.__class__.__name__}.target_critic_network_params instead. This " + "link will be removed in v0.4." + ) + return self.target_critic_network_params + def _forward_value_estimator_keys(self, **kwargs) -> None: if self._value_estimator is not None: self._value_estimator.set_keys(