Skip to content

[BUG] Collectors of batched environemnts return more frames than requested #846

@matteobettini

Description

@matteobettini

Describe the bug

The collectors currently force the actual collected frames_per_batch to be divisible by the number of batched environments (which can be collector workers or parallel workers)(if looking at #828 this could also be vectorized dimensions in the batch size).

This leads to the user feeding a desired frames_per_batch at collector creation, and actually getting more frames than requested. As you can see in the following example:

gym_env = lambda: GymEnv("Pendulum-v1", device="cpu")
gym_parallel_env = lambda: ParallelEnv(10, gym_env)

pendulum_policy = TensorDictModule(
    nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
)

coll = SyncDataCollector(
    gym_parallel_env,
    pendulum_policy,
    total_frames=20000,
    max_frames_per_traj=5,
    frames_per_batch=145,
    split_trajs=False,
)

for data in coll:
    print("Ending", data) # batch_size=torch.Size([10, 15]) aka 150 frames
    break

Which is caused for example by code like this:

self.frames_per_batch = -(-frames_per_batch // self.n_env)

This behavior might be dangerous for some users which might think that at each iteration they are training on x frames and instead they are training on x+y frames.

Solutions

  1. Throw an error if the frames_per_batch is not divisible by the number of batched envs
  2. Throw a warning if the frames_per_batch is not divisible by the number of batched envs
  3. Find a way to return only the requested amount of frames through discarding some of the collected data

Metadata

Metadata

Assignees

Labels

Good first issueA good way to start hacking torchrl!bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions