Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions docs/tutorials/time_stopping.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"##### Copyright 2019 The TensorFlow Authors."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
"# you may not use this file except in compliance with the License.\n",
"# You may obtain a copy of the License at\n",
"#\n",
"# https://www.apache.org/licenses/LICENSE-2.0\n",
"#\n",
"# Unless required by applicable law or agreed to in writing, software\n",
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
"# See the License for the specific language governing permissions and\n",
"# limitations under the License."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TensorFlow Addons Callbacks: TimeStopping"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://www.tensorflow.org/addons/tutorials/time_stopping\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/addons/blob/master/docs/tutorials/time_stopping.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
" </td>\n",
" <td>\n",
" <a target=\"_blank\" href=\"https://github.com/tensorflow/addons/blob/master/docs/tutorials/time_stopping.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
" </td>\n",
" <td>\n",
" <a href=\"https://storage.googleapis.com/tensorflow_docs/addons/docs/tutorials/time_stopping.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
" </td>\n",
"</table>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Overview\n",
"This notebook will demonstrate how to use TimeStopping Callback in TensorFlow Addons."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"!pip install -q --no-deps tensorflow-addons~=0.6"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"try:\n",
" # %tensorflow_version only exists in Colab.\n",
" %tensorflow_version 2.x\n",
"except Exception:\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import tensorflow_addons as tfa\n",
"\n",
"import tensorflow.keras as keras\n",
"from tensorflow.keras.datasets import mnist\n",
"from tensorflow.keras.models import Sequential\n",
"from tensorflow.keras.layers import Dense, Dropout, Flatten"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import and Normalize Data"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# the data, split between train and test sets\n",
"(x_train, y_train), (x_test, y_test) = mnist.load_data()\n",
"# normalize data\n",
"x_train, x_test = x_train / 255.0, x_test / 255.0"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build Simple MNIST CNN Model"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# build the model using the Sequential API\n",
"model = Sequential()\n",
"model.add(Flatten(input_shape=(28, 28)))\n",
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.2))\n",
"model.add(Dense(10, activation='softmax'))\n",
"\n",
"model.compile(optimizer='adam',\n",
" loss = 'sparse_categorical_crossentropy',\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Simple TimeStopping Usage"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train on 60000 samples, validate on 10000 samples\n",
"Epoch 1/100\n",
"60000/60000 [==============================] - 2s 28us/sample - loss: 0.3357 - accuracy: 0.9033 - val_loss: 0.1606 - val_accuracy: 0.9533\n",
"Epoch 2/100\n",
"60000/60000 [==============================] - 1s 23us/sample - loss: 0.1606 - accuracy: 0.9525 - val_loss: 0.1104 - val_accuracy: 0.9669\n",
"Epoch 3/100\n",
"60000/60000 [==============================] - 1s 24us/sample - loss: 0.1185 - accuracy: 0.9645 - val_loss: 0.0949 - val_accuracy: 0.9704\n",
"Epoch 4/100\n",
"60000/60000 [==============================] - 1s 25us/sample - loss: 0.0954 - accuracy: 0.9713 - val_loss: 0.0854 - val_accuracy: 0.9740\n",
"Timed stopping at epoch 4 after training for 0:00:05\n"
]
},
{
"data": {
"text/plain": [
"<tensorflow.python.keras.callbacks.History at 0x110af0ef0>"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# initialize TimeStopping callback \n",
"time_stopping_callback = tfa.callbacks.TimeStopping(seconds=5, verbose=1)\n",
"\n",
"# train the model with tqdm_callback\n",
"# make sure to set verbose = 0 to disable\n",
"# the default progress bar.\n",
"model.fit(x_train, y_train,\n",
" batch_size=64,\n",
" epochs=100,\n",
" callbacks=[time_stopping_callback],\n",
" validation_data=(x_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.2"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions tensorflow_addons/callbacks/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ py_library(
name = "callbacks",
srcs = [
"__init__.py",
"time_stopping.py",
"tqdm_progress_bar.py",
],
deps = [
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/callbacks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
## Maintainers
| Submodule | Maintainers | Contact Info |
|:---------- |:------------- |:--------------|
| time_stopping | @shun-lin | [email protected] |
| tqdm_progress_bar | @shun-lin | [email protected] |

## Contents
| Submodule | Callback | Reference |
|:----------------------- |:-------------------|:---------------|
| time_stopping | TimeStopping | N/A |
| tqdm_progress_bar | TQDMProgressBar | https://tqdm.github.io/ |


## Contribution Guidelines
#### Standard API
In order to conform with the current API standard, all callbacks
Expand Down
3 changes: 2 additions & 1 deletion tensorflow_addons/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@
from __future__ import division
from __future__ import print_function

from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar
from tensorflow_addons.callbacks.time_stopping import TimeStopping
from tensorflow_addons.callbacks.tqdm_progress_bar import TQDMProgressBar
64 changes: 64 additions & 0 deletions tensorflow_addons/callbacks/time_stopping.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Callback that stops training when a specified amount of time has passed."""

from __future__ import absolute_import, division, print_function

import datetime
import time

import tensorflow as tf
from tensorflow.keras.callbacks import Callback


@tf.keras.utils.register_keras_serializable(package='Addons')
class TimeStopping(Callback):
"""Stop training when a specified amount of time has passed.

Args:
seconds: maximum amount of time before stopping.
Defaults to 86400 (1 day).
verbose: verbosity mode. Defaults to 0.
Copy link
Member

@seanpmorgan seanpmorgan Dec 13, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason for using an int here instead of a boolean? Doesn't seem like varying verbosity values will make any difference, so better to prevent that confusion

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I use int here for consistency with EarlyStopping Callback.
https://github.com/tensorflow/tensorflow/blob/r2.0/tensorflow/python/keras/callbacks.py#L1134-L1251
And I would assume they also use int instead of just boolean is so that it is consistent with other places that has verbose as a parameter. This is my reasoning (for consistency), but if you feel like changing to boolean is better let me know! Thanks for the feedback!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Still feels a little intuitive, but I agree better to stay with the convention.

"""

def __init__(self, seconds=86400, verbose=0):
super(TimeStopping, self).__init__()

self.seconds = seconds
self.verbose = verbose

def on_train_begin(self, logs=None):
self.stopping_time = time.time() + self.seconds

def on_epoch_end(self, epoch, logs={}):
if time.time() >= self.stopping_time:
self.model.stop_training = True
self.stopped_epoch = epoch

def on_train_end(self, logs=None):
if self.verbose > 0:
formatted_time = datetime.timedelta(seconds=self.seconds)
msg = 'Timed stopping at epoch {} after training for {}'.format(
self.stopped_epoch + 1, formatted_time)
print(msg)

def get_config(self):
config = {
'seconds': self.seconds,
'verbose': self.verbose,
}

base_config = super(TimeStopping, self).get_config()
return dict(list(base_config.items()) + list(config.items()))