From 57c19081b1130f5dd02f60c2b3d1b61579632a57 Mon Sep 17 00:00:00 2001 From: Shun Lin Date: Thu, 12 Dec 2019 01:20:12 -0800 Subject: [PATCH 1/4] added tutorial and code for TimeStopping --- docs/tutorials/time_stopping.ipynb | 236 +++++++++++++++++++ tensorflow_addons/callbacks/BUILD | 1 + tensorflow_addons/callbacks/README.md | 3 +- tensorflow_addons/callbacks/__init__.py | 3 +- tensorflow_addons/callbacks/time_stopping.py | 64 +++++ 5 files changed, 305 insertions(+), 2 deletions(-) create mode 100644 docs/tutorials/time_stopping.ipynb create mode 100644 tensorflow_addons/callbacks/time_stopping.py diff --git a/docs/tutorials/time_stopping.ipynb b/docs/tutorials/time_stopping.ipynb new file mode 100644 index 0000000000..37c41d28a1 --- /dev/null +++ b/docs/tutorials/time_stopping.ipynb @@ -0,0 +1,236 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Copyright 2019 The TensorFlow Authors." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TensorFlow Addons Callbacks: TimeStopping" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View source on GitHub\n", + " \n", + " Download notebook\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "This notebook will demonstrate how to use TimeStopping Callback in TensorFlow Addons." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -q --no-deps tensorflow-addons~=0.7" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " # %tensorflow_version only exists in Colab.\n", + " %tensorflow_version 2.x\n", + "except Exception:\n", + " pass" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import tensorflow_addons as tfa\n", + "\n", + "import tensorflow.keras as keras\n", + "from tensorflow.keras.datasets import mnist\n", + "from tensorflow.keras.models import Sequential\n", + "from tensorflow.keras.layers import Dense, Dropout, Flatten" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import and Normalize Data" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "# the data, split between train and test sets\n", + "(x_train, y_train), (x_test, y_test) = mnist.load_data()\n", + "# normalize data\n", + "x_train, x_test = x_train / 255.0, x_test / 255.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build Simple MNIST CNN Model" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# build the model using the Sequential API\n", + "model = Sequential()\n", + "model.add(Flatten(input_shape=(28, 28)))\n", + "model.add(Dense(128, activation='relu'))\n", + "model.add(Dropout(0.2))\n", + "model.add(Dense(10, activation='softmax'))\n", + "\n", + "model.compile(optimizer='adam',\n", + " loss = 'sparse_categorical_crossentropy',\n", + " metrics=['accuracy'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Simple TimeStopping Usage" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 60000 samples, validate on 10000 samples\n", + "Epoch 1/100\n", + "60000/60000 [==============================] - 2s 28us/sample - loss: 0.3357 - accuracy: 0.9033 - val_loss: 0.1606 - val_accuracy: 0.9533\n", + "Epoch 2/100\n", + "60000/60000 [==============================] - 1s 23us/sample - loss: 0.1606 - accuracy: 0.9525 - val_loss: 0.1104 - val_accuracy: 0.9669\n", + "Epoch 3/100\n", + "60000/60000 [==============================] - 1s 24us/sample - loss: 0.1185 - accuracy: 0.9645 - val_loss: 0.0949 - val_accuracy: 0.9704\n", + "Epoch 4/100\n", + "60000/60000 [==============================] - 1s 25us/sample - loss: 0.0954 - accuracy: 0.9713 - val_loss: 0.0854 - val_accuracy: 0.9740\n", + "Timed stopping at epoch 4 after training for 0:00:05\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# initialize TimeStopping callback \n", + "time_stopping_callback = tfa.callbacks.TimeStopping(seconds=5, verbose=1)\n", + "\n", + "# train the model with tqdm_callback\n", + "# make sure to set verbose = 0 to disable\n", + "# the default progress bar.\n", + "model.fit(x_train, y_train,\n", + " batch_size=64,\n", + " epochs=100,\n", + " callbacks=[time_stopping_callback],\n", + " validation_data=(x_test, y_test))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tensorflow_addons/callbacks/BUILD b/tensorflow_addons/callbacks/BUILD index 48d88ba3b3..c353c295f5 100644 --- a/tensorflow_addons/callbacks/BUILD +++ b/tensorflow_addons/callbacks/BUILD @@ -6,6 +6,7 @@ py_library( name = "callbacks", srcs = [ "__init__.py", + "time_stopping.py", "tqdm_progress_bar.py", ], deps = [ diff --git a/tensorflow_addons/callbacks/README.md b/tensorflow_addons/callbacks/README.md index 1233c8389a..e0f820554d 100644 --- a/tensorflow_addons/callbacks/README.md +++ b/tensorflow_addons/callbacks/README.md @@ -4,12 +4,13 @@ | Submodule | Maintainers | Contact Info | |:---------- |:------------- |:--------------| | tqdm_progress_bar | @shun-lin | shunlin@google.com | +| time_stopping | @shun-lin | shunlin@google.com | ## Contents | Submodule | Callback | Reference | |:----------------------- |:-------------------|:---------------| | tqdm_progress_bar | TQDMProgressBar | https://tqdm.github.io/ | - +| time_stopping | TimeStopping | N/A | ## Contribution Guidelines #### Standard API diff --git a/tensorflow_addons/callbacks/__init__.py b/tensorflow_addons/callbacks/__init__.py index eb15d09df5..0e214d815b 100755 --- a/tensorflow_addons/callbacks/__init__.py +++ b/tensorflow_addons/callbacks/__init__.py @@ -18,4 +18,5 @@ from __future__ import division from __future__ import print_function -from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar \ No newline at end of file +from tensorflow_addons.callbacks.time_stopping import TimeStopping +from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar diff --git a/tensorflow_addons/callbacks/time_stopping.py b/tensorflow_addons/callbacks/time_stopping.py new file mode 100644 index 0000000000..67d7a5975b --- /dev/null +++ b/tensorflow_addons/callbacks/time_stopping.py @@ -0,0 +1,64 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Callback that stops training when a specified amount of time has passed.""" + +from __future__ import absolute_import, division, print_function + +import datetime +import time + +import tensorflow as tf +from tensorflow.keras.callbacks import Callback + + +@tf.keras.utils.register_keras_serializable(package='Addons') +class TimeStopping(Callback): + """Stop training when a specified amount of time has passed. + + Args: + seconds: maximum amount of time before stopping. + Defaults to 86400 (1 day). + verbose: verbosity mode. Defaults to 0. + """ + + def __init__(self, seconds=86400, verbose=0): + super(TimeStopping, self).__init__() + + self.seconds = seconds + self.verbose = verbose + + def on_train_begin(self, logs=None): + self.stopping_time = time.time() + self.seconds + + def on_epoch_end(self, epoch, logs={}): + if time.time() >= self.stopping_time: + self.model.stop_training = True + self.stopped_epoch = epoch + + def on_train_end(self, logs=None): + if self.verbose > 0: + formatted_time = datetime.timedelta(seconds=self.seconds) + print( + 'Timed stopping at epoch {} after training for {}'.format( + self.stopped_epoch + 1, formatted_time)) + + def get_config(self): + config = { + 'seconds': self.seconds, + 'verbose': self.verbose, + } + + base_config = super(TimeStopping, self).get_config() + return dict(list(base_config.items()) + list(config.items())) From 2ad070837bb69a026df8952942db29145baaa946 Mon Sep 17 00:00:00 2001 From: Shun Lin Date: Thu, 12 Dec 2019 01:27:21 -0800 Subject: [PATCH 2/4] formatted file --- docs/tutorials/time_stopping.ipynb | 2 +- tensorflow_addons/callbacks/time_stopping.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/tutorials/time_stopping.ipynb b/docs/tutorials/time_stopping.ipynb index 37c41d28a1..9c3c7c2f86 100644 --- a/docs/tutorials/time_stopping.ipynb +++ b/docs/tutorials/time_stopping.ipynb @@ -74,7 +74,7 @@ "metadata": {}, "outputs": [], "source": [ - "!pip install -q --no-deps tensorflow-addons~=0.7" + "!pip install -q --no-deps tensorflow-addons~=0.6" ] }, { diff --git a/tensorflow_addons/callbacks/time_stopping.py b/tensorflow_addons/callbacks/time_stopping.py index 67d7a5975b..c1081011c1 100644 --- a/tensorflow_addons/callbacks/time_stopping.py +++ b/tensorflow_addons/callbacks/time_stopping.py @@ -50,9 +50,9 @@ def on_epoch_end(self, epoch, logs={}): def on_train_end(self, logs=None): if self.verbose > 0: formatted_time = datetime.timedelta(seconds=self.seconds) - print( - 'Timed stopping at epoch {} after training for {}'.format( - self.stopped_epoch + 1, formatted_time)) + message = 'Timed stopping at epoch {} after training for {}'.format( + self.stopped_epoch + 1, formatted_time) + print(message) def get_config(self): config = { From 4be27486303991f93e30f0b50f37069032d75cc5 Mon Sep 17 00:00:00 2001 From: Shun Lin Date: Thu, 12 Dec 2019 18:17:16 -0800 Subject: [PATCH 3/4] change orders in README.MD --- tensorflow_addons/callbacks/README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/callbacks/README.md b/tensorflow_addons/callbacks/README.md index e0f820554d..9b6ec8e320 100644 --- a/tensorflow_addons/callbacks/README.md +++ b/tensorflow_addons/callbacks/README.md @@ -3,14 +3,14 @@ ## Maintainers | Submodule | Maintainers | Contact Info | |:---------- |:------------- |:--------------| -| tqdm_progress_bar | @shun-lin | shunlin@google.com | | time_stopping | @shun-lin | shunlin@google.com | +| tqdm_progress_bar | @shun-lin | shunlin@google.com | ## Contents | Submodule | Callback | Reference | |:----------------------- |:-------------------|:---------------| -| tqdm_progress_bar | TQDMProgressBar | https://tqdm.github.io/ | | time_stopping | TimeStopping | N/A | +| tqdm_progress_bar | TQDMProgressBar | https://tqdm.github.io/ | ## Contribution Guidelines #### Standard API From 2d0b3d605df6009784e8aa7df62e358fd3174dfa Mon Sep 17 00:00:00 2001 From: Shun Lin Date: Thu, 12 Dec 2019 20:05:42 -0800 Subject: [PATCH 4/4] fixed line too long issue --- tensorflow_addons/callbacks/time_stopping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/callbacks/time_stopping.py b/tensorflow_addons/callbacks/time_stopping.py index c1081011c1..5f863867c9 100644 --- a/tensorflow_addons/callbacks/time_stopping.py +++ b/tensorflow_addons/callbacks/time_stopping.py @@ -50,9 +50,9 @@ def on_epoch_end(self, epoch, logs={}): def on_train_end(self, logs=None): if self.verbose > 0: formatted_time = datetime.timedelta(seconds=self.seconds) - message = 'Timed stopping at epoch {} after training for {}'.format( + msg = 'Timed stopping at epoch {} after training for {}'.format( self.stopped_epoch + 1, formatted_time) - print(message) + print(msg) def get_config(self): config = {