Skip to content

"example_input_array" depends on ordering of modules #1556

@awaelchli

Description

@awaelchli

🐛 Bug

To Reproduce

  1. Go to the pl_examples/basic_examples/LightningTemplateModel.py
  2. Change the order of modules in the __build_model method from
    def __build_model(self):
        self.c_d1 = nn.Linear(in_features=self.hparams.in_features,
                              out_features=self.hparams.hidden_dim)
        self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
        self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)

        self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim,
                              out_features=self.hparams.out_features)

to:

    def __build_model(self):
        self.c_d1 = nn.Linear(in_features=self.hparams.in_features,
                              out_features=self.hparams.hidden_dim)
        # move the layer definition up here
        self.c_d2 = nn.Linear(in_features=self.hparams.hidden_dim,  
                              out_features=self.hparams.out_features)

        self.c_d1_bn = nn.BatchNorm1d(self.hparams.hidden_dim)
        self.c_d1_drop = nn.Dropout(self.hparams.drop_prob)

We get an error message because input size does not match (for this order).

Expected behavior

Input output sizes are computed in order of execution, not definition. This is important because PyTorch graphs are dynamically built on each forward, so order of execution of each layer is not known beforehand.

Proposed Fix

I propose to install a forward hook on each submodule and compute the sizes that way.
I have started to validate the fix already and would like to submit a PR very soon if you agree.

Additional Context

It could be confusing to a user to see this error, they might think something is wrong with their code.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinghelp wantedOpen to be worked on

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions