-
Notifications
You must be signed in to change notification settings - Fork 413
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
Hello, I'm performing experiments that use a relatively small number of parallel environments (8-16). Using the PongNoFrameskip-v4 environment with no wrappers, it seems that TorchRL is 4-5x slower than Gym's AsyncVectorEnv (2600 vs 11000 FPS) with a random policy. Given the throughput results in Table 2 of the paper, I would expect comparable performance. Am I setting up the environments incorrectly?
To Reproduce
This is a very simple adaptation of the script in examples/distributed/single_machine/generic.py
. Although it's not shown here, I observe similar performance with ParallelEnv
and a synchronous collector.
import time
from argparse import ArgumentParser
import torch
import tqdm
from torchrl.collectors.collectors import (
MultiaSyncDataCollector,
MultiSyncDataCollector,
RandomPolicy,
)
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv
import gymnasium as gym
parser = ArgumentParser()
parser.add_argument(
"--num_workers", default=8, type=int, help="Number of workers in each node."
)
parser.add_argument(
"--total_frames",
default=500_000,
type=int,
help="Total number of frames collected by the collector. Must be "
"divisible by the product of nodes and workers.",
)
parser.add_argument(
"--env",
default="PongNoFrameskip-v4",
help="Gym environment to be run.",
)
if __name__ == "__main__":
args = parser.parse_args()
num_workers = args.num_workers
frames_per_batch = 10*args.num_workers
# Test asynchronous gym collector
env = gym.vector.AsyncVectorEnv([lambda: gym.make(args.env) for _ in range(num_workers)])
env.reset()
global_step = 0
start = time.time()
for _ in range(args.total_frames//num_workers):
global_step += num_workers
env.step(env.action_space.sample())
stop = time.time()
if global_step % int(num_workers*1_000) == 0:
print('FPS:', global_step / (stop - start))
env.close()
# Test multiprocess TorchRL collector
device = 'cuda:0'
make_env = EnvCreator(lambda: GymEnv(args.env, device=device))
action_spec = make_env().action_spec
collector = MultiaSyncDataCollector(
[make_env] * num_workers,
policy=RandomPolicy(action_spec),
total_frames=args.total_frames,
frames_per_batch=frames_per_batch,
devices=device,
storing_devices=device,
)
counter = 0
for i, data in enumerate(collector):
if i == 10:
pbar = tqdm.tqdm(total=collector.total_frames)
t0 = time.time()
if i >= 10:
counter += data.numel()
pbar.update(data.numel())
pbar.set_description(f"data shape: {data.shape}, data device: {data.device}")
collector.shutdown()
t1 = time.time()
print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
exit()
System info
TorchRL installed via pip (v0.1.1)
import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)
None 1.22.0 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0] linux
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working