diff --git a/torchrl/data/replay_buffers/replay_buffers.py b/torchrl/data/replay_buffers/replay_buffers.py index bb7a56b6304..5d21d202eae 100644 --- a/torchrl/data/replay_buffers/replay_buffers.py +++ b/torchrl/data/replay_buffers/replay_buffers.py @@ -662,7 +662,7 @@ def __init__(self, *, priority_key: str = "td_error", **kw) -> None: super().__init__(**kw) self.priority_key = priority_key - def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: + def _get_priority_item(self, tensordict: TensorDictBase) -> float: if "_data" in tensordict.keys(): tensordict = tensordict.get("_data") @@ -682,6 +682,23 @@ def _get_priority(self, tensordict: TensorDictBase) -> Optional[torch.Tensor]: ) return priority + def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor: + if "_data" in tensordict.keys(): + tensordict = tensordict.get("_data") + + priority = tensordict.get(self.priority_key, None) + if priority is None: + return torch.tensor( + self._sampler.default_priority, + dtype=torch.float, + device=tensordict.device, + ).expand(tensordict.shape[0]) + + priority = priority.reshape(priority.shape[0], -1) + priority = _reduce(priority, self._sampler.reduction, dim=1) + + return priority + def add(self, data: TensorDictBase) -> int: if self._transform is not None: data = self._transform.inv(data) @@ -709,61 +726,50 @@ def add(self, data: TensorDictBase) -> int: self.update_tensordict_priority(data_add) return index - def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor: - if is_tensor_collection(tensordicts): - tensordicts = TensorDict( - {"_data": tensordicts}, - batch_size=tensordicts.batch_size[:1], - ) - if tensordicts.batch_dims > 1: - # we want the tensordict to have one dimension only. The batch size - # of the sampled tensordicts can be changed thereafter - if not isinstance(tensordicts, LazyStackedTensorDict): - tensordicts = tensordicts.clone(recurse=False) - else: - tensordicts = tensordicts.contiguous() - # we keep track of the batch size to reinstantiate it when sampling - if "_rb_batch_size" in tensordicts.keys(): - raise KeyError( - "conflicting key '_rb_batch_size'. Consider removing from data." - ) - shape = torch.tensor(tensordicts.batch_size[1:]).expand( - tensordicts.batch_size[0], tensordicts.batch_dims - 1 + def extend(self, tensordicts: TensorDictBase) -> torch.Tensor: + + tensordicts = TensorDict( + {"_data": tensordicts}, + batch_size=tensordicts.batch_size[:1], + ) + if tensordicts.batch_dims > 1: + # we want the tensordict to have one dimension only. The batch size + # of the sampled tensordicts can be changed thereafter + if not isinstance(tensordicts, LazyStackedTensorDict): + tensordicts = tensordicts.clone(recurse=False) + else: + tensordicts = tensordicts.contiguous() + # we keep track of the batch size to reinstantiate it when sampling + if "_rb_batch_size" in tensordicts.keys(): + raise KeyError( + "conflicting key '_rb_batch_size'. Consider removing from data." ) - tensordicts.set("_rb_batch_size", shape) - tensordicts.set( - "index", - torch.zeros( - tensordicts.shape, device=tensordicts.device, dtype=torch.int - ), + shape = torch.tensor(tensordicts.batch_size[1:]).expand( + tensordicts.batch_size[0], tensordicts.batch_dims - 1 ) - - if not is_tensor_collection(tensordicts): - stacked_td = torch.stack(tensordicts, 0) - else: - stacked_td = tensordicts + tensordicts.set("_rb_batch_size", shape) + tensordicts.set( + "index", + torch.zeros(tensordicts.shape, device=tensordicts.device, dtype=torch.int), + ) if self._transform is not None: - tensordicts = self._transform.inv(stacked_td.get("_data")) - stacked_td.set("_data", tensordicts) - if tensordicts.device is not None: - stacked_td = stacked_td.to(tensordicts.device) + data = self._transform.inv(tensordicts.get("_data")) + tensordicts.set("_data", data) + if data.device is not None: + tensordicts = tensordicts.to(data.device) - index = super()._extend(stacked_td) - self.update_tensordict_priority(stacked_td) + index = super()._extend(tensordicts) + self.update_tensordict_priority(tensordicts) return index def update_tensordict_priority(self, data: TensorDictBase) -> None: if not isinstance(self._sampler, PrioritizedSampler): return if data.ndim: - priority = torch.tensor( - [self._get_priority(td) for td in data], - dtype=torch.float, - device=data.device, - ) + priority = self._get_priority_vector(data) else: - priority = self._get_priority(data) + priority = self._get_priority_item(data) index = data.get("index") while index.shape != priority.shape: # reduce index @@ -1010,17 +1016,23 @@ def __call__(self, list_of_tds): return self.out -def _reduce(tensor: torch.Tensor, reduction: str): +def _reduce( + tensor: torch.Tensor, reduction: str, dim: Optional[int] = None +) -> Union[float, torch.Tensor]: """Reduces a tensor given the reduction method.""" if reduction == "max": - return tensor.max().item() + result = tensor.max(dim=dim) elif reduction == "min": - return tensor.min().item() + result = tensor.min(dim=dim) elif reduction == "mean": - return tensor.mean().item() + result = tensor.mean(dim=dim) elif reduction == "median": - return tensor.median().item() - raise NotImplementedError(f"Unknown reduction method {reduction}") + result = tensor.median(dim=dim) + else: + raise NotImplementedError(f"Unknown reduction method {reduction}") + if isinstance(result, tuple): + result = result[0] + return result.item() if dim is None else result def stack_tensors(list_of_tensor_iterators: List) -> Tuple[torch.Tensor]: