diff --git a/docs/tutorials/time_stopping.ipynb b/docs/tutorials/time_stopping.ipynb new file mode 100644 index 0000000000..9c3c7c2f86 --- /dev/null +++ b/docs/tutorials/time_stopping.ipynb @@ -0,0 +1,236 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Copyright 2019 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TensorFlow Addons Callbacks: TimeStopping" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "This notebook will demonstrate how to use TimeStopping Callback in TensorFlow Addons." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q --no-deps tensorflow-addons~=0.6" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " # %tensorflow_version only exists in Colab.\n", + " %tensorflow_version 2.x\n", + "except Exception:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_addons as tfa\n", + "\n", + "import tensorflow.keras as keras\n", + "from tensorflow.keras.datasets import mnist\n", + "from tensorflow.keras.models import Sequential\n", + "from tensorflow.keras.layers import Dense, Dropout, Flatten" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import and Normalize Data" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# the data, split between train and test sets\n", + "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", + "# normalize data\n", + "x_train, x_test = x_train / 255.0, x_test / 255.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build Simple MNIST CNN Model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# build the model using the Sequential API\n", + "model = Sequential()\n", + "model.add(Flatten(input_shape=(28, 28)))\n", + "model.add(Dense(128, activation='relu'))\n", + "model.add(Dropout(0.2))\n", + "model.add(Dense(10, activation='softmax'))\n", + "\n", + "model.compile(optimizer='adam',\n", + " loss = 'sparse_categorical_crossentropy',\n", + " metrics=['accuracy'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple TimeStopping Usage" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 60000 samples, validate on 10000 samples\n", + "Epoch 1/100\n", + "60000/60000 [==============================] - 2s 28us/sample - loss: 0.3357 - accuracy: 0.9033 - val_loss: 0.1606 - val_accuracy: 0.9533\n", + "Epoch 2/100\n", + "60000/60000 [==============================] - 1s 23us/sample - loss: 0.1606 - accuracy: 0.9525 - val_loss: 0.1104 - val_accuracy: 0.9669\n", + "Epoch 3/100\n", + "60000/60000 [==============================] - 1s 24us/sample - loss: 0.1185 - accuracy: 0.9645 - val_loss: 0.0949 - val_accuracy: 0.9704\n", + "Epoch 4/100\n", + "60000/60000 [==============================] - 1s 25us/sample - loss: 0.0954 - accuracy: 0.9713 - val_loss: 0.0854 - val_accuracy: 0.9740\n", + "Timed stopping at epoch 4 after training for 0:00:05\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# initialize TimeStopping callback \n", + "time_stopping_callback = tfa.callbacks.TimeStopping(seconds=5, verbose=1)\n", + "\n", + "# train the model with tqdm_callback\n", + "# make sure to set verbose = 0 to disable\n", + "# the default progress bar.\n", + "model.fit(x_train, y_train,\n", + " batch_size=64,\n", + " epochs=100,\n", + " callbacks=[time_stopping_callback],\n", + " validation_data=(x_test, y_test))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tensorflow_addons/callbacks/BUILD b/tensorflow_addons/callbacks/BUILD index 48d88ba3b3..c353c295f5 100644 --- a/tensorflow_addons/callbacks/BUILD +++ b/tensorflow_addons/callbacks/BUILD @@ -6,6 +6,7 @@ py_library( name = "callbacks", srcs = [ "__init__.py", + "time_stopping.py", "tqdm_progress_bar.py", ], deps = [ diff --git a/tensorflow_addons/callbacks/README.md b/tensorflow_addons/callbacks/README.md index 1233c8389a..9b6ec8e320 100644 --- a/tensorflow_addons/callbacks/README.md +++ b/tensorflow_addons/callbacks/README.md @@ -3,14 +3,15 @@ ## Maintainers | Submodule | Maintainers | Contact Info | |:---------- |:------------- |:--------------| +| time_stopping | @shun-lin | shunlin@google.com | | tqdm_progress_bar | @shun-lin | shunlin@google.com | ## Contents | Submodule | Callback | Reference | |:----------------------- |:-------------------|:---------------| +| time_stopping | TimeStopping | N/A | | tqdm_progress_bar | TQDMProgressBar | https://tqdm.github.io/ | - ## Contribution Guidelines #### Standard API In order to conform with the current API standard, all callbacks diff --git a/tensorflow_addons/callbacks/__init__.py b/tensorflow_addons/callbacks/__init__.py index eb15d09df5..0e214d815b 100755 --- a/tensorflow_addons/callbacks/__init__.py +++ b/tensorflow_addons/callbacks/__init__.py @@ -18,4 +18,5 @@ from __future__ import division from __future__ import print_function -from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar \ No newline at end of file +from tensorflow_addons.callbacks.time_stopping import TimeStopping +from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar diff --git a/tensorflow_addons/callbacks/time_stopping.py b/tensorflow_addons/callbacks/time_stopping.py new file mode 100644 index 0000000000..5f863867c9 --- /dev/null +++ b/tensorflow_addons/callbacks/time_stopping.py @@ -0,0 +1,64 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Callback that stops training when a specified amount of time has passed.""" + +from __future__ import absolute_import, division, print_function + +import datetime +import time + +import tensorflow as tf +from tensorflow.keras.callbacks import Callback + + +@tf.keras.utils.register_keras_serializable(package='Addons') +class TimeStopping(Callback): + """Stop training when a specified amount of time has passed. + + Args: + seconds: maximum amount of time before stopping. + Defaults to 86400 (1 day). + verbose: verbosity mode. Defaults to 0. + """ + + def __init__(self, seconds=86400, verbose=0): + super(TimeStopping, self).__init__() + + self.seconds = seconds + self.verbose = verbose + + def on_train_begin(self, logs=None): + self.stopping_time = time.time() + self.seconds + + def on_epoch_end(self, epoch, logs={}): + if time.time() >= self.stopping_time: + self.model.stop_training = True + self.stopped_epoch = epoch + + def on_train_end(self, logs=None): + if self.verbose > 0: + formatted_time = datetime.timedelta(seconds=self.seconds) + msg = 'Timed stopping at epoch {} after training for {}'.format( + self.stopped_epoch + 1, formatted_time) + print(msg) + + def get_config(self): + config = { + 'seconds': self.seconds, + 'verbose': self.verbose, + } + + base_config = super(TimeStopping, self).get_config() + return dict(list(base_config.items()) + list(config.items()))