Skip to content
Merged
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,12 @@ when requested, as critics. We provide a set of base models (layers) and a Seque
different layers. All the models can be used with or without parameter sharing within an
agent group. Here is a table of the models implemented in BenchMARL

| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|--------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |
| Name | Decentralized | Centralized with local inputs | Centralized with global input |
|------------------------------------------|:-------------:|:-----------------------------:|:-----------------------------:|
| [MLP](benchmarl/models/mlp.py) | Yes | Yes | Yes |
| [GNN](benchmarl/models/gnn.py) | Yes | Yes | No |
| [CNN](benchmarl/models/cnn.py) | Yes | Yes | Yes |
| [Deepsets](benchmarl/models/deepsets.py) | Yes | Yes | Yes |

And the ones that are _work in progress_

Expand Down
9 changes: 9 additions & 0 deletions benchmarl/conf/model/layers/deepsets.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@

name: deepsets

aggr: "sum"
local_nn_num_cells: [128, 128]
local_nn_activation_class: torch.nn.Tanh
out_features_local_nn: 256
global_nn_num_cells: [256, 256]
global_nn_activation_class: torch.nn.Tanh
19 changes: 17 additions & 2 deletions benchmarl/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,24 @@

from .cnn import Cnn, CnnConfig
from .common import Model, ModelConfig, SequenceModel, SequenceModelConfig
from .deepsets import Deepsets, DeepsetsConfig
from .gnn import Gnn, GnnConfig
from .mlp import Mlp, MlpConfig

classes = ["Mlp", "MlpConfig", "Gnn", "GnnConfig", "Cnn", "CnnConfig"]
classes = [
"Mlp",
"MlpConfig",
"Gnn",
"GnnConfig",
"Cnn",
"CnnConfig",
"Deepsets",
"DeepsetsConfig",
]

model_config_registry = {"mlp": MlpConfig, "gnn": GnnConfig, "cnn": CnnConfig}
model_config_registry = {
"mlp": MlpConfig,
"gnn": GnnConfig,
"cnn": CnnConfig,
"deepsets": DeepsetsConfig,
}
Loading