diff --git a/tests/keras/activations_test.py b/tests/keras/activations_test.py index 4c0bb3683031..7b6396336dcd 100644 --- a/tests/keras/activations_test.py +++ b/tests/keras/activations_test.py @@ -118,7 +118,6 @@ def softplus(x): expected = softplus(test_values) assert_allclose(result, expected, rtol=1e-05) - def test_softsign(): """Test using a reference softsign implementation. """ @@ -153,6 +152,23 @@ def ref_sigmoid(x): expected = sigmoid(test_values) assert_allclose(result, expected, rtol=1e-05) +def test_mish(): + """Test mish implementation. + """ + def ref_mish(x): + softplus = np.log(np.ones_like(x) + np.exp(x)) + tanh_softplus_x = (np.exp(softplus) - np.exp(-softplus))/(np.exp(softplus) + np.exp(-softplus)) + return x*tanh_softplus_x + + mish = np.vectorize(ref_mish) + + x = K.placeholder(ndim=2) + f = K.function([x], [activations.mish(x)]) + test_values = get_standard_values() + + result = f([test_values])[0] + expected = mish(test_values) + assert_allclose(result, expected, rtol=1e-05) def test_hard_sigmoid(): """Test using a reference hard sigmoid implementation.