Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions examples/ddpg/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,14 @@ def main(cfg: "DictConfig"): # noqa: F821
# Sample from replay buffer
sampled_tensordict = replay_buffer.sample().clone()

# Compute loss
loss_td = loss_module(sampled_tensordict)

actor_loss = loss_td["loss_actor"]
q_loss = loss_td["loss_value"]

# Update critic
q_loss, *_ = loss_module.loss_value(sampled_tensordict)
optimizer_critic.zero_grad()
q_loss.backward()
optimizer_critic.step()

# Update actor
actor_loss, *_ = loss_module.loss_actor(sampled_tensordict)
optimizer_actor.zero_grad()
actor_loss.backward()
optimizer_actor.step()
Expand Down
2 changes: 1 addition & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1765,7 +1765,7 @@ def test_ddpg_notensordict(self):
with pytest.warns(UserWarning, match="No target network updater has been"):
loss_val_td = loss(td)
loss_val = loss(**kwargs)
for i, key in enumerate(loss_val_td.keys()):
for i, key in enumerate(loss.out_keys):
torch.testing.assert_close(loss_val_td.get(key), loss_val[i])
# test select
loss.select_out_keys("loss_actor", "target_value")
Expand Down
53 changes: 29 additions & 24 deletions torchrl/objectives/ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,32 +280,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
a tuple of 2 tensors containing the DDPG loss.

"""
loss_value, td_error, pred_val, target_value = self._loss_value(tensordict)
td_error = td_error.detach()
if tensordict.device is not None:
td_error = td_error.to(tensordict.device)
tensordict.set(
self.tensor_keys.priority,
td_error,
inplace=True,
)
loss_actor = self._loss_actor(tensordict)
loss_value, metadata = self.loss_value(tensordict)
loss_actor, metadata_actor = self.loss_actor(tensordict)
metadata.update(metadata_actor)
return TensorDict(
source={
"loss_actor": loss_actor.mean(),
"loss_value": loss_value.mean(),
"pred_value": pred_val.mean().detach(),
"target_value": target_value.mean().detach(),
"pred_value_max": pred_val.max().detach(),
"target_value_max": target_value.max().detach(),
},
source={"loss_actor": loss_actor, "loss_value": loss_value, **metadata},
batch_size=[],
)

def _loss_actor(
def loss_actor(
self,
tensordict: TensorDictBase,
) -> torch.Tensor:
) -> [torch.Tensor, dict]:
td_copy = tensordict.select(
*self.actor_in_keys, *self.value_exclusive_keys
).detach()
Expand All @@ -317,12 +303,14 @@ def _loss_actor(
td_copy,
params=self._cached_detached_value_params,
)
return -td_copy.get(self.tensor_keys.state_action_value)
loss_actor = -td_copy.get(self.tensor_keys.state_action_value)
metadata = {}
return loss_actor.mean(), metadata

def _loss_value(
def loss_value(
self,
tensordict: TensorDictBase,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, dict]:
# value loss
td_copy = tensordict.select(*self.value_network.in_keys).detach()
self.value_network(
Expand All @@ -340,7 +328,24 @@ def _loss_value(
pred_val, target_value, loss_function=self.loss_function
)

return loss_value, (pred_val - target_value).pow(2), pred_val, target_value
td_error = (pred_val - target_value).pow(2)
td_error = td_error.detach()
if tensordict.device is not None:
td_error = td_error.to(tensordict.device)
tensordict.set(
self.tensor_keys.priority,
td_error,
inplace=True,
)
with torch.no_grad():
metadata = {
"td_error": td_error.mean(),
"pred_value": pred_val.mean(),
"target_value": target_value.mean(),
"target_value_max": target_value.max(),
"pred_value_max": pred_val.max(),
}
return loss_value.mean(), metadata

def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams):
if value_type is None:
Expand Down