diff --git a/returnn/tf/layers/base.py b/returnn/tf/layers/base.py index 7244c49011..f043e7c5a1 100644 --- a/returnn/tf/layers/base.py +++ b/returnn/tf/layers/base.py @@ -2340,6 +2340,43 @@ def _check_init(self): assert expected_output_dim == self.output.dim, ( "Expected output dim is %r but the output has dim %r. " % (expected_output_dim, self.output.dim) + "Target: %s, output: %s" % (self.target, self.output)) + if self.base_network.get_config().bool("debug_runtime_sanity_checks", False): + with tf.name_scope("Loss_debug_runtime_sanity_checks"): + checks = [self.output.get_runtime_sanity_check_op(), self.target.get_runtime_sanity_check_op()] + out_shape = tf.shape(self.output.placeholder) + target_shape = tf.shape(self.target.placeholder) + if self.output.have_batch_axis() and self.target.have_batch_axis(): + out_batch_dim = out_shape[self.output.batch_dim_axis] + target_batch_dim = target_shape[self.target.batch_dim_axis] + checks += [tf.Assert( + tf.equal(out_batch_dim, target_batch_dim), + ["Loss_debug_runtime_sanity_checks", "batch dim mismatch", + "output:", str(self.output), "shape", out_shape, + "target:", str(self.target), "shape", target_shape])] + if not self.recurrent: # framewise + if self.output.have_time_axis() and self.target.have_time_axis(): + out_time_dim = out_shape[self.output.time_dim_axis] + target_time_dim = target_shape[self.target.time_dim_axis] + checks += [tf.Assert( + tf.equal(out_time_dim, target_time_dim), + ["Loss_debug_runtime_sanity_checks", "time dim mismatch", + "output:", str(self.output), "shape", out_shape, + "target:", str(self.target), "shape", target_shape])] + if self.output.has_dynamic_size(self.output.time_dim_axis): + assert self.target.has_dynamic_size(self.target.time_dim_axis) + out_sizes = self.output.get_dynamic_size(self.output.time_dim_axis) + target_sizes = self.target.get_dynamic_size(self.target.time_dim_axis) + checks += [tf.Assert( + tf.reduce_all(tf.equal(out_sizes, target_sizes)), + ["Loss_debug_runtime_sanity_checks", "dyn seq len mismatch", + "output:", str(self.output), "shape", out_shape, "sizes", out_sizes, + "target:", str(self.target), "shape", target_shape, "sizes", target_sizes], summarize=20)] + with tf.control_dependencies(checks): + if self.target_flat is not None: + self.target_flat = tf.identity(self.target_flat) + else: + self.target = self.target.copy() + self.target.placeholder = tf.identity(self.target.placeholder) def get_error(self): """