diff --git a/docs/tutorials/time_stopping.ipynb b/docs/tutorials/time_stopping.ipynb
new file mode 100644
index 0000000000..9c3c7c2f86
--- /dev/null
+++ b/docs/tutorials/time_stopping.ipynb
@@ -0,0 +1,236 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "##### Copyright 2019 The TensorFlow Authors."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
+ "# you may not use this file except in compliance with the License.\n",
+ "# You may obtain a copy of the License at\n",
+ "#\n",
+ "# https://www.apache.org/licenses/LICENSE-2.0\n",
+ "#\n",
+ "# Unless required by applicable law or agreed to in writing, software\n",
+ "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
+ "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
+ "# See the License for the specific language governing permissions and\n",
+ "# limitations under the License."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# TensorFlow Addons Callbacks: TimeStopping"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Overview\n",
+ "This notebook will demonstrate how to use TimeStopping Callback in TensorFlow Addons."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "!pip install -q --no-deps tensorflow-addons~=0.6"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "try:\n",
+ " # %tensorflow_version only exists in Colab.\n",
+ " %tensorflow_version 2.x\n",
+ "except Exception:\n",
+ " pass"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import tensorflow as tf\n",
+ "import tensorflow_addons as tfa\n",
+ "\n",
+ "import tensorflow.keras as keras\n",
+ "from tensorflow.keras.datasets import mnist\n",
+ "from tensorflow.keras.models import Sequential\n",
+ "from tensorflow.keras.layers import Dense, Dropout, Flatten"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Import and Normalize Data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# the data, split between train and test sets\n",
+ "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
+ "# normalize data\n",
+ "x_train, x_test = x_train / 255.0, x_test / 255.0"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Build Simple MNIST CNN Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# build the model using the Sequential API\n",
+ "model = Sequential()\n",
+ "model.add(Flatten(input_shape=(28, 28)))\n",
+ "model.add(Dense(128, activation='relu'))\n",
+ "model.add(Dropout(0.2))\n",
+ "model.add(Dense(10, activation='softmax'))\n",
+ "\n",
+ "model.compile(optimizer='adam',\n",
+ " loss = 'sparse_categorical_crossentropy',\n",
+ " metrics=['accuracy'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Simple TimeStopping Usage"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {
+ "scrolled": true
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Train on 60000 samples, validate on 10000 samples\n",
+ "Epoch 1/100\n",
+ "60000/60000 [==============================] - 2s 28us/sample - loss: 0.3357 - accuracy: 0.9033 - val_loss: 0.1606 - val_accuracy: 0.9533\n",
+ "Epoch 2/100\n",
+ "60000/60000 [==============================] - 1s 23us/sample - loss: 0.1606 - accuracy: 0.9525 - val_loss: 0.1104 - val_accuracy: 0.9669\n",
+ "Epoch 3/100\n",
+ "60000/60000 [==============================] - 1s 24us/sample - loss: 0.1185 - accuracy: 0.9645 - val_loss: 0.0949 - val_accuracy: 0.9704\n",
+ "Epoch 4/100\n",
+ "60000/60000 [==============================] - 1s 25us/sample - loss: 0.0954 - accuracy: 0.9713 - val_loss: 0.0854 - val_accuracy: 0.9740\n",
+ "Timed stopping at epoch 4 after training for 0:00:05\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ ""
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# initialize TimeStopping callback \n",
+ "time_stopping_callback = tfa.callbacks.TimeStopping(seconds=5, verbose=1)\n",
+ "\n",
+ "# train the model with tqdm_callback\n",
+ "# make sure to set verbose = 0 to disable\n",
+ "# the default progress bar.\n",
+ "model.fit(x_train, y_train,\n",
+ " batch_size=64,\n",
+ " epochs=100,\n",
+ " callbacks=[time_stopping_callback],\n",
+ " validation_data=(x_test, y_test))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.6.2"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/tensorflow_addons/callbacks/BUILD b/tensorflow_addons/callbacks/BUILD
index 48d88ba3b3..c353c295f5 100644
--- a/tensorflow_addons/callbacks/BUILD
+++ b/tensorflow_addons/callbacks/BUILD
@@ -6,6 +6,7 @@ py_library(
name = "callbacks",
srcs = [
"__init__.py",
+ "time_stopping.py",
"tqdm_progress_bar.py",
],
deps = [
diff --git a/tensorflow_addons/callbacks/README.md b/tensorflow_addons/callbacks/README.md
index 1233c8389a..9b6ec8e320 100644
--- a/tensorflow_addons/callbacks/README.md
+++ b/tensorflow_addons/callbacks/README.md
@@ -3,14 +3,15 @@
## Maintainers
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
+| time_stopping | @shun-lin | shunlin@google.com |
| tqdm_progress_bar | @shun-lin | shunlin@google.com |
## Contents
| Submodule | Callback | Reference |
|:----------------------- |:-------------------|:---------------|
+| time_stopping | TimeStopping | N/A |
| tqdm_progress_bar | TQDMProgressBar | https://tqdm.github.io/ |
-
## Contribution Guidelines
#### Standard API
In order to conform with the current API standard, all callbacks
diff --git a/tensorflow_addons/callbacks/__init__.py b/tensorflow_addons/callbacks/__init__.py
index eb15d09df5..0e214d815b 100755
--- a/tensorflow_addons/callbacks/__init__.py
+++ b/tensorflow_addons/callbacks/__init__.py
@@ -18,4 +18,5 @@
from __future__ import division
from __future__ import print_function
-from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar
\ No newline at end of file
+from tensorflow_addons.callbacks.time_stopping import TimeStopping
+from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar
diff --git a/tensorflow_addons/callbacks/time_stopping.py b/tensorflow_addons/callbacks/time_stopping.py
new file mode 100644
index 0000000000..5f863867c9
--- /dev/null
+++ b/tensorflow_addons/callbacks/time_stopping.py
@@ -0,0 +1,64 @@
+# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Callback that stops training when a specified amount of time has passed."""
+
+from __future__ import absolute_import, division, print_function
+
+import datetime
+import time
+
+import tensorflow as tf
+from tensorflow.keras.callbacks import Callback
+
+
+@tf.keras.utils.register_keras_serializable(package='Addons')
+class TimeStopping(Callback):
+ """Stop training when a specified amount of time has passed.
+
+ Args:
+ seconds: maximum amount of time before stopping.
+ Defaults to 86400 (1 day).
+ verbose: verbosity mode. Defaults to 0.
+ """
+
+ def __init__(self, seconds=86400, verbose=0):
+ super(TimeStopping, self).__init__()
+
+ self.seconds = seconds
+ self.verbose = verbose
+
+ def on_train_begin(self, logs=None):
+ self.stopping_time = time.time() + self.seconds
+
+ def on_epoch_end(self, epoch, logs={}):
+ if time.time() >= self.stopping_time:
+ self.model.stop_training = True
+ self.stopped_epoch = epoch
+
+ def on_train_end(self, logs=None):
+ if self.verbose > 0:
+ formatted_time = datetime.timedelta(seconds=self.seconds)
+ msg = 'Timed stopping at epoch {} after training for {}'.format(
+ self.stopped_epoch + 1, formatted_time)
+ print(msg)
+
+ def get_config(self):
+ config = {
+ 'seconds': self.seconds,
+ 'verbose': self.verbose,
+ }
+
+ base_config = super(TimeStopping, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))