From 5725282b9340dc4b3efeb3b1fc2a99adc94446ba Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 3 Oct 2023 17:18:04 +0100 Subject: [PATCH 1/5] update Signed-off-by: Matteo Bettini --- torchrl/data/replay_buffers/replay_buffers.py | 44 ++++++++++++++----- 1 file changed, 32 insertions(+), 12 deletions(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index bb7a56b6304..02bb7b09e5d 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -662,7 +662,7 @@ def __init__(self, *, priority_key: str = "td_error", **kw) -> None: super().__init__(**kw) self.priority_key = priority_key - def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: + def _get_priority_item(self, tensordict: TensorDictBase) -> float: if "_data" in tensordict.keys(): tensordict = tensordict.get("_data") @@ -682,6 +682,23 @@ def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: ) return priority + def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor: + if "_data" in tensordict.keys(): + tensordict = tensordict.get("_data") + + priority = tensordict.get(self.priority_key, None) + if priority is None: + return torch.tensor( + [self._sampler.default_priority], + dtype=torch.float, + device=tensordict.device, + ).expand(tensordict.shape[0]) + + priority = priority.view(priority.shape[0], -1) + priority = _reduce(priority, self._sampler.reduction, dim=1) + + return priority + def add(self, data: TensorDictBase) -> int: if self._transform is not None: data = self._transform.inv(data) @@ -757,13 +774,16 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None: if not isinstance(self._sampler, PrioritizedSampler): return if data.ndim: - priority = torch.tensor( - [self._get_priority(td) for td in data], - dtype=torch.float, - device=data.device, - ) + if isinstance(data, LazyStackedTensorDict): + priority = torch.tensor( + [self._get_priority_item(td) for td in data], + dtype=torch.float, + device=data.device, + ) + else: + priority = self._get_priority_vector(data) else: - priority = self._get_priority(data) + priority = self._get_priority_item(data) index = data.get("index") while index.shape != priority.shape: # reduce index @@ -1010,16 +1030,16 @@ def __call__(self, list_of_tds): return self.out -def _reduce(tensor: torch.Tensor, reduction: str): +def _reduce(tensor: torch.Tensor, reduction: str, dim: Optional[int] = None): """Reduces a tensor given the reduction method.""" if reduction == "max": - return tensor.max().item() + return tensor.max().item() if dim is None else tensor.max(dim=dim)[0] elif reduction == "min": - return tensor.min().item() + return tensor.min().item() if dim is None else tensor.min(dim=dim)[0] elif reduction == "mean": - return tensor.mean().item() + return tensor.mean().item() if dim is None else tensor.mean(dim=dim) elif reduction == "median": - return tensor.median().item() + return tensor.median().item() if dim is None else tensor.median(dim=dim) raise NotImplementedError(f"Unknown reduction method {reduction}") From 81dcee7d20b9dde4c843c9789b1a45058ab5465c Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Tue, 3 Oct 2023 21:05:33 +0100 Subject: [PATCH 2/5] Update torchrl/data/replay_buffers/replay_buffers.py Co-authored-by: Vincent Moens --- torchrl/data/replay_buffers/replay_buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 02bb7b09e5d..4694fdaef80 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -689,7 +689,7 @@ def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor: priority = tensordict.get(self.priority_key, None) if priority is None: return torch.tensor( - [self._sampler.default_priority], + self._sampler.default_priority, dtype=torch.float, device=tensordict.device, ).expand(tensordict.shape[0]) From bb7423816213253e1346f424cb50bb11cc138f9e Mon Sep 17 00:00:00 2001 From: Matteo Bettini <55539777+matteobettini@users.noreply.github.com> Date: Tue, 3 Oct 2023 21:05:42 +0100 Subject: [PATCH 3/5] Update torchrl/data/replay_buffers/replay_buffers.py Co-authored-by: Vincent Moens --- torchrl/data/replay_buffers/replay_buffers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 4694fdaef80..4f0a85a6934 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -694,7 +694,7 @@ def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor: device=tensordict.device, ).expand(tensordict.shape[0]) - priority = priority.view(priority.shape[0], -1) + priority = priority.reshape(priority.shape[0], -1) priority = _reduce(priority, self._sampler.reduction, dim=1) return priority From 25eab5ac60b8e56ec622e043768dd08b6f628dec Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 3 Oct 2023 21:15:17 +0100 Subject: [PATCH 4/5] update Signed-off-by: Matteo Bettini --- torchrl/data/replay_buffers/replay_buffers.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 02bb7b09e5d..4fdc9ff83f7 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -1030,17 +1030,23 @@ def __call__(self, list_of_tds): return self.out -def _reduce(tensor: torch.Tensor, reduction: str, dim: Optional[int] = None): +def _reduce( + tensor: torch.Tensor, reduction: str, dim: Optional[int] = None +) -> Union[float, torch.Tensor]: """Reduces a tensor given the reduction method.""" if reduction == "max": - return tensor.max().item() if dim is None else tensor.max(dim=dim)[0] + result = tensor.max(dim=dim) elif reduction == "min": - return tensor.min().item() if dim is None else tensor.min(dim=dim)[0] + result = tensor.min(dim=dim) elif reduction == "mean": - return tensor.mean().item() if dim is None else tensor.mean(dim=dim) + result = tensor.mean(dim=dim) elif reduction == "median": - return tensor.median().item() if dim is None else tensor.median(dim=dim) - raise NotImplementedError(f"Unknown reduction method {reduction}") + result = tensor.median(dim=dim) + else: + raise NotImplementedError(f"Unknown reduction method {reduction}") + if isinstance(result, tuple): + result = result[0] + return result.item() if dim is None else result def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]: From 98514c0cc4f3cb1a1617a7f5138c2642f0c6743d Mon Sep 17 00:00:00 2001 From: Matteo Bettini Date: Tue, 3 Oct 2023 22:04:12 +0100 Subject: [PATCH 5/5] update Signed-off-by: Matteo Bettini --- torchrl/data/replay_buffers/replay_buffers.py | 76 ++++++++----------- 1 file changed, 31 insertions(+), 45 deletions(-) diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index 45d06386677..5d21d202eae 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -726,62 +726,48 @@ def add(self, data: TensorDictBase) -> int: self.update_tensordict_priority(data_add) return index - def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: - if is_tensor_collection(tensordicts): - tensordicts = TensorDict( - {"_data": tensordicts}, - batch_size=tensordicts.batch_size[:1], - ) - if tensordicts.batch_dims > 1: - # we want the tensordict to have one dimension only. The batch size - # of the sampled tensordicts can be changed thereafter - if not isinstance(tensordicts, LazyStackedTensorDict): - tensordicts = tensordicts.clone(recurse=False) - else: - tensordicts = tensordicts.contiguous() - # we keep track of the batch size to reinstantiate it when sampling - if "_rb_batch_size" in tensordicts.keys(): - raise KeyError( - "conflicting key '_rb_batch_size'. Consider removing from data." - ) - shape = torch.tensor(tensordicts.batch_size[1:]).expand( - tensordicts.batch_size[0], tensordicts.batch_dims - 1 + def extend(self, tensordicts: TensorDictBase) -> torch.Tensor: + + tensordicts = TensorDict( + {"_data": tensordicts}, + batch_size=tensordicts.batch_size[:1], + ) + if tensordicts.batch_dims > 1: + # we want the tensordict to have one dimension only. The batch size + # of the sampled tensordicts can be changed thereafter + if not isinstance(tensordicts, LazyStackedTensorDict): + tensordicts = tensordicts.clone(recurse=False) + else: + tensordicts = tensordicts.contiguous() + # we keep track of the batch size to reinstantiate it when sampling + if "_rb_batch_size" in tensordicts.keys(): + raise KeyError( + "conflicting key '_rb_batch_size'. Consider removing from data." ) - tensordicts.set("_rb_batch_size", shape) - tensordicts.set( - "index", - torch.zeros( - tensordicts.shape, device=tensordicts.device, dtype=torch.int - ), + shape = torch.tensor(tensordicts.batch_size[1:]).expand( + tensordicts.batch_size[0], tensordicts.batch_dims - 1 ) - - if not is_tensor_collection(tensordicts): - stacked_td = torch.stack(tensordicts, 0) - else: - stacked_td = tensordicts + tensordicts.set("_rb_batch_size", shape) + tensordicts.set( + "index", + torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.int), + ) if self._transform is not None: - tensordicts = self._transform.inv(stacked_td.get("_data")) - stacked_td.set("_data", tensordicts) - if tensordicts.device is not None: - stacked_td = stacked_td.to(tensordicts.device) + data = self._transform.inv(tensordicts.get("_data")) + tensordicts.set("_data", data) + if data.device is not None: + tensordicts = tensordicts.to(data.device) - index = super()._extend(stacked_td) - self.update_tensordict_priority(stacked_td) + index = super()._extend(tensordicts) + self.update_tensordict_priority(tensordicts) return index def update_tensordict_priority(self, data: TensorDictBase) -> None: if not isinstance(self._sampler, PrioritizedSampler): return if data.ndim: - if isinstance(data, LazyStackedTensorDict): - priority = torch.tensor( - [self._get_priority_item(td) for td in data], - dtype=torch.float, - device=data.device, - ) - else: - priority = self._get_priority_vector(data) + priority = self._get_priority_vector(data) else: priority = self._get_priority_item(data) index = data.get("index")