diff --git a/src/utils/operations.py b/src/utils/operations.py index d485aa8..5d08578 100644 --- a/src/utils/operations.py +++ b/src/utils/operations.py @@ -66,7 +66,7 @@ def masked_nllloss(logprobs, target, lengths, device): return ( (loss_raw * device( Variable(loss_mask.view(-1)))).sum() - / loss_mask.sum() + / device(loss_mask).sum() )