Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
112 commits
Select commit Hold shift + click to select a range
9fbf53b
init
vmoens Mar 20, 2023
5488d4d
lint
vmoens Mar 20, 2023
93deeeb
amend
vmoens Mar 20, 2023
00b1e5a
Merge branch 'main' into ddpg_loss_tuto
vmoens Mar 21, 2023
f511020
dqn (1)
vmoens Mar 21, 2023
2586b74
amend
vmoens Mar 21, 2023
b330b16
edit training dqn
vmoens Mar 21, 2023
bfef8ee
dqn
vmoens Mar 21, 2023
972217a
amend
vmoens Mar 23, 2023
fb81fc3
empty
vmoens Mar 24, 2023
d137048
init
vmoens Mar 23, 2023
adad97d
init
vmoens Mar 23, 2023
ea20603
amend
vmoens Mar 24, 2023
e8bf4c1
amend
vmoens Mar 24, 2023
d178f93
value_estimate and sac init
vmoens Mar 24, 2023
efb57a8
temp
vmoens Mar 24, 2023
48c227a
tmp
vmoens Mar 27, 2023
1e34ef9
SAC
vmoens Mar 27, 2023
dd6ac56
Merge branch 'main' into add_adv_module_lossesd
vmoens Mar 27, 2023
01f1ae7
amend
vmoens Mar 27, 2023
acd0ec1
test
vmoens Mar 27, 2023
7855ef6
smooth deprecation
vmoens Mar 28, 2023
55361cc
amend
vmoens Mar 28, 2023
fd87457
amend
vmoens Mar 28, 2023
11946ee
Merge branch 'main' into add_adv_module_lossesd
vmoens Mar 28, 2023
d8ceb3d
Merge branch 'add_adv_module_lossesd' into ddpg_loss_tuto
vmoens Mar 28, 2023
f008034
amend
vmoens Mar 28, 2023
400cfd1
amend
vmoens Mar 28, 2023
64768b0
amend
vmoens Mar 28, 2023
f5550df
amend
vmoens Mar 28, 2023
218ab1a
amend
vmoens Mar 28, 2023
6cdcc8e
amend
vmoens Mar 28, 2023
b47dee2
amend
vmoens Mar 28, 2023
bbf97ad
Merge branch 'add_adv_module_lossesd' into ddpg_loss_tuto
vmoens Mar 28, 2023
e9bb239
amend
vmoens Mar 28, 2023
34469f2
differentiable=True
vmoens Mar 28, 2023
aae2bbe
differentiable=True
vmoens Mar 28, 2023
f6d2da4
Merge branch 'add_adv_module_lossesd' into ddpg_loss_tuto
vmoens Mar 28, 2023
c9c106b
amend
vmoens Mar 28, 2023
efe9b09
Merge branch 'main' into ddpg_loss_tuto
vmoens Mar 28, 2023
c957916
fix trainer
vmoens Mar 28, 2023
0300728
no grad
vmoens Mar 29, 2023
86915fe
init
vmoens Mar 29, 2023
d23af8b
tests
vmoens Mar 29, 2023
68c3442
empty commit
vmoens Mar 29, 2023
dec5c56
tests
vmoens Mar 29, 2023
7612098
tests
vmoens Mar 29, 2023
69eb921
amend
vmoens Mar 29, 2023
5d32e10
amend
vmoens Mar 29, 2023
c471b96
fix examples
vmoens Mar 29, 2023
d9ab477
fix dqn updater
vmoens Mar 29, 2023
fb7d5de
fix doc
vmoens Mar 29, 2023
180b5b2
print td shape
vmoens Mar 29, 2023
921b91b
fix recorder
vmoens Mar 29, 2023
8984654
fix examples
vmoens Mar 29, 2023
a3f76d1
tmp
vmoens Mar 29, 2023
a10900a
tmp
vmoens Mar 29, 2023
2e65eef
tmp
vmoens Mar 29, 2023
aca6946
tmp
vmoens Mar 29, 2023
206830a
tmp
vmoens Mar 29, 2023
3b4e0e7
tmp
vmoens Mar 29, 2023
6d2ff4b
amend
vmoens Mar 29, 2023
965f77d
Merge branch 'fix_rb' into ddpg_loss_tuto
vmoens Mar 29, 2023
a0caddb
Merge branch 'main' into ddpg_loss_tuto
vmoens Mar 29, 2023
1411cf4
amend
vmoens Mar 29, 2023
259a1be
amend
vmoens Mar 30, 2023
bad0d6a
init
vmoens Mar 30, 2023
713869c
Merge branch 'main' into ddpg_loss_tuto
vmoens Mar 30, 2023
9960792
Merge branch 'fix_explo' into ddpg_loss_tuto
vmoens Mar 30, 2023
2fe0f82
amend
vmoens Mar 30, 2023
370134f
Merge branch 'fix_explo' into ddpg_loss_tuto
vmoens Mar 30, 2023
91fa500
amend
vmoens Mar 30, 2023
1d7ffa8
Merge branch 'fix_explo' into ddpg_loss_tuto
vmoens Mar 30, 2023
c251d1a
bf
vmoens Mar 31, 2023
6e28e5f
Merge branch 'fix_explo' into ddpg_loss_tuto
vmoens Mar 31, 2023
495acff
bf
vmoens Mar 31, 2023
59da7a2
amend
vmoens Mar 31, 2023
3025b07
Merge branch 'fix_explo' into ddpg_loss_tuto
vmoens Mar 31, 2023
2498b5f
amend
vmoens Mar 31, 2023
91cab8a
Merge branch 'fix_explo' into ddpg_loss_tuto
vmoens Mar 31, 2023
388552a
Merge branch 'main' into ddpg_loss_tuto
vmoens Mar 31, 2023
f3537fd
Merge branch 'main' into ddpg_loss_tuto
vmoens Mar 31, 2023
f1da081
stateful functional modules
vmoens Apr 2, 2023
00e75f7
Merge branch 'main' into ddpg_loss_tuto
vmoens Apr 2, 2023
14e0a73
amend
vmoens Apr 3, 2023
94ec94e
amend
vmoens Apr 3, 2023
3f16a49
amend
vmoens Apr 3, 2023
833bf58
revert
vmoens Apr 3, 2023
094d49b
amend
vmoens Apr 3, 2023
effa4fc
log_dir
vmoens Apr 3, 2023
4afd785
amend
vmoens Apr 3, 2023
e50f578
amend
vmoens Apr 3, 2023
ac6c83b
amend
vmoens Apr 4, 2023
7180e6c
amend
vmoens Apr 4, 2023
6223494
init
vmoens Apr 4, 2023
b0d9629
empty commit
vmoens Apr 4, 2023
822f518
amend
vmoens Apr 4, 2023
4ad5fb9
amend
vmoens Apr 4, 2023
8f3bf1c
Merge branch 'fix_collector_reset' into ddpg_loss_tuto
vmoens Apr 4, 2023
ff54f0a
amend
vmoens Apr 4, 2023
d978824
amend
vmoens Apr 4, 2023
cf1ba97
amend
vmoens Apr 4, 2023
0f59058
Merge branch 'main' into ddpg_loss_tuto
vmoens Apr 4, 2023
0e4e6b4
theme
vmoens Apr 4, 2023
66da336
amend
vmoens Apr 5, 2023
89e7b1b
amend
vmoens Apr 5, 2023
7d65ca4
remove prints
vmoens Apr 5, 2023
7144490
Merge branch 'main' into ddpg_loss_tuto
vmoens Apr 5, 2023
0d238d5
amend
vmoens Apr 5, 2023
33133cb
amend
vmoens Apr 6, 2023
c221982
amend
vmoens Apr 6, 2023
86aefc2
Merge branch 'main' into ddpg_loss_tuto
vmoens Apr 6, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added docs/source/_static/img/replaybuffer_traj.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
3,824 changes: 2 additions & 3,822 deletions docs/source/_static/js/theme.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ Check the :obj:`torchrl.envs.utils.check_env_specs` method for a sanity check.
Utils
-----

.. currentmodule:: torchrl.data.datasets
.. currentmodule:: torchrl.data

.. autosummary::
:toctree: generated/
Expand Down
1 change: 0 additions & 1 deletion docs/source/reference/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ provides more information on how to design a custom environment from scratch.
EnvBase
GymLikeEnv
EnvMetaData
Specs

Vectorized envs
---------------
Expand Down
2 changes: 1 addition & 1 deletion docs/source/reference/modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ TensorDict modules

Hooks
-----
.. currentmodule:: torchrl.modules.tensordict_module.actors
.. currentmodule:: torchrl.modules

.. autosummary::
:toctree: generated/
Expand Down
8 changes: 5 additions & 3 deletions docs/source/reference/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@ The main characteristics of TorchRL losses are:
method will receive a tensordict as input that contains all the necessary
information to return a loss value.
- They output a :class:`tensordict.TensorDict` instance with the loss values
written under a ``"loss_<smth>`` where ``smth`` is a string describing the
written under a ``"loss_<smth>"`` where ``smth`` is a string describing the
loss. Additional keys in the tensordict may be useful metrics to log during
training time.
.. note::
The reason we return independent losses is to let the user use a different
optimizer for different sets of parameters for instance. Summing the losses
can be simply done via ``sum(loss for key, loss in loss_vals.items() if key.startswith("loss_")``.
can be simply done via

>>> loss_val = sum(loss for key, loss in loss_vals.items() if key.startswith("loss_"))

Training value functions
------------------------
Expand Down Expand Up @@ -216,5 +218,5 @@ Utils
next_state_value
SoftUpdate
HardUpdate
ValueFunctions
ValueEstimators
default_value_kwargs
2 changes: 1 addition & 1 deletion docs/source/reference/trainers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ Hooks can be split into 3 categories: **data processing** (:obj:`"batch_process"
- **Data processing** hooks update a tensordict of data. Hooks :obj:`__call__` method should accept
a :obj:`TensorDict` object as input and update it given some strategy.
Examples of such hooks include Replay Buffer extension (:obj:`ReplayBufferTrainer.extend`), data normalization (including normalization
constants update), data subsampling (:doc:`BatchSubSampler`) and such.
constants update), data subsampling (:class:`torchrl.trainers.BatchSubSampler`) and such.

- **Logging** hooks take a batch of data presented as a :obj:`TensorDict` and write in the logger
some information retrieved from that data. Examples include the :obj:`Recorder` hook, the reward
Expand Down
17 changes: 8 additions & 9 deletions test/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,10 @@ class MockingLossModule(nn.Module):

def mocking_trainer(file=None, optimizer=_mocking_optim) -> Trainer:
trainer = Trainer(
MockingCollector(),
*[
None,
]
* 2,
collector=MockingCollector(),
total_frames=None,
frame_skip=None,
optim_steps_per_batch=None,
loss_module=MockingLossModule(),
optimizer=optimizer,
save_trainer_file=file,
Expand Down Expand Up @@ -862,7 +861,7 @@ def test_recorder(self, N=8):
with tempfile.TemporaryDirectory() as folder:
logger = TensorboardLogger(exp_name=folder)

recorder = transformed_env_constructor(
environment = transformed_env_constructor(
args,
video_tag="tmp",
norm_obs_only=True,
Expand All @@ -874,7 +873,7 @@ def test_recorder(self, N=8):
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
recorder=recorder,
environment=environment,
record_interval=args.record_interval,
)
trainer = mocking_trainer()
Expand Down Expand Up @@ -936,7 +935,7 @@ def _make_recorder_and_trainer(tmpdirname):
raise NotImplementedError
trainer = mocking_trainer(file)

recorder = transformed_env_constructor(
environment = transformed_env_constructor(
args,
video_tag="tmp",
norm_obs_only=True,
Expand All @@ -948,7 +947,7 @@ def _make_recorder_and_trainer(tmpdirname):
record_frames=args.record_frames,
frame_skip=args.frame_skip,
policy_exploration=None,
recorder=recorder,
environment=environment,
record_interval=args.record_interval,
)
recorder.register(trainer)
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from . import datasets
from .postprocs import MultiStep
from .replay_buffers import (
LazyMemmapStorage,
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .d4rl import D4RLExperienceReplay
from .openml import OpenMLExperienceReplay
9 changes: 7 additions & 2 deletions torchrl/data/datasets/openml.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,13 @@
import numpy as np
from tensordict.tensordict import TensorDict

from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
from torchrl.data.replay_buffers import Sampler, SamplerWithoutReplacement, Writer
from torchrl.data.replay_buffers import (
LazyMemmapStorage,
Sampler,
SamplerWithoutReplacement,
TensorDictReplayBuffer,
Writer,
)


class OpenMLExperienceReplay(TensorDictReplayBuffer):
Expand Down
14 changes: 7 additions & 7 deletions torchrl/data/postprocs/postprocs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def _get_reward(
class MultiStep(nn.Module):
"""Multistep reward transform.

Presented in 'Sutton, R. S. 1988. Learning to
predict by the methods of temporal differences. Machine learning 3(
1):9–44.'
Presented in

| Sutton, R. S. 1988. Learning to predict by the methods of temporal differences. Machine learning 3(1):9–44.

This module maps the "next" observation to the t + n "next" observation.
It is an identity transform whenever :attr:`n_steps` is 0.
Expand Down Expand Up @@ -153,6 +153,10 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""
tensordict = tensordict.clone(False)
done = tensordict.get(("next", "done"))
truncated = tensordict.get(
("next", "truncated"), torch.zeros((), dtype=done.dtype, device=done.device)
)
done = done | truncated

# we'll be using the done states to index the tensordict.
# if the shapes don't match we're in trouble.
Expand All @@ -175,10 +179,6 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"(trailing singleton dimension excluded)."
) from err

truncated = tensordict.get(
("next", "truncated"), torch.zeros((), dtype=done.dtype, device=done.device)
)
done = done | truncated
mask = tensordict.get(("collector", "mask"), None)
reward = tensordict.get(("next", "reward"))
*batch, T = tensordict.batch_size
Expand Down
18 changes: 4 additions & 14 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch
from tensordict.tensordict import LazyStackedTensorDict, TensorDict, TensorDictBase
from tensordict.utils import expand_right
from tensordict.utils import expand_as_right

from torchrl.data.utils import DEVICE_TYPING

Expand Down Expand Up @@ -708,6 +708,8 @@ def extend(self, tensordicts: Union[List, TensorDictBase]) -> torch.Tensor:
return index

def update_tensordict_priority(self, data: TensorDictBase) -> None:
if not isinstance(self._sampler, PrioritizedSampler):
return
priority = torch.tensor(
[self._get_priority(td) for td in data],
dtype=torch.float,
Expand Down Expand Up @@ -753,19 +755,7 @@ def sample(
data, info = super().sample(batch_size, return_info=True)
if include_info in (True, None):
for k, v in info.items():
data.set(k, torch.tensor(v, device=data.device))
if "_batch_size" in data.keys():
# we need to reset the batch-size
shape = data.pop("_batch_size")
shape = shape[0]
shape = torch.Size([data.shape[0], *shape])
# we may need to update some values in the data
for key, value in data.items():
if value.ndim >= len(shape):
continue
value = expand_right(value, shape)
data.set(key, value)
data.batch_size = shape
data.set(k, expand_as_right(torch.tensor(v, device=data.device), data))
if return_info:
return data, info
return data
Expand Down
37 changes: 35 additions & 2 deletions torchrl/data/replay_buffers/storages.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from tensordict.memmap import MemmapTensor
from tensordict.prototype import is_tensorclass
from tensordict.tensordict import is_tensor_collection, TensorDict, TensorDictBase
from tensordict.utils import expand_right

from torchrl._utils import _CKPT_BACKEND, VERBOSE
from torchrl.data.replay_buffers.utils import INT_CLASSES
Expand Down Expand Up @@ -423,10 +424,42 @@ def _mem_map_tensor_as_tensor(mem_map_tensor: MemmapTensor) -> torch.Tensor:
return mem_map_tensor._tensor


def _reset_batch_size(x):
"""Resets the batch size of a tensordict.

In some cases we save the original shape of the tensordict as a tensor (or memmap tensor).

This function will read that tensor, extract its items and reset the shape
of the tensordict to it. If items have an incompatible shape (e.g. "index")
they will be expanded to the right to match it.

"""
shape = x.pop("_batch_size", None)
if shape is not None:
# we need to reset the batch-size
if isinstance(shape, MemmapTensor):
shape = shape.as_tensor()
locked = x.is_locked
if locked:
x.unlock_()
shape = [s.item() for s in shape[0]]
shape = torch.Size([x.shape[0], *shape])
# we may need to update some values in the data
for key, value in x.items():
if value.ndim >= len(shape):
continue
value = expand_right(value, shape)
x.set(key, value)
x.batch_size = shape
if locked:
x.lock_()
return x


def _collate_list_tensordict(x):
out = torch.stack(x, 0)
if isinstance(out, TensorDictBase):
return out.to_tensordict()
return _reset_batch_size(out.to_tensordict())
return out


Expand All @@ -436,7 +469,7 @@ def _collate_list_tensors(*x):

def _collate_contiguous(x):
if isinstance(x, TensorDictBase):
return x.to_tensordict()
return _reset_batch_size(x).to_tensordict()
return x.clone()


Expand Down
60 changes: 51 additions & 9 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2602,6 +2602,13 @@ class VecNorm(Transform):
default: 0.99
eps (number, optional): lower bound of the running standard
deviation (for numerical underflow). Default is 1e-4.
shapes (List[torch.Size], optional): if provided, represents the shape
of each in_keys. Its length must match the one of ``in_keys``.
Each shape must match the trailing dimension of the corresponding
entry.
If not, the feature dimensions of the entry (ie all dims that do
not belong to the tensordict batch-size) will be considered as
feature dimension.

Examples:
>>> from torchrl.envs.libs.gym import GymEnv
Expand Down Expand Up @@ -2629,6 +2636,7 @@ def __init__(
lock: mp.Lock = None,
decay: float = 0.9999,
eps: float = 1e-4,
shapes: List[torch.Size] = None,
) -> None:
if lock is None:
lock = mp.Lock()
Expand Down Expand Up @@ -2656,8 +2664,14 @@ def __init__(

self.lock = lock
self.decay = decay
self.shapes = shapes
self.eps = eps

def _key_str(self, key):
if not isinstance(key, str):
key = "_".join(key)
return key

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self.lock is not None:
self.lock.acquire()
Expand All @@ -2681,17 +2695,44 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
forward = _call

def _init(self, tensordict: TensorDictBase, key: str) -> None:
if self._td is None or key + "_sum" not in self._td.keys():
td_view = tensordict.view(-1)
td_select = td_view[0]
d = {key + "_sum": torch.zeros_like(td_select.get(key))}
d.update({key + "_ssq": torch.zeros_like(td_select.get(key))})
key_str = self._key_str(key)
if self._td is None or key_str + "_sum" not in self._td.keys():
if key is not key_str and key_str in tensordict.keys():
raise RuntimeError(
f"Conflicting key names: {key_str} from VecNorm and input tensordict keys."
)
if self.shapes is None:
td_view = tensordict.view(-1)
td_select = td_view[0]
item = td_select.get(key)
d = {key_str + "_sum": torch.zeros_like(item)}
d.update({key_str + "_ssq": torch.zeros_like(item)})
else:
idx = 0
for in_key in self.in_keys:
if in_key != key:
idx += 1
else:
break
shape = self.shapes[idx]
item = tensordict.get(key)
d = {
key_str
+ "_sum": torch.zeros(shape, device=item.device, dtype=item.dtype)
}
d.update(
{
key_str
+ "_ssq": torch.zeros(
shape, device=item.device, dtype=item.dtype
)
}
)

d.update(
{
key
+ "_count": torch.zeros(
1, device=td_select.get(key).device, dtype=torch.float
)
key_str
+ "_count": torch.zeros(1, device=item.device, dtype=torch.float)
}
)
if self._td is None:
Expand All @@ -2702,6 +2743,7 @@ def _init(self, tensordict: TensorDictBase, key: str) -> None:
pass

def _update(self, key, value, N) -> torch.Tensor:
key = self._key_str(key)
_sum = self._td.get(key + "_sum")
_ssq = self._td.get(key + "_ssq")
_count = self._td.get(key + "_count")
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,12 @@
ActorValueOperator,
AdditiveGaussianWrapper,
DistributionalQValueActor,
DistributionalQValueHook,
EGreedyWrapper,
OrnsteinUhlenbeckProcessWrapper,
ProbabilisticActor,
QValueActor,
QValueHook,
SafeModule,
SafeProbabilisticModule,
SafeProbabilisticTensorDictSequential,
Expand Down
2 changes: 2 additions & 0 deletions torchrl/modules/tensordict_module/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
ActorCriticWrapper,
ActorValueOperator,
DistributionalQValueActor,
DistributionalQValueHook,
ProbabilisticActor,
QValueActor,
QValueHook,
ValueOperator,
)
from .common import SafeModule
Expand Down
Loading