Skip to content

SigmoidFocalCrossEntropy loss fails to compile model #876

@swghosh

Description

@swghosh

System information

  • OS Platform and Distribution: Deep Learning Linux VM (Debian Stretch) GCE Image
  • TensorFlow version and how it was installed: Binary v2.1.0
  • TensorFlow-Addons version and how it was installed: Binary v0.7.0
  • Python version: Python 3.5.3
  • Is GPU used? (yes/no): No

Describe the bug

model.compile tries to call some code with some test values for y_true and y_pred which causes a shape mismatch with tfa.losses.sigmoid_focal_crossentropy. The following line is causing the error to occur. It may also be a problem with the tf.keras models when used with custom losses altogether.

if y_true.shape != y_pred.shape:
raise ValueError("Shape mismatch for y_true: {} and y_pred: {}".format(
tf.shape(y_true), tf.shape(y_pred)))

Also, it is notable to mention that the same model (with exact same code) is able to be compile when using tensorflow-addons v0.6.0 (even, training works perfectly fine)

Code to reproduce the issue

from tensorflow.keras.layers import Input, Dense
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import Reduction

from tensorflow_addons.losses import SigmoidFocalCrossEntropy

def create_mlp(input_size, num_classes):
    model = Sequential([
        Dense(100, activation='relu', input_shape=(input_size,), name='hidden'),
        Dense(num_classes, activation='softmax', name='output')
    ], name='mlp')

    loss = SigmoidFocalCrossEntropy(reduction=Reduction.SUM_OVER_BATCH_SIZE)
    opt = Adam()

    model.summary()
    model.compile(opt, loss)

    return model

mlp = create_mlp(1000, 5)

Other info / logs

Model: "mlp"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
hidden (Dense)               (None, 100)               100100    
_________________________________________________________________
output (Dense)               (None, 5)                 505       
=================================================================
Total params: 100,605
Trainable params: 100,605
Non-trainable params: 0
_________________________________________________________________
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-10-2891a1b09888> in <module>()
----> 1 mlp = create_mlp_classifier(1000, 5)

14 frames
<ipython-input-8-686d7bad7efc> in create_mlp_classifier(input_size, num_classes)
     20 
     21     model.summary()
---> 22     model.compile(opt, loss)
     23 
     24     return model

/tensorflow-2.1.0/python3.6/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
    455     self._self_setattr_tracking = False  # pylint: disable=protected-access
    456     try:
--> 457       result = method(self, *args, **kwargs)
    458     finally:
    459       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/engine/training.py in compile(self, optimizer, loss, metrics, loss_weights, sample_weight_mode, weighted_metrics, target_tensors, distribute, **kwargs)
    444 
    445       # Creates the model loss and weighted metrics sub-graphs.
--> 446       self._compile_weights_loss_and_weighted_metrics()
    447 
    448       # Functions for train, test and predict will

/tensorflow-2.1.0/python3.6/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
    455     self._self_setattr_tracking = False  # pylint: disable=protected-access
    456     try:
--> 457       result = method(self, *args, **kwargs)
    458     finally:
    459       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/engine/training.py in _compile_weights_loss_and_weighted_metrics(self, sample_weights)
   1590       #                   loss_weight_2 * output_2_loss_fn(...) +
   1591       #                   layer losses.
-> 1592       self.total_loss = self._prepare_total_loss(masks)
   1593 
   1594   def _prepare_skip_target_masks(self):

/tensorflow-2.1.0/python3.6/tensorflow_core/python/keras/engine/training.py in _prepare_total_loss(self, masks)
   1650 
   1651           if hasattr(loss_fn, 'reduction'):
-> 1652             per_sample_losses = loss_fn.call(y_true, y_pred)
   1653             weighted_losses = losses_utils.compute_weighted_loss(
   1654                 per_sample_losses,

/usr/local/lib/python3.6/dist-packages/tensorflow_addons/losses/focal_loss.py in call(self, y_true, y_pred)
     86             alpha=self.alpha,
     87             gamma=self.gamma,
---> 88             from_logits=self.from_logits)
     89 
     90     def get_config(self):

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    566         xla_context.Exit()
    567     else:
--> 568       result = self._call(*args, **kwds)
    569 
    570     if tracing_count == self._get_tracing_count():

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    604       # In this case we have not created variables on the first call. So we can
    605       # run the first trace but we should fail if variables are created.
--> 606       results = self._stateful_fn(*args, **kwds)
    607       if self._created_variables:
    608         raise ValueError("Creating variables on a non-first call to a function"

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   2360     """Calls a graph function specialized to the inputs."""
   2361     with self._lock:
-> 2362       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
   2363     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2364 

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2701 
   2702       self._function_cache.missed.add(call_context_key)
-> 2703       graph_function = self._create_graph_function(args, kwargs)
   2704       self._function_cache.primary[cache_key] = graph_function
   2705       return graph_function, args, kwargs

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2591             arg_names=arg_names,
   2592             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2593             capture_by_value=self._capture_by_value),
   2594         self._function_attributes,
   2595         # Tell the ConcreteFunction to clean up its graph once it goes out of

/tensorflow-2.1.0/python3.6/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    976                                           converted_func)
    977 
--> 978       func_outputs = python_func(*func_args, **func_kwargs)
    979 
    980       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/tensorflow-2.1.0/python3.6/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 

/tensorflow-2.1.0/python3.6/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

ValueError: in converted code:

    /usr/local/lib/python3.6/dist-packages/tensorflow_addons/losses/focal_loss.py:126 sigmoid_focal_crossentropy  *
        raise ValueError("Shape mismatch for y_true: {} and y_pred: {}".format(

    ValueError: Shape mismatch for y_true: Tensor("Shape:0", shape=(2,), dtype=int32) and y_pred: Tensor("Shape_1:0", shape=(2,),

Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workinglosses

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions