diff --git a/torchrl/collectors/collectors.py b/torchrl/collectors/collectors.py index 0d5443b22b4..afd8ae61765 100644 --- a/torchrl/collectors/collectors.py +++ b/torchrl/collectors/collectors.py @@ -398,6 +398,7 @@ class SyncDataCollector(DataCollectorBase): policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection. @@ -599,13 +600,26 @@ def __init__( self.total_frames = total_frames self.reset_at_each_iter = reset_at_each_iter self.init_random_frames = init_random_frames + if ( + init_random_frames is not None + and init_random_frames % frames_per_batch != 0 + and RL_WARNINGS + ): + warnings.warn( + f"init_random_frames ({init_random_frames}) is not exactly a multiple of frames_per_batch ({frames_per_batch}), " + f" this results in more init_random_frames than requested" + f" ({-(-init_random_frames // frames_per_batch) * frames_per_batch})." + "To silence this message, set the environment variable RL_WARNINGS to False." + ) + self.postproc = postproc if self.postproc is not None and hasattr(self.postproc, "to"): self.postproc.to(self.storing_device) if frames_per_batch % self.n_env != 0 and RL_WARNINGS: warnings.warn( - f"frames_per_batch {frames_per_batch} is not exactly divisible by the number of batched environments {self.n_env}, " - f" this results in more frames_per_batch per iteration that requested." + f"frames_per_batch ({frames_per_batch}) is not exactly divisible by the number of batched environments ({self.n_env}), " + f" this results in more frames_per_batch per iteration that requested" + f" ({-(-frames_per_batch // self.n_env) * self.n_env})." "To silence this message, set the environment variable RL_WARNINGS to False." ) self.requested_frames_per_batch = frames_per_batch @@ -1026,6 +1040,7 @@ class _MultiDataCollector(DataCollectorBase): policy is ignored before it is called. This feature is mainly intended to be used in offline/model-based settings, where a batch of random trajectories can be used to initialize training. + If provided, it will be rounded up to the closest multiple of frames_per_batch. Defaults to ``None`` (i.e. no random frames). reset_at_each_iter (bool, optional): Whether environments should be reset at the beginning of a batch collection.