Skip to content

Commit 679333c

Browse files
Added tests for the time stopping callback (#1543)
* Added tests for time stopping. * Small refactoring. * Rename. * Refactoring. * Calming the angry bazel. * Added the dependency to optimizers.
1 parent f3d6ee5 commit 679333c

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import time
2+
3+
import pytest
4+
import numpy as np
5+
import tensorflow as tf
6+
from tensorflow.keras.models import Sequential
7+
from tensorflow.keras.layers import Dense
8+
9+
from tensorflow_addons.callbacks.time_stopping import TimeStopping
10+
11+
12+
class SleepLayer(tf.keras.layers.Layer):
13+
def __init__(self, secs):
14+
self.secs = secs
15+
super().__init__(dynamic=True)
16+
17+
def call(self, inputs):
18+
time.sleep(self.secs)
19+
return inputs
20+
21+
22+
def get_data_and_model(secs):
23+
X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
24+
y = np.array([[0], [1], [1], [0]])
25+
26+
model = Sequential()
27+
model.add(SleepLayer(secs))
28+
model.add(Dense(1))
29+
model.compile(loss="mean_squared_error")
30+
31+
# In case there is some initialization going on.
32+
model.fit(X, y, epochs=1, verbose=0)
33+
return X, y, model
34+
35+
36+
def test_stop_at_the_right_time():
37+
X, y, model = get_data_and_model(0.1)
38+
39+
time_stopping = TimeStopping(2, verbose=0)
40+
history = model.fit(X, y, epochs=30, verbose=0, callbacks=[time_stopping])
41+
42+
assert len(history.epoch) <= 20
43+
44+
45+
def test_default_value():
46+
X, y, model = get_data_and_model(0.1)
47+
48+
time_stopping = TimeStopping()
49+
history = model.fit(X, y, epochs=15, verbose=0, callbacks=[time_stopping])
50+
51+
assert len(history.epoch) == 15
52+
53+
54+
@pytest.mark.parametrize("verbose", [0, 1])
55+
def test_time_stopping_verbose(capsys, verbose):
56+
X, y, model = get_data_and_model(0.25)
57+
58+
time_stopping = TimeStopping(1, verbose=verbose)
59+
60+
capsys.readouterr() # flush the stdout/stderr buffer.
61+
history = model.fit(X, y, epochs=10, verbose=0, callbacks=[time_stopping])
62+
fit_stdout = capsys.readouterr().out
63+
nb_epochs_run = len(history.epoch)
64+
message = "Timed stopping at epoch " + str(nb_epochs_run)
65+
if verbose:
66+
assert message in fit_stdout
67+
else:
68+
assert message not in fit_stdout
69+
assert len(history.epoch) <= 4

0 commit comments

Comments
 (0)