diff --git a/examples/ddpg/ddpg.py b/examples/ddpg/ddpg.py index 273947569be..5688e561ae5 100644 --- a/examples/ddpg/ddpg.py +++ b/examples/ddpg/ddpg.py @@ -120,18 +120,14 @@ def main(cfg: "DictConfig"): # noqa: F821 # Sample from replay buffer sampled_tensordict = replay_buffer.sample().clone() - # Compute loss - loss_td = loss_module(sampled_tensordict) - - actor_loss = loss_td["loss_actor"] - q_loss = loss_td["loss_value"] - # Update critic + q_loss, *_ = loss_module.loss_value(sampled_tensordict) optimizer_critic.zero_grad() q_loss.backward() optimizer_critic.step() # Update actor + actor_loss, *_ = loss_module.loss_actor(sampled_tensordict) optimizer_actor.zero_grad() actor_loss.backward() optimizer_actor.step() diff --git a/test/test_cost.py b/test/test_cost.py index 6c38e6a8b65..a65b3d00809 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1765,7 +1765,7 @@ def test_ddpg_notensordict(self): with pytest.warns(UserWarning, match="No target network updater has been"): loss_val_td = loss(td) loss_val = loss(**kwargs) - for i, key in enumerate(loss_val_td.keys()): + for i, key in enumerate(loss.out_keys): torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) # test select loss.select_out_keys("loss_actor", "target_value") diff --git a/torchrl/objectives/ddpg.py b/torchrl/objectives/ddpg.py index d72afb09f7b..1795f785716 100644 --- a/torchrl/objectives/ddpg.py +++ b/torchrl/objectives/ddpg.py @@ -280,32 +280,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict: a tuple of 2 tensors containing the DDPG loss. """ - loss_value, td_error, pred_val, target_value = self._loss_value(tensordict) - td_error = td_error.detach() - if tensordict.device is not None: - td_error = td_error.to(tensordict.device) - tensordict.set( - self.tensor_keys.priority, - td_error, - inplace=True, - ) - loss_actor = self._loss_actor(tensordict) + loss_value, metadata = self.loss_value(tensordict) + loss_actor, metadata_actor = self.loss_actor(tensordict) + metadata.update(metadata_actor) return TensorDict( - source={ - "loss_actor": loss_actor.mean(), - "loss_value": loss_value.mean(), - "pred_value": pred_val.mean().detach(), - "target_value": target_value.mean().detach(), - "pred_value_max": pred_val.max().detach(), - "target_value_max": target_value.max().detach(), - }, + source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata}, batch_size=[], ) - def _loss_actor( + def loss_actor( self, tensordict: TensorDictBase, - ) -> torch.Tensor: + ) -> [torch.Tensor, dict]: td_copy = tensordict.select( *self.actor_in_keys, *self.value_exclusive_keys ).detach() @@ -317,12 +303,14 @@ def _loss_actor( td_copy, params=self._cached_detached_value_params, ) - return -td_copy.get(self.tensor_keys.state_action_value) + loss_actor = -td_copy.get(self.tensor_keys.state_action_value) + metadata = {} + return loss_actor.mean(), metadata - def _loss_value( + def loss_value( self, tensordict: TensorDictBase, - ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> Tuple[torch.Tensor, dict]: # value loss td_copy = tensordict.select(*self.value_network.in_keys).detach() self.value_network( @@ -340,7 +328,24 @@ def _loss_value( pred_val, target_value, loss_function=self.loss_function ) - return loss_value, (pred_val - target_value).pow(2), pred_val, target_value + td_error = (pred_val - target_value).pow(2) + td_error = td_error.detach() + if tensordict.device is not None: + td_error = td_error.to(tensordict.device) + tensordict.set( + self.tensor_keys.priority, + td_error, + inplace=True, + ) + with torch.no_grad(): + metadata = { + "td_error": td_error.mean(), + "pred_value": pred_val.mean(), + "target_value": target_value.mean(), + "target_value_max": target_value.max(), + "pred_value_max": pred_val.max(), + } + return loss_value.mean(), metadata def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): if value_type is None: