-
Notifications
You must be signed in to change notification settings - Fork 617
Closed
Labels
Description
System information
- OS: Ubuntu 18.04
- TensorFlow: 2.0.0, pip install tensorflow-gpu
- TensorFlow-Addons: 0.6.0, pip install tensorflow-addons
- Python version: 3.6.7
- Is GPU used? (yes/no): yes
- CUDA: 10.0
- CUDNN: 7.6.4
Describe the bug
tf.shape returns int32 but Mul op expects int64 (in crf_binary_score function)
tf.maximum returns int64 but Pack op expects int32 (in crf_forward function)
Code to reproduce the issue
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np
inputs = np.ones([16, 20, 5], dtype=np.float32)
tags = tf.convert_to_tensor(np.ones([16, 20], dtype=np.int64))
seq_lens = np.ones([16,], dtype=np.int64) * 20
loss, _ = tfa.text.crf_log_likelihood(
inputs=inputs,
tag_indices=tags,
sequence_lengths=seq_lens
)
Other info / logs
Error logs:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-2-41e7d4eaa2ab> in <module>
6 inputs=inputs,
7 tag_indices=tags,
----> 8 sequence_lengths=seq_lens
9 )
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_log_likelihood(inputs, tag_indices, sequence_lengths, transition_params)
194
195 sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
--> 196 transition_params)
197 log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
198
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_sequence_score(inputs, tag_indices, sequence_lengths, transition_params)
66 return _single_seq_fn()
67 else:
---> 68 return _multi_seq_fn()
69
70
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in _multi_seq_fn()
59 unary_scores = crf_unary_score(tag_indices, sequence_lengths, inputs)
60 binary_scores = crf_binary_score(tag_indices, sequence_lengths,
---> 61 transition_params)
62 sequence_scores = unary_scores + binary_scores
63 return sequence_scores
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_binary_score(tag_indices, sequence_lengths, transition_params)
257 # Encode the indices in a flattened representation.
258 flattened_transition_indices = start_tag_indices * \
--> 259 num_tags + end_tag_indices
260 flattened_transition_params = tf.reshape(transition_params, [-1])
261
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/ops/math_ops.py in binary_op_wrapper(x, y)
897 with ops.name_scope(None, op_name, [x, y]) as name:
898 if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
--> 899 return func(x, y, name=name)
900 elif not isinstance(y, sparse_tensor.SparseTensor):
901 try:
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/ops/math_ops.py in _mul_dispatch(x, y, name)
1204 is_tensor_y = isinstance(y, ops.Tensor)
1205 if is_tensor_y:
-> 1206 return gen_math_ops.mul(x, y, name=name)
1207 else:
1208 assert isinstance(y, sparse_tensor.SparseTensor) # Case: Dense * Sparse.
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_math_ops.py in mul(x, y, name)
6696 else:
6697 message = e.message
-> 6698 _six.raise_from(_core._status_to_exception(e.code, message), None)
6699 # Add nodes to the TensorFlow graph.
6700 _, _, _op = _op_def_lib._apply_op_helper(
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: cannot compute Mul as input #1(zero-based) was expected to be a int64 tensor but is a int32 tensor [Op:Mul] name: mul/
This can be fixed by cast num_tags in crf_binary_score
num_tags = tf.cast(num_tags, dtype=tf.int64)
After fixing this error, another error comes up:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-2-41e7d4eaa2ab> in <module>
6 inputs=inputs,
7 tag_indices=tags,
----> 8 sequence_lengths=seq_lens
9 )
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_log_likelihood(inputs, tag_indices, sequence_lengths, transition_params)
195 sequence_scores = crf_sequence_score(inputs, tag_indices, sequence_lengths,
196 transition_params)
--> 197 log_norm = crf_log_norm(inputs, sequence_lengths, transition_params)
198
199 # Normalize the scores to get the log-likelihood per example.
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_log_norm(inputs, sequence_lengths, transition_params)
161 return _single_seq_fn()
162 else:
--> 163 return _multi_seq_fn()
164
165
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in _multi_seq_fn()
150
151 alphas = crf_forward(rest_of_input, first_input, transition_params,
--> 152 sequence_lengths)
153 log_norm = tf.reduce_logsumexp(alphas, [1])
154 # Mask `log_norm` of the sequences with length <= zero.
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_addons/text/crf.py in crf_forward(inputs, state, transition_params, sequence_lengths)
304 all_alphas = tf.transpose(tf.scan(_scan_fn, inputs, state), [1, 0, 2])
305 idxs = tf.stack(
--> 306 [tf.range(tf.shape(sequence_lengths)[0]), sequence_lengths], axis=1)
307 return tf.gather_nd(all_alphas, idxs)
308
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/util/dispatch.py in wrapper(*args, **kwargs)
178 """Call target, and fall back on dispatchers if there is a TypeError."""
179 try:
--> 180 return target(*args, **kwargs)
181 except (TypeError, ValueError):
182 # Note: convert_to_eager_tensor currently raises a ValueError, not a
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/ops/array_ops.py in stack(values, axis, name)
1163 (axis, -expanded_num_dims, expanded_num_dims))
1164
-> 1165 return gen_array_ops.pack(values, axis=axis, name=name)
1166
1167
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/tensorflow_core/python/ops/gen_array_ops.py in pack(values, axis, name)
6291 else:
6292 message = e.message
-> 6293 _six.raise_from(_core._status_to_exception(e.code, message), None)
6294 # Add nodes to the TensorFlow graph.
6295 if not isinstance(values, (list, tuple)):
~/Deployment/tf_sequence_labeler/env2/lib/python3.6/site-packages/six.py in raise_from(value, from_value)
InvalidArgumentError: cannot compute Pack as input #1(zero-based) was expected to be a int32 tensor but is a int64 tensor [Op:Pack] name: stack
This can be fixed by casting sequence_lengths to int32 in crf_forward function:
sequence_lengths = tf.cast(sequence_lengths, dtype=tf.int32)
Is this a bug, or I used it in a wrong way?