Skip to content

Order of inputs of forward function problematic for jit with Classification models #1010

@dhpollack

Description

@dhpollack

TL;DR

Due to order of args of forward in classification models, device gets hardcoded during jit tracing or causes unwanted overhead. Easy solution (but possibly breaking):

# change this
# classification BERT
class BertForSequenceClassification(BertPreTrainedModel):
    ...
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
                position_ids=None, head_mask=None):
        ...
# to this
# classification BERT
class BertForSequenceClassification(BertPreTrainedModel):
    ...
    def forward(self, input_ids, token_type_ids=None, attention_mask=None,
                position_ids=None, head_mask=None, labels=None):
        ...

Long Version

The order of the inputs of the models is problematic for jit tracing, because you separate the inputs of the base BERT models in the classifications models. Confusing in words, but easy to see in code:

# base BERT
class BertModel(BertPreTrainedModel):
    ...
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
        ...

# classification BERT
# notice the order where labels comes in
class BertForSequenceClassification(BertPreTrainedModel):
    ...
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
                position_ids=None, head_mask=None):
        ...

The problem arises because torch.jit.trace does not use the python logic when creating the embedding layer. This line, position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device), becomes position_ids = torch.arange(seq_length, dtype=torch.long, device=torch.device("[device at time of jit]")). Importantly, model.to(device) will not change this hardcoded device in the embeddings. Thus the torch device gets hardcoded into the whole network and one can't use model.to(device) as expected. One could circumvent this problem by explicitly passing position_ids at the time of tracing, but the torch.jit.trace function only takes a tuple of inputs. Because labels comes before position_ids, you cannot jit trace the function without putting in dummy labels and doing the extra overhead of calculating the loss, which you don't want for a graph used solely for inference.

The simple solution is to change the order of your arguments to make the labels argument come after the arguments in the base bert model. Of course, this could break existing scripts that rely on this order, although the current examples use kwargs so it should be a problem.

# classification BERT
class BertForSequenceClassification(BertPreTrainedModel):
    ...
    def forward(self, input_ids, token_type_ids=None, attention_mask=None,
                position_ids=None, head_mask=None, labels=None):
        ...

If this were done then one could do:

#  model = any of the classification models
msl = 15  # max sequence length, which gets hardcoded into the network
inputs = [
    torch.ones(1, msl, dtype=torch.long()),  # input_ids
    torch.ones(1, msl, dtype=torch.long()),  # segment_ids
    torch.ones(1, msl, dtype=torch.long()),  # attention_masks
    torch.ones(1, msl, dtype=torch.long()),  # position_ids
]
traced_model = torch.jit.trace(model, input)

Finally, and this is a judgement call, it's illogical to stick the labels parameter into the middle of the list of parameters, it probably should be at the end. But that is a minor, minor gripe in an otherwise fantastically built library.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions