Skip to content

Commit 74d259d

Browse files
authored
Revert "[Doc] Refactor DDPG and DQN tutos to narrow the scope (#979)"
This reverts commit c3765cf.
1 parent c3765cf commit 74d259d

File tree

29 files changed

+5422
-1470
lines changed

29 files changed

+5422
-1470
lines changed
-246 KB
Binary file not shown.

docs/source/_static/js/theme.js

Lines changed: 3822 additions & 2 deletions
Large diffs are not rendered by default.

docs/source/reference/data.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
218218
Utils
219219
-----
220220

221-
.. currentmodule:: torchrl.data
221+
.. currentmodule:: torchrl.data.datasets
222222

223223
.. autosummary::
224224
:toctree: generated/

docs/source/reference/envs.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ provides more information on how to design a custom environment from scratch.
114114
EnvBase
115115
GymLikeEnv
116116
EnvMetaData
117+
Specs
117118

118119
Vectorized envs
119120
---------------

docs/source/reference/modules.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ TensorDict modules
3232

3333
Hooks
3434
-----
35-
.. currentmodule:: torchrl.modules
35+
.. currentmodule:: torchrl.modules.tensordict_module.actors
3636

3737
.. autosummary::
3838
:toctree: generated/

docs/source/reference/objectives.rst

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@ The main characteristics of TorchRL losses are:
1616
method will receive a tensordict as input that contains all the necessary
1717
information to return a loss value.
1818
- They output a :class:`tensordict.TensorDict` instance with the loss values
19-
written under a ``"loss_<smth>"`` where ``smth`` is a string describing the
19+
written under a ``"loss_<smth>`` where ``smth`` is a string describing the
2020
loss. Additional keys in the tensordict may be useful metrics to log during
2121
training time.
2222
.. note::
2323
The reason we return independent losses is to let the user use a different
2424
optimizer for different sets of parameters for instance. Summing the losses
25-
can be simply done via
26-
27-
>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))
25+
can be simply done via ``sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")``.
2826

2927
Training value functions
3028
------------------------
@@ -218,5 +216,5 @@ Utils
218216
next_state_value
219217
SoftUpdate
220218
HardUpdate
221-
ValueEstimators
219+
ValueFunctions
222220
default_value_kwargs

docs/source/reference/trainers.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"
7373
- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept
7474
a :obj:`TensorDict` object as input and update it given some strategy.
7575
Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization
76-
constants update), data subsampling (:class:`torchrl.trainers.BatchSubSampler`) and such.
76+
constants update), data subsampling (:doc:`BatchSubSampler`) and such.
7777

7878
- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger
7979
some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward

test/test_trainer.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,11 @@ class MockingLossModule(nn.Module):
8989

9090
def mocking_trainer(file=None, optimizer=_mocking_optim) -> Trainer:
9191
trainer = Trainer(
92-
collector=MockingCollector(),
93-
total_frames=None,
94-
frame_skip=None,
95-
optim_steps_per_batch=None,
92+
MockingCollector(),
93+
*[
94+
None,
95+
]
96+
* 2,
9697
loss_module=MockingLossModule(),
9798
optimizer=optimizer,
9899
save_trainer_file=file,
@@ -861,7 +862,7 @@ def test_recorder(self, N=8):
861862
with tempfile.TemporaryDirectory() as folder:
862863
logger = TensorboardLogger(exp_name=folder)
863864

864-
environment = transformed_env_constructor(
865+
recorder = transformed_env_constructor(
865866
args,
866867
video_tag="tmp",
867868
norm_obs_only=True,
@@ -873,7 +874,7 @@ def test_recorder(self, N=8):
873874
record_frames=args.record_frames,
874875
frame_skip=args.frame_skip,
875876
policy_exploration=None,
876-
environment=environment,
877+
recorder=recorder,
877878
record_interval=args.record_interval,
878879
)
879880
trainer = mocking_trainer()
@@ -935,7 +936,7 @@ def _make_recorder_and_trainer(tmpdirname):
935936
raise NotImplementedError
936937
trainer = mocking_trainer(file)
937938

938-
environment = transformed_env_constructor(
939+
recorder = transformed_env_constructor(
939940
args,
940941
video_tag="tmp",
941942
norm_obs_only=True,
@@ -947,7 +948,7 @@ def _make_recorder_and_trainer(tmpdirname):
947948
record_frames=args.record_frames,
948949
frame_skip=args.frame_skip,
949950
policy_exploration=None,
950-
environment=environment,
951+
recorder=recorder,
951952
record_interval=args.record_interval,
952953
)
953954
recorder.register(trainer)

torchrl/data/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
from . import datasets
76
from .postprocs import MultiStep
87
from .replay_buffers import (
98
LazyMemmapStorage,

torchrl/data/datasets/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .d4rl import D4RLExperienceReplay
2-
from .openml import OpenMLExperienceReplay

0 commit comments

Comments
 (0)