|  | 
|  | 1 | +import time | 
|  | 2 | + | 
|  | 3 | +import pytest | 
|  | 4 | +import numpy as np | 
|  | 5 | +import tensorflow as tf | 
|  | 6 | +from tensorflow.keras.models import Sequential | 
|  | 7 | +from tensorflow.keras.layers import Dense | 
|  | 8 | + | 
|  | 9 | +from tensorflow_addons.callbacks.time_stopping import TimeStopping | 
|  | 10 | + | 
|  | 11 | + | 
|  | 12 | +class SleepLayer(tf.keras.layers.Layer): | 
|  | 13 | +    def __init__(self, secs): | 
|  | 14 | +        self.secs = secs | 
|  | 15 | +        super().__init__(dynamic=True) | 
|  | 16 | + | 
|  | 17 | +    def call(self, inputs): | 
|  | 18 | +        time.sleep(self.secs) | 
|  | 19 | +        return inputs | 
|  | 20 | + | 
|  | 21 | + | 
|  | 22 | +def get_data_and_model(secs): | 
|  | 23 | +    X = np.array([[0, 0], [0, 1], [1, 0], [1, 1]]) | 
|  | 24 | +    y = np.array([[0], [1], [1], [0]]) | 
|  | 25 | + | 
|  | 26 | +    model = Sequential() | 
|  | 27 | +    model.add(SleepLayer(secs)) | 
|  | 28 | +    model.add(Dense(1)) | 
|  | 29 | +    model.compile(loss="mean_squared_error") | 
|  | 30 | + | 
|  | 31 | +    # In case there is some initialization going on. | 
|  | 32 | +    model.fit(X, y, epochs=1, verbose=0) | 
|  | 33 | +    return X, y, model | 
|  | 34 | + | 
|  | 35 | + | 
|  | 36 | +def test_stop_at_the_right_time(): | 
|  | 37 | +    X, y, model = get_data_and_model(0.1) | 
|  | 38 | + | 
|  | 39 | +    time_stopping = TimeStopping(2, verbose=0) | 
|  | 40 | +    history = model.fit(X, y, epochs=30, verbose=0, callbacks=[time_stopping]) | 
|  | 41 | + | 
|  | 42 | +    assert len(history.epoch) <= 20 | 
|  | 43 | + | 
|  | 44 | + | 
|  | 45 | +def test_default_value(): | 
|  | 46 | +    X, y, model = get_data_and_model(0.1) | 
|  | 47 | + | 
|  | 48 | +    time_stopping = TimeStopping() | 
|  | 49 | +    history = model.fit(X, y, epochs=15, verbose=0, callbacks=[time_stopping]) | 
|  | 50 | + | 
|  | 51 | +    assert len(history.epoch) == 15 | 
|  | 52 | + | 
|  | 53 | + | 
|  | 54 | +@pytest.mark.parametrize("verbose", [0, 1]) | 
|  | 55 | +def test_time_stopping_verbose(capsys, verbose): | 
|  | 56 | +    X, y, model = get_data_and_model(0.25) | 
|  | 57 | + | 
|  | 58 | +    time_stopping = TimeStopping(1, verbose=verbose) | 
|  | 59 | + | 
|  | 60 | +    capsys.readouterr()  # flush the stdout/stderr buffer. | 
|  | 61 | +    history = model.fit(X, y, epochs=10, verbose=0, callbacks=[time_stopping]) | 
|  | 62 | +    fit_stdout = capsys.readouterr().out | 
|  | 63 | +    nb_epochs_run = len(history.epoch) | 
|  | 64 | +    message = "Timed stopping at epoch " + str(nb_epochs_run) | 
|  | 65 | +    if verbose: | 
|  | 66 | +        assert message in fit_stdout | 
|  | 67 | +    else: | 
|  | 68 | +        assert message not in fit_stdout | 
|  | 69 | +    assert len(history.epoch) <= 4 | 
0 commit comments