Skip to content

[BUG] Memory leak? #845

@vmoens

Description

@vmoens

Describe the bug

I’m addressing a “memory-leak” issue, but I’m not sure it’s a real memory leak
With pong the mem on the GPU i use for data collection keeps increasing.
The strange thing is that it correlates with the performance of the training: the better the training, the higher the mem consumption. The obvious explanation (not the full story) is that better perf <=> longer trajectories. Hence for some reason, longer trajs cause the memory to increase.
A few things could explain that:

  • some transform has indeed a memory leak that gets cleared by its reset method (unlikely)
  • the dataloader has a similar leak that gets cleared when calling reset (unlikely)
  • the most likely to me: the split_traj option causes this. We essentially pad the values to fit all the trajs in a [B x max_T] tensordict, where max_T is the maximum length of the trajectories collected. Now imagine you have 8 workers and a batch size of 128 elts per worker. 7 workers collect trajectories all < 10 steps for a batch of length 128 (ie 7 x 128 // 10 = 100 small trajectories), and one of them collects one long trajectory of length 128. The split_trajs will deliver a batch B=101 and a max_T=128 but 90% of the values will be zeros.

Possible solutions

The main thing that worried me and made me use this split traj was using different trajectories sequentially may break some algos.
From my experiments with the advantage functions (TD0, TDLambda, GAE) only TDLambda suffers from this and it's likely that it is because we're not using the done flag appropriately.

  • Using nestedtensors instead of padding (cc @matteobettini)
  • Keep split_trajs but turn it off by default. Fix the value functions to make them work in this context.
  • Refactor the dataloader devices: right now we can choose on which device will the env sit, and on which will the policy. Not sure that really makes sense: when will the policy be so big that we can't transform data in the env with it? The logic was that by default the data collected would sit on the device of the env, not the policy (to avoid that long rollout fill the GPU, one can put the env on CPU and get the data there). What we could do instead is: policy and env are on device and the passing_device or else is the device where the data is dumped at each iteration.

Some more context

I tried using gc.collect() but with the pong example I was running it didn't change anything.

@albertbou92 I know you had a similar issue, interested in having your perspective on this.
@ShahRutav I believe that in your case split_trajs does not have an impact so I doubt that it is the cause of the problem. I'll keep digging

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions