In prioritised tensordict replay buffers, the update_tensordict_priority
, method performs a for loop over the batch dimension
|
priority = torch.tensor( |
|
[self._get_priority(td) for td in data], |
|
dtype=torch.float, |
|
device=data.device, |
This causes significant slowdowns as this is the vectorised dimension used in the training pipelines and can get to really high sizes.
This method is called every time the buffer is extended or the priorities are updated.