Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
368 changes: 368 additions & 0 deletions examples/TQDM_Callbacks.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,368 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# TensorFlow Addons Callbacks: TQDM"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Overview\n",
"This notebook will demonstrate how to use TQDMCallback in TensorFlow Addons."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Setup"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: ipywidgets in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (7.5.1)\n",
"Requirement already satisfied: nbformat>=4.2.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipywidgets) (4.4.0)\n",
"Requirement already satisfied: widgetsnbextension~=3.5.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipywidgets) (3.5.1)\n",
"Requirement already satisfied: traitlets>=4.3.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipywidgets) (4.3.2)\n",
"Requirement already satisfied: ipykernel>=4.5.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipywidgets) (5.1.2)\n",
"Requirement already satisfied: ipython>=4.0.0; python_version >= \"3.3\" in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipywidgets) (7.8.0)\n",
"Requirement already satisfied: jsonschema!=2.5.0,>=2.4 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (3.0.2)\n",
"Requirement already satisfied: jupyter-core in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (4.5.0)\n",
"Requirement already satisfied: ipython-genutils in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbformat>=4.2.0->ipywidgets) (0.2.0)\n",
"Requirement already satisfied: notebook>=4.4.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from widgetsnbextension~=3.5.0->ipywidgets) (6.0.1)\n",
"Requirement already satisfied: six in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from traitlets>=4.3.1->ipywidgets) (1.12.0)\n",
"Requirement already satisfied: decorator in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from traitlets>=4.3.1->ipywidgets) (4.4.0)\n",
"Requirement already satisfied: jupyter-client in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipykernel>=4.5.1->ipywidgets) (5.3.3)\n",
"Requirement already satisfied: tornado>=4.2 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipykernel>=4.5.1->ipywidgets) (6.0.3)\n",
"Requirement already satisfied: jedi>=0.10 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.15.1)\n",
"Requirement already satisfied: backcall in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.1.0)\n",
"Requirement already satisfied: pickleshare in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.7.5)\n",
"Requirement already satisfied: pexpect; sys_platform != \"win32\" in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (4.7.0)\n",
"Requirement already satisfied: pygments in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (2.4.2)\n",
"Requirement already satisfied: prompt-toolkit<2.1.0,>=2.0.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (2.0.9)\n",
"Requirement already satisfied: setuptools>=18.5 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (41.2.0)\n",
"Requirement already satisfied: pyrsistent>=0.14.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (0.15.4)\n",
"Requirement already satisfied: attrs>=17.4.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from jsonschema!=2.5.0,>=2.4->nbformat>=4.2.0->ipywidgets) (19.1.0)\n",
"Requirement already satisfied: prometheus-client in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.7.1)\n",
"Requirement already satisfied: terminado>=0.8.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.2)\n",
"Requirement already satisfied: Send2Trash in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.5.0)\n",
"Requirement already satisfied: pyzmq>=17 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (18.1.0)\n",
"Requirement already satisfied: nbconvert in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (5.6.0)\n",
"Requirement already satisfied: jinja2 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (2.10.1)\n",
"Requirement already satisfied: python-dateutil>=2.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from jupyter-client->ipykernel>=4.5.1->ipywidgets) (2.8.0)\n",
"Requirement already satisfied: parso>=0.5.0 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from jedi>=0.10->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.5.1)\n",
"Requirement already satisfied: ptyprocess>=0.5 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from pexpect; sys_platform != \"win32\"->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.6.0)\n",
"Requirement already satisfied: wcwidth in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from prompt-toolkit<2.1.0,>=2.0.0->ipython>=4.0.0; python_version >= \"3.3\"->ipywidgets) (0.1.7)\n",
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.8.4)\n",
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.4.2)\n",
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.3)\n",
"Requirement already satisfied: bleach in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (3.1.0)\n",
"Requirement already satisfied: defusedxml in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.6.0)\n",
"Requirement already satisfied: testpath in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.4.2)\n",
"Requirement already satisfied: MarkupSafe>=0.23 in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from jinja2->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (1.1.1)\n",
"Requirement already satisfied: webencodings in /usr/local/google/home/shunlin/addons/env/lib/python3.6/site-packages (from bleach->nbconvert->notebook>=4.4.1->widgetsnbextension~=3.5.0->ipywidgets) (0.5.1)\n",
"Enabling notebook extension jupyter-js-widgets/extension...\n",
" - Validating: \u001b[32mOK\u001b[0m\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using TensorFlow backend.\n"
]
}
],
"source": [
"!pip install -q tensorflow-gpu==2.0.0rc0\n",
"!pip install -q tensorflow-addons~=0.5\n",
"!pip install -q tqdm\n",
"!pip install -q keras\n",
"\n",
"!pip install ipywidgets\n",
"!jupyter nbextension enable --py widgetsnbextension --sys-prefix\n",
"\n",
"import numpy as np\n",
"import tensorflow as tf\n",
"import tensorflow_addons as tfa\n",
"import keras\n",
"from keras.datasets import mnist\n",
"from keras.models import Sequential\n",
"from keras.layers import Dense, Dropout, Flatten\n",
"from keras.layers import Conv2D, MaxPooling2D\n",
"from keras import backend as K"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import Data"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# the data, split between train and test sets\n",
"(x_train, y_train), (x_test, y_test) = mnist.load_data()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Preprocess Data"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"x_train shape: (60000, 28, 28, 1)\n",
"60000 train samples\n",
"10000 test samples\n"
]
}
],
"source": [
"batch_size = 128\n",
"num_classes = 10\n",
"epochs = 3\n",
"\n",
"# input image dimensions\n",
"img_rows, img_cols = 28, 28\n",
"\n",
"if K.image_data_format() == 'channels_first':\n",
" x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)\n",
" x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)\n",
" input_shape = (1, img_rows, img_cols)\n",
"else:\n",
" x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)\n",
" x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)\n",
" input_shape = (img_rows, img_cols, 1)\n",
"\n",
"x_train = x_train.astype('float32')\n",
"x_test = x_test.astype('float32')\n",
"x_train /= 255\n",
"x_test /= 255\n",
"print('x_train shape:', x_train.shape)\n",
"print(x_train.shape[0], 'train samples')\n",
"print(x_test.shape[0], 'test samples')\n",
"\n",
"# convert class vectors to binary class matrices\n",
"y_train = keras.utils.to_categorical(y_train, num_classes)\n",
"y_test = keras.utils.to_categorical(y_test, num_classes)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Build Simple MNIST CNN Model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"model = Sequential()\n",
"model.add(Conv2D(32, kernel_size=(3, 3),\n",
" activation='relu',\n",
" input_shape=input_shape))\n",
"model.add(Conv2D(64, (3, 3), activation='relu'))\n",
"model.add(MaxPooling2D(pool_size=(2, 2)))\n",
"model.add(Dropout(0.25))\n",
"model.add(Flatten())\n",
"model.add(Dense(128, activation='relu'))\n",
"model.add(Dropout(0.5))\n",
"model.add(Dense(num_classes, activation='softmax'))\n",
"\n",
"model.compile(loss=keras.losses.categorical_crossentropy,\n",
" optimizer=keras.optimizers.Adadelta(),\n",
" metrics=['accuracy'])\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## TQDMCallback example usage 1"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2ecc9bbf7e8d4ee2a420d4b6341a33a9",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Training', max=3, style=ProgressStyle(description_width='init…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5daaabe6392847559ef1cbd8fcdebf49",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch: 0', max=60000, style=ProgressStyle(description_width='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "6fa1dd5a90ce461483b64b6b090c4038",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch: 1', max=60000, style=ProgressStyle(description_width='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "8167fc3e3b6042a78e8c9f7a28ec9d77",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, description='Epoch: 2', max=60000, style=ProgressStyle(description_width='…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"data": {
"text/plain": [
"<keras.callbacks.callbacks.History at 0x7f7cfc7693c8>"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tqdm_callback = tfa.callbacks.TQDMCallback(leave_outer=True)\n",
"model.fit(x_train, y_train,\n",
" batch_size=batch_size,\n",
" epochs=epochs,\n",
" verbose=0,\n",
" callbacks = [tqdm_callback],\n",
" validation_data=(x_test, y_test))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test loss: 2.3033354560852053\n",
"Test accuracy: 0.07720000296831131\n"
]
}
],
"source": [
"score = model.evaluate(x_test, y_test, verbose=0)\n",
"print('Test loss:', score[0])\n",
"print('Test accuracy:', score[1])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading