Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions returnn/tf/layers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2340,6 +2340,43 @@ def _check_init(self):
assert expected_output_dim == self.output.dim, (
"Expected output dim is %r but the output has dim %r. " % (expected_output_dim, self.output.dim) +
"Target: %s, output: %s" % (self.target, self.output))
if self.base_network.get_config().bool("debug_runtime_sanity_checks", False):
with tf.name_scope("Loss_debug_runtime_sanity_checks"):
checks = [self.output.get_runtime_sanity_check_op(), self.target.get_runtime_sanity_check_op()]
out_shape = tf.shape(self.output.placeholder)
target_shape = tf.shape(self.target.placeholder)
if self.output.have_batch_axis() and self.target.have_batch_axis():
out_batch_dim = out_shape[self.output.batch_dim_axis]
target_batch_dim = target_shape[self.target.batch_dim_axis]
checks += [tf.Assert(
tf.equal(out_batch_dim, target_batch_dim),
["Loss_debug_runtime_sanity_checks", "batch dim mismatch",
"output:", str(self.output), "shape", out_shape,
"target:", str(self.target), "shape", target_shape])]
if not self.recurrent: # framewise
if self.output.have_time_axis() and self.target.have_time_axis():
out_time_dim = out_shape[self.output.time_dim_axis]
target_time_dim = target_shape[self.target.time_dim_axis]
checks += [tf.Assert(
tf.equal(out_time_dim, target_time_dim),
["Loss_debug_runtime_sanity_checks", "time dim mismatch",
"output:", str(self.output), "shape", out_shape,
"target:", str(self.target), "shape", target_shape])]
if self.output.has_dynamic_size(self.output.time_dim_axis):
assert self.target.has_dynamic_size(self.target.time_dim_axis)
out_sizes = self.output.get_dynamic_size(self.output.time_dim_axis)
target_sizes = self.target.get_dynamic_size(self.target.time_dim_axis)
checks += [tf.Assert(
tf.reduce_all(tf.equal(out_sizes, target_sizes)),
["Loss_debug_runtime_sanity_checks", "dyn seq len mismatch",
"output:", str(self.output), "shape", out_shape, "sizes", out_sizes,
"target:", str(self.target), "shape", target_shape, "sizes", target_sizes], summarize=20)]
with tf.control_dependencies(checks):
if self.target_flat is not None:
self.target_flat = tf.identity(self.target_flat)
else:
self.target = self.target.copy()
self.target.placeholder = tf.identity(self.target.placeholder)

def get_error(self):
"""
Expand Down