diff --git a/stanza/models/common/crf.py b/stanza/models/common/crf.py index 778a14a56..acfdbae73 100644 --- a/stanza/models/common/crf.py +++ b/stanza/models/common/crf.py @@ -134,14 +134,14 @@ def log_sum_exp(value, dim=None, keepdim=False): value.exp().sum(dim, keepdim).log() """ if dim is not None: - m, _ = torch.max(value, dim=dim, keepdim=True) + m = torch.amax(value, dim=dim, keepdim=True) value0 = value - m if keepdim is False: m = m.squeeze(dim) return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) else: - m = torch.max(value) + m = torch.amax(value) sum_exp = torch.sum(torch.exp(value - m)) if isinstance(sum_exp, Number): return m + math.log(sum_exp)