Skip to content

[BUG] Throughput vs Gym AsyncVectorEnv #1325

@ShaneFlandermeyer

Description

@ShaneFlandermeyer

Describe the bug

Hello, I'm performing experiments that use a relatively small number of parallel environments (8-16). Using the PongNoFrameskip-v4 environment with no wrappers, it seems that TorchRL is 4-5x slower than Gym's AsyncVectorEnv (2600 vs 11000 FPS) with a random policy. Given the throughput results in Table 2 of the paper, I would expect comparable performance. Am I setting up the environments incorrectly?

To Reproduce

This is a very simple adaptation of the script in examples/distributed/single_machine/generic.py. Although it's not shown here, I observe similar performance with ParallelEnv and a synchronous collector.

import time
from argparse import ArgumentParser

import torch
import tqdm

from torchrl.collectors.collectors import (
    MultiaSyncDataCollector,
    MultiSyncDataCollector,
    RandomPolicy,
)
from torchrl.envs import EnvCreator
from torchrl.envs.libs.gym import GymEnv
import gymnasium as gym

parser = ArgumentParser()
parser.add_argument(
    "--num_workers", default=8, type=int, help="Number of workers in each node."
)
parser.add_argument(
    "--total_frames",
    default=500_000,
    type=int,
    help="Total number of frames collected by the collector. Must be "
    "divisible by the product of nodes and workers.",
)
parser.add_argument(
    "--env",
    default="PongNoFrameskip-v4",
    help="Gym environment to be run.",
)
if __name__ == "__main__":
    args = parser.parse_args()
    num_workers = args.num_workers
    frames_per_batch = 10*args.num_workers
    
    # Test asynchronous gym collector
    env = gym.vector.AsyncVectorEnv([lambda: gym.make(args.env) for _ in range(num_workers)])
    env.reset()
    global_step = 0
    start = time.time()
    for _ in range(args.total_frames//num_workers):
        global_step += num_workers
        env.step(env.action_space.sample())
        stop = time.time()
        if global_step % int(num_workers*1_000) == 0:
            print('FPS:', global_step / (stop - start))
    env.close()

    # Test multiprocess TorchRL collector
    device = 'cuda:0'
    make_env = EnvCreator(lambda: GymEnv(args.env, device=device))
    action_spec = make_env().action_spec
    collector = MultiaSyncDataCollector(
        [make_env] * num_workers,
        policy=RandomPolicy(action_spec),
        total_frames=args.total_frames,
        frames_per_batch=frames_per_batch,
        devices=device,
        storing_devices=device,
    )
    counter = 0
    for i, data in enumerate(collector):
        if i == 10:
            pbar = tqdm.tqdm(total=collector.total_frames)
            t0 = time.time()
        if i >= 10:
            counter += data.numel()
            pbar.update(data.numel())
            pbar.set_description(f"data shape: {data.shape}, data device: {data.device}")
    collector.shutdown()
    t1 = time.time()
    print(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
    exit()

System info

TorchRL installed via pip (v0.1.1)

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

None 1.22.0 3.10.11 (main, Apr 20 2023, 19:02:41) [GCC 11.2.0] linux

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions