-
Notifications
You must be signed in to change notification settings - Fork 31.2k
Description
TL;DR
Due to order of args of forward in classification models, device gets hardcoded during jit tracing or causes unwanted overhead. Easy solution (but possibly breaking):
# change this
# classification BERT
class BertForSequenceClassification(BertPreTrainedModel):
...
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
...
# to this
# classification BERT
class BertForSequenceClassification(BertPreTrainedModel):
...
def forward(self, input_ids, token_type_ids=None, attention_mask=None,
position_ids=None, head_mask=None, labels=None):
...
Long Version
The order of the inputs of the models is problematic for jit tracing, because you separate the inputs of the base BERT models in the classifications models. Confusing in words, but easy to see in code:
# base BERT
class BertModel(BertPreTrainedModel):
...
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
...
# classification BERT
# notice the order where labels comes in
class BertForSequenceClassification(BertPreTrainedModel):
...
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
...
The problem arises because torch.jit.trace does not use the python logic when creating the embedding layer. This line, position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device), becomes position_ids = torch.arange(seq_length, dtype=torch.long, device=torch.device("[device at time of jit]")). Importantly, model.to(device) will not change this hardcoded device in the embeddings. Thus the torch device gets hardcoded into the whole network and one can't use model.to(device) as expected. One could circumvent this problem by explicitly passing position_ids at the time of tracing, but the torch.jit.trace function only takes a tuple of inputs. Because labels comes before position_ids, you cannot jit trace the function without putting in dummy labels and doing the extra overhead of calculating the loss, which you don't want for a graph used solely for inference.
The simple solution is to change the order of your arguments to make the labels argument come after the arguments in the base bert model. Of course, this could break existing scripts that rely on this order, although the current examples use kwargs so it should be a problem.
# classification BERT
class BertForSequenceClassification(BertPreTrainedModel):
...
def forward(self, input_ids, token_type_ids=None, attention_mask=None,
position_ids=None, head_mask=None, labels=None):
...
If this were done then one could do:
# model = any of the classification models
msl = 15 # max sequence length, which gets hardcoded into the network
inputs = [
torch.ones(1, msl, dtype=torch.long()), # input_ids
torch.ones(1, msl, dtype=torch.long()), # segment_ids
torch.ones(1, msl, dtype=torch.long()), # attention_masks
torch.ones(1, msl, dtype=torch.long()), # position_ids
]
traced_model = torch.jit.trace(model, input)
Finally, and this is a judgement call, it's illogical to stick the labels parameter into the middle of the list of parameters, it probably should be at the end. But that is a minor, minor gripe in an otherwise fantastically built library.