-
Notifications
You must be signed in to change notification settings - Fork 413
Description
Describe the bug
LossModule.convert_to_functional
can NOT correctly deal with a Linear
module in the PyG lib.
As some modules in PyG lib depend on the Linear
module of the PyG version to work, if LossModule.convert_to_functional
fails to work with it, it would cause severe incompatibility with the PyG lib.
To Reproduce
This example is mostly copied from the TorchRL's documentation, with only a substitution of torch.nn.Linear
with torch_geometric.nn.dense.Linear
.
Code
import torch
from torch import nn
from torch_geometric.nn.dense import Linear
from torchrl.data.tensor_specs import BoundedTensorSpec
from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.modules.tensordict_module.common import SafeModule
from torchrl.objectives.ppo import PPOLoss
from tensordict.tensordict import TensorDict
n_act, n_obs = 4, 3
spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
base_layer = Linear(n_obs, 5)
net = NormalParamWrapper(nn.Sequential(base_layer, Linear(5, 2 * n_act)))
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
actor = ProbabilisticActor(
module=module,
distribution_class=TanhNormal,
in_keys=["loc", "scale"],
spec=spec)
module = nn.Sequential(base_layer, Linear(5, 1))
value = ValueOperator(
module=module,
in_keys=["observation"])
loss = PPOLoss(actor, value)
batch = [2, ]
action = spec.rand(batch)
data = TensorDict({"observation": torch.randn(*batch, n_obs),
"action": action,
"sample_log_prob": torch.randn_like(action[..., 1]),
("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
("next", "reward"): torch.randn(*batch, 1),
("next", "observation"): torch.randn(*batch, n_obs),
}, batch)
loss(data)
Traceback Info
Traceback (most recent call last):
File "<project_root>/issuse.py", line 35, in <module>
loss(data)
File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
result = forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/tensordict/_contextlib.py", line 126, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/common.py", line 282, in wrapper
return func(_self, tensordict, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torchrl/objectives/ppo.py", line 433, in forward
self.value_estimator(
File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py", line 63, in new_func
return fun(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py", line 52, in new_fun
return fun(self, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/common.py", line 282, in wrapper
return func(_self, tensordict, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py", line 1224, in forward
value, next_value = _call_value_nets(
^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py", line 138, in _call_value_nets
data_out = vmap(value_net, (0, 0))(data_in, params_stack)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 188, in wrapped
return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 266, in vmap_impl
return _flat_vmap(
^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 38, in fn
return f(*args, **kwargs)
^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 379, in _flat_vmap
batched_outputs = func(*batched_inputs, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/functional_modules.py", line 572, in new_fun
old_params = _assign_params(
^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/functional_modules.py", line 649, in _assign_params
return _swap_state(module, params, make_stateless, return_old_tensordict)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/functional_modules.py", line 389, in _swap_state
_old_value = _swap_state(
^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/functional_modules.py", line 389, in _swap_state
_old_value = _swap_state(
^^^^^^^^^^^^
File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/functional_modules.py", line 378, in _swap_state
raise Exception(f"{model}\nhas no stateless attribute.")
Exception: Linear(3, 5, bias=False)
has no stateless attribute.
Looking into the traceback, it's found that the state attribute _is_stateless
is missing, which should be correctly copied in LossModule.convert_to_functional
:
rl/torchrl/objectives/common.py
Line 247 in bf264e0
functional_module = deepcopy(module) |
However, Linear
of the PyG version overrides the default deepcopy behavior:
This code is copied from PyG repo.
def __deepcopy__(self, memo):
out = Linear(self.in_channels, self.out_channels, self.bias
is not None, self.weight_initializer,
self.bias_initializer)
if self.in_channels > 0:
out.weight = copy.deepcopy(self.weight, memo)
if self.bias is not None:
out.bias = copy.deepcopy(self.bias, memo)
return out
which will cause the extra attribute added to the module and not aware by the Linear module will be ignored.
Expected behavior
LossModule.convert_to_functional
can correctly deal with a Linear
module in the PyG lib, just as the standard Linear
module in Pytorch.
System info
- TorchRL 0.2.0
- Torch 2.1.0
- Python 3.11.6 | packaged by conda-forge
- OS: WSL2 - Ubuntu 22.04 LTS
Checklist
- I have checked that there is no similar issue in the repo (required)
- I have read the documentation (required)
- I have provided a minimal working example to reproduce the bug (required)