From 2410e31373418ca352b9ea20d6049264e463cb24 Mon Sep 17 00:00:00 2001 From: Shun Lin Date: Tue, 17 Sep 2019 21:58:53 -0700 Subject: [PATCH] Added TQDMCallback in callbacks --- examples/TQDM_Callbacks.ipynb | 368 +++++++++++++++++++ examples/mnist_with_TQDMCallback.py | 68 ++++ tensorflow_addons/callbacks/BUILD | 11 +- tensorflow_addons/callbacks/__init__.py | 2 + tensorflow_addons/callbacks/tqdm_callback.py | 148 ++++++++ 5 files changed, 592 insertions(+), 5 deletions(-) create mode 100644 examples/TQDM_Callbacks.ipynb create mode 100644 examples/mnist_with_TQDMCallback.py create mode 100644 tensorflow_addons/callbacks/tqdm_callback.py diff --git a/examples/TQDM_Callbacks.ipynb b/examples/TQDM_Callbacks.ipynb new file mode 100644 index 0000000000..1f633ddf35 --- /dev/null +++ b/examples/TQDM_Callbacks.ipynb @@ -0,0 +1,368 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# TensorFlow Addons Callbacks: TQDM" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "This notebook will demonstrate how to use TQDMCallback in TensorFlow Addons." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: ipywidgets in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (7.5.1)\n", + "Requirement already satisfied: nbformat>=4.2.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipywidgets) (4.4.0)\n", + "Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipywidgets) (3.5.1)\n", + "Requirement already satisfied: traitlets>=4.3.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipywidgets) (4.3.2)\n", + "Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipywidgets) (5.1.2)\n", + "Requirement already satisfied: ipython>=4.0.0; python_version >= \"3.3\" in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipywidgets) (7.8.0)\n", + "Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (3.0.2)\n", + "Requirement already satisfied: jupyter-core in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (4.5.0)\n", + "Requirement already satisfied: ipython-genutils in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (0.2.0)\n", + "Requirement already satisfied: notebook>=4.4.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from widgetsnbextension~=3.5.0->ipywidgets) (6.0.1)\n", + "Requirement already satisfied: six in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from traitlets>=4.3.1->ipywidgets) (1.12.0)\n", + "Requirement already satisfied: decorator in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from traitlets>=4.3.1->ipywidgets) (4.4.0)\n", + "Requirement already satisfied: jupyter-client in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipykernel>=4.5.1->ipywidgets) (5.3.3)\n", + "Requirement already satisfied: tornado>=4.2 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.0.3)\n", + "Requirement already satisfied: jedi>=0.10 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.15.1)\n", + "Requirement already satisfied: backcall in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.1.0)\n", + "Requirement already satisfied: pickleshare in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.7.5)\n", + "Requirement already satisfied: pexpect; sys_platform != \"win32\" in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (4.7.0)\n", + "Requirement already satisfied: pygments in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (2.4.2)\n", + "Requirement already satisfied: prompt-toolkit<2.1.0,>=2.0.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (2.0.9)\n", + "Requirement already satisfied: setuptools>=18.5 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (41.2.0)\n", + "Requirement already satisfied: pyrsistent>=0.14.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (0.15.4)\n", + "Requirement already satisfied: attrs>=17.4.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (19.1.0)\n", + "Requirement already satisfied: prometheus-client in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.7.1)\n", + "Requirement already satisfied: terminado>=0.8.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.2)\n", + "Requirement already satisfied: Send2Trash in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.5.0)\n", + "Requirement already satisfied: pyzmq>=17 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (18.1.0)\n", + "Requirement already satisfied: nbconvert in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (5.6.0)\n", + "Requirement already satisfied: jinja2 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.10.1)\n", + "Requirement already satisfied: python-dateutil>=2.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets) (2.8.0)\n", + "Requirement already satisfied: parso>=0.5.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from jedi>=0.10->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.5.1)\n", + "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from pexpect; sys_platform != \"win32\"->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.6.0)\n", + "Requirement already satisfied: wcwidth in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from prompt-toolkit<2.1.0,>=2.0.0->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.1.7)\n", + "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.4)\n", + "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.4.2)\n", + "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.3)\n", + "Requirement already satisfied: bleach in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (3.1.0)\n", + "Requirement already satisfied: defusedxml in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.6.0)\n", + "Requirement already satisfied: testpath in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.4.2)\n", + "Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.1.1)\n", + "Requirement already satisfied: webencodings in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.1)\n", + "Enabling notebook extension jupyter-js-widgets/extension...\n", + " - Validating: \u001b[32mOK\u001b[0m\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Using TensorFlow backend.\n" + ] + } + ], + "source": [ + "!pip install -q tensorflow-gpu==2.0.0rc0\n", + "!pip install -q tensorflow-addons~=0.5\n", + "!pip install -q tqdm\n", + "!pip install -q keras\n", + "\n", + "!pip install ipywidgets\n", + "!jupyter nbextension enable --py widgetsnbextension --sys-prefix\n", + "\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import tensorflow_addons as tfa\n", + "import keras\n", + "from keras.datasets import mnist\n", + "from keras.models import Sequential\n", + "from keras.layers import Dense, Dropout, Flatten\n", + "from keras.layers import Conv2D, MaxPooling2D\n", + "from keras import backend as K" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Import Data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# the data, split between train and test sets\n", + "(x_train, y_train), (x_test, y_test) = mnist.load_data()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Preprocess Data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "x_train shape: (60000, 28, 28, 1)\n", + "60000 train samples\n", + "10000 test samples\n" + ] + } + ], + "source": [ + "batch_size = 128\n", + "num_classes = 10\n", + "epochs = 3\n", + "\n", + "# input image dimensions\n", + "img_rows, img_cols = 28, 28\n", + "\n", + "if K.image_data_format() == 'channels_first':\n", + " x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)\n", + " x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)\n", + " input_shape = (1, img_rows, img_cols)\n", + "else:\n", + " x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n", + " x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n", + " input_shape = (img_rows, img_cols, 1)\n", + "\n", + "x_train = x_train.astype('float32')\n", + "x_test = x_test.astype('float32')\n", + "x_train /= 255\n", + "x_test /= 255\n", + "print('x_train shape:', x_train.shape)\n", + "print(x_train.shape[0], 'train samples')\n", + "print(x_test.shape[0], 'test samples')\n", + "\n", + "# convert class vectors to binary class matrices\n", + "y_train = keras.utils.to_categorical(y_train, num_classes)\n", + "y_test = keras.utils.to_categorical(y_test, num_classes)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Build Simple MNIST CNN Model" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "model = Sequential()\n", + "model.add(Conv2D(32, kernel_size=(3, 3),\n", + " activation='relu',\n", + " input_shape=input_shape))\n", + "model.add(Conv2D(64, (3, 3), activation='relu'))\n", + "model.add(MaxPooling2D(pool_size=(2, 2)))\n", + "model.add(Dropout(0.25))\n", + "model.add(Flatten())\n", + "model.add(Dense(128, activation='relu'))\n", + "model.add(Dropout(0.5))\n", + "model.add(Dense(num_classes, activation='softmax'))\n", + "\n", + "model.compile(loss=keras.losses.categorical_crossentropy,\n", + " optimizer=keras.optimizers.Adadelta(),\n", + " metrics=['accuracy'])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## TQDMCallback example usage 1" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2ecc9bbf7e8d4ee2a420d4b6341a33a9", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, description='Training', max=3, style=ProgressStyle(description_width='init…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5daaabe6392847559ef1cbd8fcdebf49", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, description='Epoch: 0', max=60000, style=ProgressStyle(description_width='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6fa1dd5a90ce461483b64b6b090c4038", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, description='Epoch: 1', max=60000, style=ProgressStyle(description_width='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8167fc3e3b6042a78e8c9f7a28ec9d77", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, description='Epoch: 2', max=60000, style=ProgressStyle(description_width='…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tqdm_callback = tfa.callbacks.TQDMCallback(leave_outer=True)\n", + "model.fit(x_train, y_train,\n", + " batch_size=batch_size,\n", + " epochs=epochs,\n", + " verbose=0,\n", + " callbacks = [tqdm_callback],\n", + " validation_data=(x_test, y_test))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test loss: 2.3033354560852053\n", + "Test accuracy: 0.07720000296831131\n" + ] + } + ], + "source": [ + "score = model.evaluate(x_test, y_test, verbose=0)\n", + "print('Test loss:', score[0])\n", + "print('Test accuracy:', score[1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/examples/mnist_with_TQDMCallback.py b/examples/mnist_with_TQDMCallback.py new file mode 100644 index 0000000000..a90867ef01 --- /dev/null +++ b/examples/mnist_with_TQDMCallback.py @@ -0,0 +1,68 @@ +import numpy as np +import tensorflow as tf +import tensorflow_addons as tfa +import keras +from keras.datasets import mnist +from keras.models import Sequential +from keras.layers import Dense, Dropout, Flatten +from keras.layers import Conv2D, MaxPooling2D +from keras import backend as K + +batch_size = 128 +num_classes = 10 +epochs = 3 + +# the data, split between train and test sets +(x_train, y_train), (x_test, y_test) = mnist.load_data() + +# input image dimensions +img_rows, img_cols = 28, 28 + +if K.image_data_format() == 'channels_first': + x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) + x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) + input_shape = (1, img_rows, img_cols) +else: + x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) + x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) + input_shape = (img_rows, img_cols, 1) + +x_train = x_train.astype('float32') +x_test = x_test.astype('float32') +x_train /= 255 +x_test /= 255 +print('x_train shape:', x_train.shape) +print(x_train.shape[0], 'train samples') +print(x_test.shape[0], 'test samples') + +# convert class vectors to binary class matrices +y_train = keras.utils.to_categorical(y_train, num_classes) +y_test = keras.utils.to_categorical(y_test, num_classes) + +model = Sequential() +model.add(Conv2D(32, kernel_size=(3, 3), + activation='relu', + input_shape=input_shape)) +model.add(Conv2D(64, (3, 3), activation='relu')) +model.add(MaxPooling2D(pool_size=(2, 2))) +model.add(Dropout(0.25)) +model.add(Flatten()) +model.add(Dense(128, activation='relu')) +model.add(Dropout(0.5)) +model.add(Dense(num_classes, activation='softmax')) + +model.compile(loss=keras.losses.categorical_crossentropy, + optimizer=keras.optimizers.Adadelta(), + metrics=['accuracy']) + +tqdm_callback = tfa.callbacks.TQDMCallback() +model.fit(x_train, y_train, + batch_size=batch_size, + epochs=epochs, + verbose=0, + callbacks=[tqdm_callback], + validation_data=(x_test, y_test)) + +score = model.evaluate(x_test, y_test, verbose=0) +print('Test loss:', score[0]) +print('Test accuracy:', score[1]) diff --git a/tensorflow_addons/callbacks/BUILD b/tensorflow_addons/callbacks/BUILD index e0388beaa0..f75519bbaa 100644 --- a/tensorflow_addons/callbacks/BUILD +++ b/tensorflow_addons/callbacks/BUILD @@ -1,14 +1,15 @@ licenses(["notice"]) # Apache 2.0 -package(default_visibility = ["//visibility:public"]) +package(default_visibility=["//visibility:public"]) py_library( - name = "callbacks", - srcs = [ + name="callbacks", + srcs=[ "__init__.py", + "tqdm_callback.py" ], - srcs_version = "PY2AND3", - deps = [ + srcs_version="PY2AND3", + deps=[ "//tensorflow_addons/utils", ], ) diff --git a/tensorflow_addons/callbacks/__init__.py b/tensorflow_addons/callbacks/__init__.py index 3d79cb58bf..65cf4ae559 100755 --- a/tensorflow_addons/callbacks/__init__.py +++ b/tensorflow_addons/callbacks/__init__.py @@ -17,3 +17,5 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function + +from tensorflow_addons.callbacks.tqdm_callback import TQDMCallback diff --git a/tensorflow_addons/callbacks/tqdm_callback.py b/tensorflow_addons/callbacks/tqdm_callback.py new file mode 100644 index 0000000000..1b8380b8f7 --- /dev/null +++ b/tensorflow_addons/callbacks/tqdm_callback.py @@ -0,0 +1,148 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from sys import stderr + +import numpy as np +import six +from tensorflow.keras.callbacks import Callback +from tqdm.auto import tqdm +from tensorflow_addons.utils import keras_utils + + +@keras_utils.register_keras_custom_object +class TQDMCallback(Callback): + """TQDM Progress bar for Keras. + + Arguments: + outer_description: string for outer progress bar + inner_description_initial: initial format for epoch ("Epoch: {epoch}") + inner_description_update: format after metrics collected ("Epoch: {epoch} - {metrics}") + metric_format: format for each metric name/value pair ("{name}: {value:0.3f}") + separator: separator between metrics (", ") + leave_inner: True to leave inner bars + leave_outer: True to leave outer bars + show_inner: False to hide inner bars + show_outer: False to hide outer bar + output_file: output file (default sys.stderr) + """ + + def __init__(self, outer_description="Training", + inner_description_initial="Epoch: {epoch}", + inner_description_update="Epoch: {epoch} - {metrics}", + metric_format="{name}: {value:0.3f}", + separator=", ", + leave_inner=True, + leave_outer=True, + show_inner=True, + show_outer=True, + output_file=stderr): + + self.outer_description = outer_description + self.inner_description_initial = inner_description_initial + self.inner_description_update = inner_description_update + self.metric_format = metric_format + self.separator = separator + self.leave_inner = leave_inner + self.leave_outer = leave_outer + self.show_inner = show_inner + self.show_outer = show_outer + self.output_file = output_file + self.tqdm_outer = None + self.tqdm_inner = None + self.epoch = None + self.running_logs = None + self.inner_count = None + + def on_epoch_begin(self, epoch, logs={}): + self.epoch = epoch + desc = self.inner_description_initial.format(epoch=self.epoch) + if self.mode == 'sample': + self.inner_total = self.params['samples'] + else: + self.inner_total = self.params['steps'] + if self.show_inner: + self.tqdm_inner = tqdm( + desc=desc, total=self.inner_total, leave=self.leave_inner) + self.inner_count = 0 + self.running_logs = {} + + def on_epoch_end(self, epoch, logs={}): + metrics = self.format_metrics(logs) + desc = self.inner_description_update.format( + epoch=epoch, metrics=metrics) + if self.show_inner: + self.tqdm_inner.desc = desc + # set miniters and mininterval to 0 so last update displays + self.tqdm_inner.miniters = 0 + self.tqdm_inner.mininterval = 0 + self.tqdm_inner.update(self.inner_total - self.tqdm_inner.n) + self.tqdm_inner.close() + if self.show_outer: + self.tqdm_outer.update(1) + + def on_batch_begin(self, batch, logs={}): + pass + + def on_batch_end(self, batch, logs={}): + if self.mode == "sample": + update = logs['size'] + else: + update = 1 + self.inner_count += update + if self.inner_count < self.inner_total: + self.append_logs(logs) + metrics = self.format_metrics(self.running_logs) + desc = self.inner_description_update.format( + epoch=self.epoch, metrics=metrics) + if self.show_inner: + self.tqdm_inner.desc = desc + self.tqdm_inner.update(update) + + def on_train_begin(self, logs={}): + if self.show_outer: + epochs = self.params['epochs'] + self.tqdm_outer = tqdm( + desc=self.outer_description, total=epochs, leave=self.leave_outer) + + # set counting mode + if 'samples' in self.params: + self.mode = 'sample' + else: + self.mode = 'step' + + def on_train_end(self, logs={}): + if self.show_outer: + self.tqdm_outer.close() + + def append_logs(self, logs): + """append logs seen in a batch to the running log to display updated + metrics values in real time.""" + metrics = self.params['metrics'] + for metric, value in six.iteritems(logs): + if metric in metrics: + if metric in self.running_logs: + self.running_logs[metric].append(value[()]) + else: + self.running_logs[metric] = [value[()]] + + def format_metrics(self, logs): + """Format metrics in logs into a string. + + Arguments: + logs: dictionary of metrics and their values. + + Returns: + metrics_string: a string displaying metrics using the given + formators passed in through the constructor. + """ + metrics = self.params['metrics'] + metric_value_pairs = [] + for metric in metrics: + if metric in logs: + pair = self.metric_format.format( + name=metric, value=np.mean(logs[metric], axis=None)) + metric_value_pairs.append(pair) + metrics_string = self.separator.join(metric_value_pairs) + return metrics_string