-
Notifications
You must be signed in to change notification settings - Fork 413
Closed
Labels
Good first issueA good way to start hacking torchrl!A good way to start hacking torchrl!bugSomething isn't workingSomething isn't working
Description
Describe the bug
The collectors currently force the actual collected frames_per_batch
to be divisible by the number of batched environments (which can be collector workers or parallel workers)(if looking at #828 this could also be vectorized dimensions in the batch size).
This leads to the user feeding a desired frames_per_batch
at collector creation, and actually getting more frames than requested. As you can see in the following example:
gym_env = lambda: GymEnv("Pendulum-v1", device="cpu")
gym_parallel_env = lambda: ParallelEnv(10, gym_env)
pendulum_policy = TensorDictModule(
nn.Linear(3, 1), in_keys=["observation"], out_keys=["action"]
)
coll = SyncDataCollector(
gym_parallel_env,
pendulum_policy,
total_frames=20000,
max_frames_per_traj=5,
frames_per_batch=145,
split_trajs=False,
)
for data in coll:
print("Ending", data) # batch_size=torch.Size([10, 15]) aka 150 frames
break
Which is caused for example by code like this:
self.frames_per_batch = -(-frames_per_batch // self.n_env)
This behavior might be dangerous for some users which might think that at each iteration they are training on x frames and instead they are training on x+y frames.
Solutions
- Throw an error if the
frames_per_batch
is not divisible by the number of batched envs - Throw a warning if the
frames_per_batch
is not divisible by the number of batched envs - Find a way to return only the requested amount of frames through discarding some of the collected data
Metadata
Metadata
Assignees
Labels
Good first issueA good way to start hacking torchrl!A good way to start hacking torchrl!bugSomething isn't workingSomething isn't working