Skip to content

Commit aeb689b

Browse files
Merge branch 'master' of https://github.com/tensorflow/addons into add_metrics
add new changes
2 parents a8b9f40 + 6c7f559 commit aeb689b

File tree

5 files changed

+15
-9
lines changed

5 files changed

+15
-9
lines changed

tensorflow_addons/activations/sparsemax.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,8 @@ def _compute_2d_sparsemax(logits, name=None):
126126
p = tf.math.maximum(
127127
tf.cast(0, logits.dtype), z - tf.expand_dims(tau_z, -1))
128128
# If k_z = 0 or if z = nan, then the input is invalid
129-
p_safe = tf.where(
129+
# TODO: Adjust dimension order for TF2 broadcasting
130+
p_safe = tf.compat.v1.where(
130131
tf.math.logical_or(
131132
tf.math.equal(k_z, 0), tf.math.is_nan(z_cumsum[:, -1])),
132133
tf.fill([obs, dims], tf.cast(float("nan"), logits.dtype)), p)

tensorflow_addons/optimizers/lazy_adam.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _resource_apply_sparse(self, grad, var, indices):
5555
local_step = tf.cast(self.iterations + 1, var_dtype)
5656
beta_1_power = tf.math.pow(beta_1_t, local_step)
5757
beta_2_power = tf.math.pow(beta_2_t, local_step)
58-
epsilon_t = self._get_hyper('epsilon', var_dtype)
58+
epsilon_t = tf.convert_to_tensor(self.epsilon, var_dtype)
5959
lr = (lr_t * tf.math.sqrt(1 - beta_2_power) / (1 - beta_1_power))
6060

6161
# \\(m := beta1 * m + (1 - beta1) * g_t\\)

tensorflow_addons/optimizers/weight_decay_optimizers_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def testBasicCallableParams(self):
221221
learning_rate=lambda: 0.001,
222222
beta_1=lambda: 0.9,
223223
beta_2=lambda: 0.999,
224-
epsilon=lambda: 1e-8,
224+
epsilon=1e-8,
225225
weight_decay=lambda: WEIGHT_DECAY)
226226

227227

tensorflow_addons/seq2seq/decoder.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -406,8 +406,10 @@ def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
406406
# Zero out output values past finish
407407
if impute_finished:
408408
emit = tf.nest.map_structure(
409-
lambda out, zero: tf.where(finished, zero, out),
410-
next_outputs, zero_outputs)
409+
# TODO: Adjust dimension order for TF2 broadcasting
410+
lambda out, zero: tf.compat.v1.where(finished, zero, out),
411+
next_outputs,
412+
zero_outputs)
411413
else:
412414
emit = next_outputs
413415

@@ -419,7 +421,9 @@ def _maybe_copy_state(new, cur):
419421
else:
420422
new.set_shape(cur.shape)
421423
pass_through = (new.shape.ndims == 0)
422-
return new if pass_through else tf.where(finished, cur, new)
424+
# TODO: Adjust dimension order for TF2 broadcasting
425+
return new if pass_through else tf.compat.v1.where(
426+
finished, cur, new)
423427

424428
if impute_finished:
425429
next_state = tf.nest.map_structure(_maybe_copy_state,

tensorflow_addons/seq2seq/sampler.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -448,9 +448,10 @@ def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
448448
auxiliary_inputs)
449449

450450
if self.next_inputs_fn is None:
451-
return tf.where(sample_ids,
452-
maybe_concatenate_auxiliary_inputs(outputs),
453-
base_next_inputs)
451+
# TODO: Adjust dimension order for TF2 broadcasting
452+
return tf.compat.v1.where(
453+
sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
454+
base_next_inputs)
454455

455456
where_sampling = tf.cast(tf.where(sample_ids), tf.int32)
456457
where_not_sampling = tf.cast(

0 commit comments

Comments
 (0)