Skip to content

Commit 770a87d

Browse files
author
Vincent Moens
committed
[Feature] TensorDictMap Query module
ghstack-source-id: a118271 Pull Request resolved: #2305
1 parent 1713b93 commit 770a87d

File tree

4 files changed

+263
-2
lines changed

4 files changed

+263
-2
lines changed

test/test_storage_map.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99

1010
import torch
1111

12-
from torchrl.data.map import BinaryToDecimal, RandomProjectionHash, SipHash
12+
from tensordict import TensorDict
13+
from torchrl.data.map import BinaryToDecimal, QueryModule, RandomProjectionHash, SipHash
1314

1415
_has_gym = importlib.util.find_spec("gymnasium", None) or importlib.util.find_spec(
1516
"gym", None
@@ -51,6 +52,69 @@ def test_randomprojection_hash(self, n_components, scale):
5152
assert y.unique().numel() == y.numel()
5253

5354

55+
class TestQuery:
56+
def test_query_construct(self):
57+
query_module = QueryModule(
58+
in_keys=[(("key1",),), (("another",), "key2")],
59+
index_key=("some", ("_index",)),
60+
hash_module=SipHash(),
61+
clone=False,
62+
)
63+
assert not query_module.clone
64+
assert query_module.in_keys == ["key1", ("another", "key2")]
65+
assert query_module.index_key == ("some", "_index")
66+
assert isinstance(query_module.hash_module, dict)
67+
assert isinstance(
68+
query_module.aggregator,
69+
type(query_module.hash_module[query_module.in_keys[0]]),
70+
)
71+
query_module = QueryModule(
72+
in_keys=[(("key1",),), (("another",), "key2")],
73+
index_key=("some", ("_index",)),
74+
hash_module=SipHash(),
75+
clone=False,
76+
aggregator=SipHash(),
77+
)
78+
# assert not isinstance(query_module.aggregator is not query_module.hash_module[0]
79+
assert isinstance(query_module.aggregator, SipHash)
80+
query_module = QueryModule(
81+
in_keys=[(("key1",),), (("another",), "key2")],
82+
index_key=("some", ("_index",)),
83+
hash_module=[SipHash(), SipHash()],
84+
clone=False,
85+
)
86+
# assert query_module.aggregator is not query_module.hash_module[0]
87+
assert isinstance(query_module.aggregator, SipHash)
88+
89+
@pytest.mark.parametrize("index_key", ["index", ("another", "index")])
90+
@pytest.mark.parametrize("clone", [True, False])
91+
def test_query(self, clone, index_key):
92+
query_module = QueryModule(
93+
in_keys=["key1", "key2"],
94+
index_key=index_key,
95+
hash_module=SipHash(),
96+
clone=clone,
97+
)
98+
99+
query = TensorDict(
100+
{
101+
"key1": torch.Tensor([[1], [1], [1], [2]]),
102+
"key2": torch.Tensor([[3], [3], [2], [3]]),
103+
},
104+
batch_size=(4,),
105+
)
106+
res = query_module(query)
107+
if clone:
108+
assert res is not query
109+
else:
110+
assert res is query
111+
assert index_key in res
112+
113+
assert res[index_key][0] == res[index_key][1]
114+
for i in range(1, 3):
115+
assert res[index_key][i].item() != res[index_key][i + 1].item()
116+
117+
54118
if __name__ == "__main__":
55119
args, unknown = argparse.ArgumentParser().parse_known_args()
56120
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

torchrl/data/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from .map import BinaryToDecimal, RandomProjectionHash, SipHash
6+
from .map import BinaryToDecimal, HashToInt, QueryModule, RandomProjectionHash, SipHash
77
from .postprocs import MultiStep
88
from .replay_buffers import (
99
Flat2TED,

torchrl/data/map/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from .hash import BinaryToDecimal, RandomProjectionHash, SipHash
7+
from .query import HashToInt, QueryModule

torchrl/data/map/query.py

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
from copy import deepcopy
7+
from typing import Any, Callable, Dict, List, Mapping, TypeVar
8+
9+
import torch
10+
import torch.nn as nn
11+
from tensordict import NestedKey, TensorDictBase
12+
from tensordict.nn.common import TensorDictModuleBase
13+
from torchrl._utils import logger as torchrl_logger
14+
from torchrl.data.map import SipHash
15+
16+
K = TypeVar("K")
17+
V = TypeVar("V")
18+
19+
20+
class HashToInt(nn.Module):
21+
"""Converts a hash value to an integer that can be used for indexing a contiguous storage."""
22+
23+
def __init__(self):
24+
super().__init__()
25+
self._index_to_index = {}
26+
27+
def __call__(self, key: torch.Tensor, extend: bool = False) -> torch.Tensor:
28+
result = []
29+
if extend:
30+
for _item in key.tolist():
31+
result.append(
32+
self._index_to_index.setdefault(_item, len(self._index_to_index))
33+
)
34+
else:
35+
for _item in key.tolist():
36+
result.append(
37+
self._index_to_index.get(_item, len(self._index_to_index))
38+
)
39+
return torch.tensor(result, device=key.device, dtype=key.dtype)
40+
41+
def state_dict(self) -> Dict[str, torch.Tensor]:
42+
values = torch.tensor(self._index_to_index.values())
43+
keys = torch.tensor(self._index_to_index.keys())
44+
return {"keys": keys, "values": values}
45+
46+
def load_state_dict(
47+
self, state_dict: Mapping[str, Any], strict: bool = True, assign: bool = False
48+
):
49+
keys = state_dict["keys"]
50+
values = state_dict["values"]
51+
self._index_to_index = {
52+
key: val for key, val in zip(keys.tolist(), values.tolist())
53+
}
54+
55+
56+
class QueryModule(TensorDictModuleBase):
57+
"""A Module to generate compatible indices for storage.
58+
59+
A module that queries a storage and return required index of that storage.
60+
Currently, it only outputs integer indices (torch.int64).
61+
62+
Args:
63+
in_keys (list of NestedKeys): keys of the input tensordict that
64+
will be used to generate the hash value.
65+
index_key (NestedKey): the output key where the index value will be written.
66+
Defaults to ``"_index"``.
67+
68+
Keyword Args:
69+
hash_key (NestedKey): the output key where the hash value will be written.
70+
Defaults to ``"_hash"``.
71+
hash_module (Callable[[Any], int] or a list of these, optional): a hash
72+
module similar to :class:`~tensordict.nn.SipHash` (default).
73+
If a list of callables is provided, its length must equate the number of in_keys.
74+
hash_to_int (Callable[[int], int], optional): a stateful function that
75+
maps a hash value to a non-negative integer corresponding to an index in a
76+
storage. Defaults to :class:`~torchrl.data.map.HashToInt`.
77+
aggregator (Callable[[int], int], optional): a hash function to group multiple hashes
78+
together. This argument should only be passed when there is more than one ``in_keys``.
79+
If a single ``hash_module`` is provided but no aggregator is passed, it will take
80+
the value of the hash_module. If no ``hash_module`` or a list of ``hash_modules`` is
81+
provided but no aggregator is passed, it will default to ``SipHash``.
82+
clone (bool, optional): if ``True``, a shallow clone of the input TensorDict will be
83+
returned. This can be used to retrieve the integer index within the storage,
84+
corresponding to a given input tensordict.
85+
Defaults to ``False``.
86+
d
87+
Examples:
88+
>>> query_module = QueryModule(
89+
... in_keys=["key1", "key2"],
90+
... index_key="index",
91+
... hash_module=SipHash(),
92+
... )
93+
>>> query = TensorDict(
94+
... {
95+
... "key1": torch.Tensor([[1], [1], [1], [2]]),
96+
... "key2": torch.Tensor([[3], [3], [2], [3]]),
97+
... "other": torch.randn(4),
98+
... },
99+
... batch_size=(4,),
100+
... )
101+
>>> res = query_module(query)
102+
>>> # The first two pairs of key1 and key2 match
103+
>>> assert res["index"][0] == res["index"][1]
104+
>>> # The last three pairs of key1 and key2 have at least one mismatching value
105+
>>> assert res["index"][1] != res["index"][2]
106+
>>> assert res["index"][2] != res["index"][3]
107+
"""
108+
109+
def __init__(
110+
self,
111+
in_keys: List[NestedKey],
112+
index_key: NestedKey = "_index",
113+
hash_key: NestedKey = "_hash",
114+
*,
115+
hash_module: Callable[[Any], int] | List[Callable[[Any], int]] | None = None,
116+
hash_to_int: Callable[[int], int] | None = None,
117+
aggregator: Callable[[Any], int] = None,
118+
clone: bool = False,
119+
):
120+
if len(in_keys) == 0:
121+
raise ValueError("`in_keys` cannot be empty.")
122+
in_keys = in_keys if isinstance(in_keys, List) else [in_keys]
123+
124+
super().__init__()
125+
in_keys = self.in_keys = in_keys
126+
self.out_keys = [index_key, hash_key]
127+
index_key = self.out_keys[0]
128+
self.hash_key = self.out_keys[1]
129+
130+
if aggregator is not None and len(self.in_keys) == 1:
131+
torchrl_logger.warn(
132+
"An aggregator was provided but there is only one in-key to be read. "
133+
"This module will be ignored."
134+
)
135+
elif aggregator is None:
136+
if hash_module is not None and not isinstance(hash_module, list):
137+
aggregator = hash_module
138+
else:
139+
aggregator = SipHash()
140+
if hash_module is None:
141+
hash_module = [SipHash() for _ in range(len(self.in_keys))]
142+
elif not isinstance(hash_module, list):
143+
try:
144+
hash_module = [
145+
deepcopy(hash_module) if len(self.in_keys) > 1 else hash_module
146+
for _ in range(len(self.in_keys))
147+
]
148+
except Exception as err:
149+
raise RuntimeError(
150+
"failed to deepcopy the hash module. Please provide a list of hash modules instead."
151+
) from err
152+
elif len(hash_module) != len(self.in_keys):
153+
raise ValueError(
154+
"The number of hash_modules must match the number of in_keys. "
155+
f"Got {len(hash_module)} hash modules but {len(in_keys)} in_keys."
156+
)
157+
if hash_to_int is None:
158+
hash_to_int = HashToInt()
159+
160+
self.aggregator = aggregator
161+
self.hash_module = dict(zip(self.in_keys, hash_module))
162+
self.hash_to_int = hash_to_int
163+
164+
self.index_key = index_key
165+
self.clone = clone
166+
167+
def forward(
168+
self,
169+
tensordict: TensorDictBase,
170+
extend: bool = True,
171+
write_hash: bool = True,
172+
) -> TensorDictBase:
173+
hash_values = []
174+
175+
for k in self.in_keys:
176+
hash_values.append(self.hash_module[k](tensordict.get(k)))
177+
if len(self.in_keys) > 1:
178+
hash_values = torch.stack(
179+
hash_values,
180+
dim=-1,
181+
)
182+
hash_values = self.aggregator(hash_values)
183+
else:
184+
hash_values = hash_values[0]
185+
186+
td_hash_value = self.hash_to_int(hash_values, extend=extend)
187+
188+
if self.clone:
189+
output = tensordict.copy()
190+
else:
191+
output = tensordict
192+
193+
output.set(self.index_key, td_hash_value)
194+
if write_hash:
195+
output.set(self.hash_key, hash_values)
196+
return output

0 commit comments

Comments
 (0)