Skip to content

InvalidArgumentError caused by int32 / int64 in crf_log_likelihood #623

@xiayandi

Description

@xiayandi

System information

  • OS: Ubuntu 18.04
  • TensorFlow: 2.0.0, pip install tensorflow-gpu
  • TensorFlow-Addons: 0.6.0, pip install tensorflow-addons
  • Python version: 3.6.7
  • Is GPU used? (yes/no): yes
  • CUDA: 10.0
  • CUDNN: 7.6.4

Describe the bug

tf.shape returns int32 but Mul op expects int64 (in crf_binary_score function)
tf.maximum returns int64 but Pack op expects int32 (in crf_forward function)

Code to reproduce the issue

import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np

inputs = np.ones([16, 20, 5], dtype=np.float32)
tags = tf.convert_to_tensor(np.ones([16, 20], dtype=np.int64))
seq_lens = np.ones([16,], dtype=np.int64) * 20
loss, _ = tfa.text.crf_log_likelihood(
    inputs=inputs,
    tag_indices=tags,
    sequence_lengths=seq_lens
)

Other info / logs

Error logs:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-2-41e7d4eaa2ab> in <module>
      6     inputs=inputs,
      7     tag_indices=tags,
----> 8     sequence_lengths=seq_lens
      9 )

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_log_likelihood(inputs, tag_indices, sequence_lengths, transition_params)
    194 
    195     sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
--> 196                                          transition_params)
    197     log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
    198 

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params)
     66         return _single_seq_fn()
     67     else:
---> 68         return _multi_seq_fn()
     69 
     70 

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in _multi_seq_fn()
     59         unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
     60         binary_scores = crf_binary_score(tag_indices, sequence_lengths,
---> 61                                          transition_params)
     62         sequence_scores = unary_scores + binary_scores
     63         return sequence_scores

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_binary_score(tag_indices, sequence_lengths, transition_params)
    257     # Encode the indices in a flattened representation.
    258     flattened_transition_indices = start_tag_indices * \
--> 259         num_tags + end_tag_indices
    260     flattened_transition_params = tf.reshape(transition_params, [-1])
    261 

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/ops/math_ops.py in binary_op_wrapper(x, y)
    897     with ops.name_scope(None, op_name, [x, y]) as name:
    898       if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
--> 899         return func(x, y, name=name)
    900       elif not isinstance(y, sparse_tensor.SparseTensor):
    901         try:

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/ops/math_ops.py in _mul_dispatch(x, y, name)
   1204   is_tensor_y = isinstance(y, ops.Tensor)
   1205   if is_tensor_y:
-> 1206     return gen_math_ops.mul(x, y, name=name)
   1207   else:
   1208     assert isinstance(y, sparse_tensor.SparseTensor)  # Case: Dense * Sparse.

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_math_ops.py in mul(x, y, name)
   6696       else:
   6697         message = e.message
-> 6698       _six.raise_from(_core._status_to_exception(e.code, message), None)
   6699   # Add nodes to the TensorFlow graph.
   6700   _, _, _op = _op_def_lib._apply_op_helper(

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a int64 tensor but is a int32 tensor [Op:Mul] name: mul/

This can be fixed by cast num_tags in crf_binary_score

num_tags = tf.cast(num_tags, dtype=tf.int64)

After fixing this error, another error comes up:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-2-41e7d4eaa2ab> in <module>
      6     inputs=inputs,
      7     tag_indices=tags,
----> 8     sequence_lengths=seq_lens
      9 )

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_log_likelihood(inputs, tag_indices, sequence_lengths, transition_params)
    195     sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
    196                                          transition_params)
--> 197     log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
    198 
    199     # Normalize the scores to get the log-likelihood per example.

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_log_norm(inputs, sequence_lengths, transition_params)
    161         return _single_seq_fn()
    162     else:
--> 163         return _multi_seq_fn()
    164 
    165 

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in _multi_seq_fn()
    150 
    151         alphas = crf_forward(rest_of_input, first_input, transition_params,
--> 152                              sequence_lengths)
    153         log_norm = tf.reduce_logsumexp(alphas, [1])
    154         # Mask `log_norm` of the sequences with length <= zero.

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_forward(inputs, state, transition_params, sequence_lengths)
    304     all_alphas = tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2])
    305     idxs = tf.stack(
--> 306         [tf.range(tf.shape(sequence_lengths)[0]), sequence_lengths], axis=1)
    307     return tf.gather_nd(all_alphas, idxs)
    308 

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/util/dispatch.py in wrapper(*args, **kwargs)
    178     """Call target, and fall back on dispatchers if there is a TypeError."""
    179     try:
--> 180       return target(*args, **kwargs)
    181     except (TypeError, ValueError):
    182       # Note: convert_to_eager_tensor currently raises a ValueError, not a

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/ops/array_ops.py in stack(values, axis, name)
   1163                        (axis, -expanded_num_dims, expanded_num_dims))
   1164 
-> 1165   return gen_array_ops.pack(values, axis=axis, name=name)
   1166 
   1167 

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_array_ops.py in pack(values, axis, name)
   6291       else:
   6292         message = e.message
-> 6293       _six.raise_from(_core._status_to_exception(e.code, message), None)
   6294   # Add nodes to the TensorFlow graph.
   6295   if not isinstance(values, (list, tuple)):

~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: cannot compute Pack as input #1(zero-based) was expected to be a int32 tensor but is a int64 tensor [Op:Pack] name: stack

This can be fixed by casting sequence_lengths to int32 in crf_forward function:

sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)

Is this a bug, or I used it in a wrong way?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions