Skip to content

[Performance] Prioritised TensorDict replay buffers use for loops over the batch dimension #1574

@matteobettini

Description

@matteobettini

In prioritised tensordict replay buffers, the update_tensordict_priority, method performs a for loop over the batch dimension

priority = torch.tensor(
[self._get_priority(td) for td in data],
dtype=torch.float,
device=data.device,

This causes significant slowdowns as this is the vectorised dimension used in the training pipelines and can get to really high sizes.

This method is called every time the buffer is extended or the priorities are updated.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions