Skip to content

Commit f4d223c

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
2 parents b01901c + c163b3b commit f4d223c

File tree

6 files changed

+60
-88
lines changed

6 files changed

+60
-88
lines changed

docs/source/reference/data.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,26 @@ The following classes are deprecated and just point to the classes above:
972972
UnboundedContinuousTensorSpec
973973
UnboundedDiscreteTensorSpec
974974

975+
Trees and Forests
976+
-----------------
977+
978+
TorchRL offers a set of classes and functions that can be used to represent trees and forests efficiently.
979+
980+
.. currentmodule:: torchrl.data
981+
982+
.. autosummary::
983+
:toctree: generated/
984+
:template: rl_template.rst
985+
986+
BinaryToDecimal
987+
HashToInt
988+
QueryModule
989+
RandomProjectionHash
990+
SipHash
991+
TensorDictMap
992+
TensorMap
993+
994+
975995
Reinforcement Learning From Human Feedback (RLHF)
976996
-------------------------------------------------
977997

torchrl/data/replay_buffers/storages.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from tensordict.base import _NESTED_TENSORS_AS_LISTS
2828
from tensordict.memmap import MemoryMappedTensor
29+
from tensordict.utils import _zip_strict
2930
from torch import multiprocessing as mp
3031
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
3132
from torchrl._utils import _make_ordinal_device, implement_for, logger as torchrl_logger
@@ -258,7 +259,7 @@ def set(
258259
np.ndarray,
259260
),
260261
):
261-
for _cursor, _data in zip(cursor, data, strict=True):
262+
for _cursor, _data in _zip_strict(cursor, data):
262263
self.set(_cursor, _data, set_cursor=set_cursor)
263264
else:
264265
raise TypeError(

torchrl/objectives/cql.py

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -514,32 +514,21 @@ def out_keys(self, values):
514514

515515
@dispatch
516516
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
517-
shape = None
518-
if tensordict.ndimension() > 1:
519-
shape = tensordict.shape
520-
tensordict_reshape = tensordict.reshape(-1)
521-
else:
522-
tensordict_reshape = tensordict
523-
524-
q_loss, metadata = self.q_loss(tensordict_reshape)
525-
cql_loss, cql_metadata = self.cql_loss(tensordict_reshape)
517+
q_loss, metadata = self.q_loss(tensordict)
518+
cql_loss, cql_metadata = self.cql_loss(tensordict)
526519
if self.with_lagrange:
527-
alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(
528-
tensordict_reshape
529-
)
520+
alpha_prime_loss, alpha_prime_metadata = self.alpha_prime_loss(tensordict)
530521
metadata.update(alpha_prime_metadata)
531-
loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict_reshape)
532-
loss_actor, actor_metadata = self.actor_loss(tensordict_reshape)
522+
loss_actor_bc, bc_metadata = self.actor_bc_loss(tensordict)
523+
loss_actor, actor_metadata = self.actor_loss(tensordict)
533524
loss_alpha, alpha_metadata = self.alpha_loss(actor_metadata)
534525
metadata.update(bc_metadata)
535526
metadata.update(cql_metadata)
536527
metadata.update(actor_metadata)
537528
metadata.update(alpha_metadata)
538-
tensordict_reshape.set(
529+
tensordict.set(
539530
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
540531
)
541-
if shape:
542-
tensordict.update(tensordict_reshape.view(shape))
543532
out = {
544533
"loss_actor": loss_actor,
545534
"loss_actor_bc": loss_actor_bc,
@@ -682,7 +671,9 @@ def _get_value_v(self, tensordict, _alpha, actor_params, qval_params):
682671
)
683672
# take max over actions
684673
state_action_value = state_action_value.reshape(
685-
self.num_qvalue_nets, tensordict.shape[0], self.num_random, -1
674+
torch.Size(
675+
[self.num_qvalue_nets, *tensordict.shape, self.num_random, -1]
676+
)
686677
).max(-2)[0]
687678
# take min over qvalue nets
688679
next_state_value = state_action_value.min(0)[0]
@@ -739,14 +730,13 @@ def cql_loss(self, tensordict: TensorDictBase) -> Tuple[Tensor, dict]:
739730
"This could be caused by calling cql_loss method before q_loss method."
740731
)
741732

742-
random_actions_tensor = (
743-
torch.FloatTensor(
744-
tensordict.shape[0] * self.num_random,
733+
random_actions_tensor = pred_q1.new_empty(
734+
(
735+
*tensordict.shape[:-1],
736+
tensordict.shape[-1] * self.num_random,
745737
tensordict[self.tensor_keys.action].shape[-1],
746738
)
747-
.uniform_(-1, 1)
748-
.to(tensordict.device)
749-
)
739+
).uniform_(-1, 1)
750740
curr_actions_td, curr_log_pis = self._get_policy_actions(
751741
tensordict.copy(),
752742
self.actor_network_params,
@@ -833,31 +823,31 @@ def filter_and_repeat(name, x):
833823
q_new[0] - new_log_pis.detach().unsqueeze(-1),
834824
q_curr[0] - curr_log_pis.detach().unsqueeze(-1),
835825
],
836-
1,
826+
-1,
837827
)
838828
cat_q2 = torch.cat(
839829
[
840830
q_random[1] - random_density,
841831
q_new[1] - new_log_pis.detach().unsqueeze(-1),
842832
q_curr[1] - curr_log_pis.detach().unsqueeze(-1),
843833
],
844-
1,
834+
-1,
845835
)
846836

847837
min_qf1_loss = (
848-
torch.logsumexp(cat_q1 / self.temperature, dim=1)
838+
torch.logsumexp(cat_q1 / self.temperature, dim=-1)
849839
* self.min_q_weight
850840
* self.temperature
851841
)
852842
min_qf2_loss = (
853-
torch.logsumexp(cat_q2 / self.temperature, dim=1)
843+
torch.logsumexp(cat_q2 / self.temperature, dim=-1)
854844
* self.min_q_weight
855845
* self.temperature
856846
)
857847

858848
# Subtract the log likelihood of data
859-
cql_q1_loss = min_qf1_loss - pred_q1 * self.min_q_weight
860-
cql_q2_loss = min_qf2_loss - pred_q2 * self.min_q_weight
849+
cql_q1_loss = min_qf1_loss.flatten() - pred_q1 * self.min_q_weight
850+
cql_q2_loss = min_qf2_loss.flatten() - pred_q2 * self.min_q_weight
861851

862852
# write cql losses in tensordict for alpha prime loss
863853
tensordict.set(self.tensor_keys.cql_q1_loss, cql_q1_loss)
@@ -1080,9 +1070,9 @@ def __init__(
10801070
self.loss_function = loss_function
10811071
if action_space is None:
10821072
# infer from value net
1083-
try:
1073+
if hasattr(value_network, "action_space"):
10841074
action_space = value_network.spec
1085-
except AttributeError:
1075+
else:
10861076
# let's try with action_space then
10871077
try:
10881078
action_space = value_network.action_space
@@ -1205,8 +1195,6 @@ def value_loss(
12051195
with torch.no_grad():
12061196
td_error = (pred_val_index - target_value).pow(2)
12071197
td_error = td_error.unsqueeze(-1)
1208-
if tensordict.device is not None:
1209-
td_error = td_error.to(tensordict.device)
12101198

12111199
tensordict.set(
12121200
self.tensor_keys.priority,

torchrl/objectives/crossq.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -495,23 +495,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
495495
To see what keys are expected in the input tensordict and what keys are expected as output, check the
496496
class's `"in_keys"` and `"out_keys"` attributes.
497497
"""
498-
shape = None
499-
if tensordict.ndimension() > 1:
500-
shape = tensordict.shape
501-
tensordict_reshape = tensordict.reshape(-1)
502-
else:
503-
tensordict_reshape = tensordict
504-
505-
loss_qvalue, value_metadata = self.qvalue_loss(tensordict_reshape)
506-
loss_actor, metadata_actor = self.actor_loss(tensordict_reshape)
498+
loss_qvalue, value_metadata = self.qvalue_loss(tensordict)
499+
loss_actor, metadata_actor = self.actor_loss(tensordict)
507500
loss_alpha = self.alpha_loss(log_prob=metadata_actor["log_prob"])
508-
tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"])
501+
tensordict.set(self.tensor_keys.priority, value_metadata["td_error"])
509502
if loss_actor.shape != loss_qvalue.shape:
510503
raise RuntimeError(
511504
f"Losses shape mismatch: {loss_actor.shape} and {loss_qvalue.shape}"
512505
)
513-
if shape:
514-
tensordict.update(tensordict_reshape.view(shape))
515506
entropy = -metadata_actor["log_prob"]
516507
out = {
517508
"loss_actor": loss_actor,

torchrl/objectives/iql.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -373,16 +373,9 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
373373

374374
@dispatch
375375
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
376-
shape = None
377-
if tensordict.ndimension() > 1:
378-
shape = tensordict.shape
379-
tensordict_reshape = tensordict.reshape(-1)
380-
else:
381-
tensordict_reshape = tensordict
382-
383-
loss_actor, metadata = self.actor_loss(tensordict_reshape)
384-
loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict_reshape)
385-
loss_value, metadata_value = self.value_loss(tensordict_reshape)
376+
loss_actor, metadata = self.actor_loss(tensordict)
377+
loss_qvalue, metadata_qvalue = self.qvalue_loss(tensordict)
378+
loss_value, metadata_value = self.value_loss(tensordict)
386379
metadata.update(metadata_qvalue)
387380
metadata.update(metadata_value)
388381

@@ -392,13 +385,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
392385
raise RuntimeError(
393386
f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}"
394387
)
395-
tensordict_reshape.set(
388+
tensordict.set(
396389
self.tensor_keys.priority, metadata.pop("td_error").detach().max(0).values
397390
)
398-
if shape:
399-
tensordict.update(tensordict_reshape.view(shape))
400-
401-
entropy = -tensordict_reshape.get(self.tensor_keys.log_prob).detach()
391+
entropy = -tensordict.get(self.tensor_keys.log_prob).detach()
402392
out = {
403393
"loss_actor": loss_actor,
404394
"loss_qvalue": loss_qvalue,

torchrl/objectives/sac.py

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -577,30 +577,21 @@ def out_keys(self, values):
577577

578578
@dispatch
579579
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
580-
shape = None
581-
if tensordict.ndimension() > 1:
582-
shape = tensordict.shape
583-
tensordict_reshape = tensordict.reshape(-1)
584-
else:
585-
tensordict_reshape = tensordict
586-
587580
if self._version == 1:
588-
loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict_reshape)
589-
loss_value, _ = self._value_loss(tensordict_reshape)
581+
loss_qvalue, value_metadata = self._qvalue_v1_loss(tensordict)
582+
loss_value, _ = self._value_loss(tensordict)
590583
else:
591-
loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict_reshape)
584+
loss_qvalue, value_metadata = self._qvalue_v2_loss(tensordict)
592585
loss_value = None
593-
loss_actor, metadata_actor = self._actor_loss(tensordict_reshape)
586+
loss_actor, metadata_actor = self._actor_loss(tensordict)
594587
loss_alpha = self._alpha_loss(log_prob=metadata_actor["log_prob"])
595-
tensordict_reshape.set(self.tensor_keys.priority, value_metadata["td_error"])
588+
tensordict.set(self.tensor_keys.priority, value_metadata["td_error"])
596589
if (loss_actor.shape != loss_qvalue.shape) or (
597590
loss_value is not None and loss_actor.shape != loss_value.shape
598591
):
599592
raise RuntimeError(
600593
f"Losses shape mismatch: {loss_actor.shape}, {loss_qvalue.shape} and {loss_value.shape}"
601594
)
602-
if shape:
603-
tensordict.update(tensordict_reshape.view(shape))
604595
entropy = -metadata_actor["log_prob"]
605596
out = {
606597
"loss_actor": loss_actor,
@@ -1158,26 +1149,17 @@ def in_keys(self, values):
11581149

11591150
@dispatch
11601151
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
1161-
shape = None
1162-
if tensordict.ndimension() > 1:
1163-
shape = tensordict.shape
1164-
tensordict_reshape = tensordict.reshape(-1)
1165-
else:
1166-
tensordict_reshape = tensordict
1167-
1168-
loss_value, metadata_value = self._value_loss(tensordict_reshape)
1169-
loss_actor, metadata_actor = self._actor_loss(tensordict_reshape)
1152+
loss_value, metadata_value = self._value_loss(tensordict)
1153+
loss_actor, metadata_actor = self._actor_loss(tensordict)
11701154
loss_alpha = self._alpha_loss(
11711155
log_prob=metadata_actor["log_prob"],
11721156
)
11731157

1174-
tensordict_reshape.set(self.tensor_keys.priority, metadata_value["td_error"])
1158+
tensordict.set(self.tensor_keys.priority, metadata_value["td_error"])
11751159
if loss_actor.shape != loss_value.shape:
11761160
raise RuntimeError(
11771161
f"Losses shape mismatch: {loss_actor.shape}, and {loss_value.shape}"
11781162
)
1179-
if shape:
1180-
tensordict.update(tensordict_reshape.view(shape))
11811163
entropy = -metadata_actor["log_prob"]
11821164
out = {
11831165
"loss_actor": loss_actor,

0 commit comments

Comments
 (0)