Skip to content

[BUG] LossModule.convert_to_functional can NOT correctly deal with a Linear module in the PyG lib #1613

@Sefank

Description

@Sefank

Describe the bug

LossModule.convert_to_functional can NOT correctly deal with a Linear module in the PyG lib.

As some modules in PyG lib depend on the Linear module of the PyG version to work, if LossModule.convert_to_functional fails to work with it, it would cause severe incompatibility with the PyG lib.

To Reproduce

This example is mostly copied from the TorchRL's documentation, with only a substitution of torch.nn.Linear with torch_geometric.nn.dense.Linear.

Code
import torch
from torch import nn
from torch_geometric.nn.dense import Linear
from torchrl.data.tensor_specs import BoundedTensorSpec
from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.modules.tensordict_module.common import SafeModule
from torchrl.objectives.ppo import PPOLoss
from tensordict.tensordict import TensorDict
n_act, n_obs = 4, 3
spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,))
base_layer = Linear(n_obs, 5)
net = NormalParamWrapper(nn.Sequential(base_layer, Linear(5, 2 * n_act)))
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
actor = ProbabilisticActor(
    module=module,
    distribution_class=TanhNormal,
    in_keys=["loc", "scale"],
    spec=spec)
module = nn.Sequential(base_layer, Linear(5, 1))
value = ValueOperator(
    module=module,
    in_keys=["observation"])
loss = PPOLoss(actor, value)
batch = [2, ]
action = spec.rand(batch)
data = TensorDict({"observation": torch.randn(*batch, n_obs),
        "action": action,
        "sample_log_prob": torch.randn_like(action[..., 1]),
        ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool),
        ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool),
        ("next", "reward"): torch.randn(*batch, 1),
        ("next", "observation"): torch.randn(*batch, n_obs),
    }, batch)
loss(data)
Traceback Info
Traceback (most recent call last):
  File "<project_root>/issuse.py", line 35, in <module>
    loss(data)
  File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1568, in _call_impl
    result = forward_call(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/tensordict/_contextlib.py", line 126, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/common.py", line 282, in wrapper
    return func(_self, tensordict, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torchrl/objectives/ppo.py", line 433, in forward
    self.value_estimator(
  File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py", line 63, in new_func
    return fun(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py", line 52, in new_fun
    return fun(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/common.py", line 282, in wrapper
    return func(_self, tensordict, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py", line 1224, in forward
    value, next_value = _call_value_nets(
                        ^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torchrl/objectives/value/advantages.py", line 138, in _call_value_nets
    data_out = vmap(value_net, (0, 0))(data_in, params_stack)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 266, in vmap_impl
    return _flat_vmap(
           ^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 38, in fn
    return f(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 379, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/functional_modules.py", line 572, in new_fun
    old_params = _assign_params(
                 ^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/functional_modules.py", line 649, in _assign_params
    return _swap_state(module, params, make_stateless, return_old_tensordict)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/functional_modules.py", line 389, in _swap_state
    _old_value = _swap_state(
                 ^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/functional_modules.py", line 389, in _swap_state
    _old_value = _swap_state(
                 ^^^^^^^^^^^^
  File "<project_root>/venv/lib/python3.11/site-packages/tensordict/nn/functional_modules.py", line 378, in _swap_state
    raise Exception(f"{model}\nhas no stateless attribute.")
Exception: Linear(3, 5, bias=False)
has no stateless attribute.

Looking into the traceback, it's found that the state attribute _is_stateless is missing, which should be correctly copied in LossModule.convert_to_functional:

functional_module = deepcopy(module)

However, Linear of the PyG version overrides the default deepcopy behavior:

This code is copied from PyG repo.

    def __deepcopy__(self, memo):
        out = Linear(self.in_channels, self.out_channels, self.bias
                     is not None, self.weight_initializer,
                     self.bias_initializer)
        if self.in_channels > 0:
            out.weight = copy.deepcopy(self.weight, memo)
        if self.bias is not None:
            out.bias = copy.deepcopy(self.bias, memo)
        return out

which will cause the extra attribute added to the module and not aware by the Linear module will be ignored.

Expected behavior

LossModule.convert_to_functional can correctly deal with a Linear module in the PyG lib, just as the standard Linear module in Pytorch.

System info

  • TorchRL 0.2.0
  • Torch 2.1.0
  • Python 3.11.6 | packaged by conda-forge
  • OS: WSL2 - Ubuntu 22.04 LTS

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions