Skip to content

Commit bb189aa

Browse files
authored
LazyAdam example (#167)
* LazyAdam example
1 parent 8edcf7d commit bb189aa

File tree

2 files changed

+275
-92
lines changed

2 files changed

+275
-92
lines changed
Lines changed: 275 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,275 @@
1+
{
2+
"nbformat": 4,
3+
"nbformat_minor": 0,
4+
"metadata": {
5+
"colab": {
6+
"name": "optimizers_lazyadam.ipynb",
7+
"version": "0.3.2",
8+
"provenance": [],
9+
"collapsed_sections": []
10+
},
11+
"kernelspec": {
12+
"name": "python3",
13+
"display_name": "Python 3"
14+
},
15+
"accelerator": "GPU"
16+
},
17+
"cells": [
18+
{
19+
"metadata": {
20+
"colab_type": "text",
21+
"id": "Tce3stUlHN0L"
22+
},
23+
"cell_type": "markdown",
24+
"source": [
25+
"##### Copyright 2019 The TensorFlow Authors.\n",
26+
"\n"
27+
]
28+
},
29+
{
30+
"metadata": {
31+
"colab_type": "code",
32+
"id": "tuOe1ymfHZPu",
33+
"cellView": "form",
34+
"colab": {}
35+
},
36+
"cell_type": "code",
37+
"source": [
38+
"#@title Licensed under the Apache License, Version 2.0\n",
39+
"# you may not use this file except in compliance with the License.\n",
40+
"# You may obtain a copy of the License at\n",
41+
"#\n",
42+
"# https://www.apache.org/licenses/LICENSE-2.0\n",
43+
"#\n",
44+
"# Unless required by applicable law or agreed to in writing, software\n",
45+
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
46+
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
47+
"# See the License for the specific language governing permissions and\n",
48+
"# limitations under the License."
49+
],
50+
"execution_count": 0,
51+
"outputs": []
52+
},
53+
{
54+
"metadata": {
55+
"colab_type": "text",
56+
"id": "MfBg1C5NB3X0"
57+
},
58+
"cell_type": "markdown",
59+
"source": [
60+
"# TensorFlow Addons Optimizers: LazyAdam\n",
61+
"\n",
62+
"<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
63+
" <td>\n",
64+
" <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/addons/blob/master/tensorflow_addons/examples/notebooks/optimizers_lazyadam.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
65+
" </td>\n",
66+
" <td>\n",
67+
" <a target=\"_blank\" href=\"https://github.com/tensorflow/addons/blob/master/tensorflow_addons/examples/notebooks/optimizers_lazyadam.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
68+
" </td>\n",
69+
"</table>"
70+
]
71+
},
72+
{
73+
"metadata": {
74+
"colab_type": "text",
75+
"id": "xHxb-dlhMIzW"
76+
},
77+
"cell_type": "markdown",
78+
"source": [
79+
"# Overview\n",
80+
"\n",
81+
"This notebook will demonstrate how to use the lazy adam optimizer from the Addons package.\n"
82+
]
83+
},
84+
{
85+
"metadata": {
86+
"id": "bQwBbFVAyHJ_",
87+
"colab_type": "text"
88+
},
89+
"cell_type": "markdown",
90+
"source": [
91+
"# LazyAdam\n",
92+
"\n",
93+
"> LazyAdam is a variant of the Adam optimizer that handles sparse updates moreefficiently.\n",
94+
" The original Adam algorithm maintains two moving-average accumulators for\n",
95+
" each trainable variable; the accumulators are updated at every step.\n",
96+
" This class provides lazier handling of gradient updates for sparse\n",
97+
" variables. It only updates moving-average accumulators for sparse variable\n",
98+
" indices that appear in the current batch, rather than updating the\n",
99+
" accumulators for all indices. Compared with the original Adam optimizer,\n",
100+
" it can provide large improvements in model training throughput for some\n",
101+
" applications. However, it provides slightly different semantics than the\n",
102+
" original Adam algorithm, and may lead to different empirical results."
103+
]
104+
},
105+
{
106+
"metadata": {
107+
"colab_type": "text",
108+
"id": "MUXex9ctTuDB"
109+
},
110+
"cell_type": "markdown",
111+
"source": [
112+
"## Setup"
113+
]
114+
},
115+
{
116+
"metadata": {
117+
"colab_type": "code",
118+
"id": "IqR2PQG4ZaZ0",
119+
"colab": {}
120+
},
121+
"cell_type": "code",
122+
"source": [
123+
"!pip install tensorflow-gpu==2.0.0.a0\n",
124+
"!pip install tensorflow-addons\n",
125+
"from __future__ import absolute_import, division, print_function, unicode_literals\n",
126+
"\n",
127+
"import tensorflow as tf\n",
128+
"import tensorflow_addons as tfa\n",
129+
"import tensorflow_datasets as tfds\n",
130+
"import numpy as np\n",
131+
"from matplotlib import pyplot as plt"
132+
],
133+
"execution_count": 0,
134+
"outputs": []
135+
},
136+
{
137+
"metadata": {
138+
"id": "ys65MwOLKnXq",
139+
"colab_type": "code",
140+
"colab": {}
141+
},
142+
"cell_type": "code",
143+
"source": [
144+
"# Hyperparameters\n",
145+
"batch_size=64\n",
146+
"epochs=10"
147+
],
148+
"execution_count": 0,
149+
"outputs": []
150+
},
151+
{
152+
"metadata": {
153+
"id": "KR01t9v_fxbT",
154+
"colab_type": "text"
155+
},
156+
"cell_type": "markdown",
157+
"source": [
158+
"# Build the Model"
159+
]
160+
},
161+
{
162+
"metadata": {
163+
"id": "djpoAvfWNyL5",
164+
"colab_type": "code",
165+
"colab": {}
166+
},
167+
"cell_type": "code",
168+
"source": [
169+
"model = tf.keras.Sequential([\n",
170+
" tf.keras.layers.Dense(64, input_shape=(784,), activation='relu', name='dense_1'),\n",
171+
" tf.keras.layers.Dense(64, activation='relu', name='dense_2'),\n",
172+
" tf.keras.layers.Dense(10, activation='softmax', name='predictions'),\n",
173+
"])"
174+
],
175+
"execution_count": 0,
176+
"outputs": []
177+
},
178+
{
179+
"metadata": {
180+
"id": "0_D7CZqkv_Hj",
181+
"colab_type": "text"
182+
},
183+
"cell_type": "markdown",
184+
"source": [
185+
"# Prep the Data"
186+
]
187+
},
188+
{
189+
"metadata": {
190+
"id": "U0bS3SyowBoB",
191+
"colab_type": "code",
192+
"colab": {}
193+
},
194+
"cell_type": "code",
195+
"source": [
196+
"# Load MNIST dataset as NumPy arrays\n",
197+
"dataset = {}\n",
198+
"num_validation = 10000\n",
199+
"(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()\n",
200+
"\n",
201+
"# Preprocess the data\n",
202+
"x_train = x_train.reshape(-1, 784).astype('float32') / 255\n",
203+
"x_test = x_test.reshape(-1, 784).astype('float32') / 255"
204+
],
205+
"execution_count": 0,
206+
"outputs": []
207+
},
208+
{
209+
"metadata": {
210+
"id": "HYE-BxhOzFQp",
211+
"colab_type": "text"
212+
},
213+
"cell_type": "markdown",
214+
"source": [
215+
"# Train and Evaluate\n",
216+
"\n",
217+
"Simply replace typical keras optimizers with the new tfa optimizer "
218+
]
219+
},
220+
{
221+
"metadata": {
222+
"id": "NxfYhtiSzHf-",
223+
"colab_type": "code",
224+
"colab": {}
225+
},
226+
"cell_type": "code",
227+
"source": [
228+
"# Compile the model\n",
229+
"model.compile(\n",
230+
" optimizer=tfa.optimizers.LazyAdam(0.001), # Utilize TFA optimizer\n",
231+
" loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
232+
" metrics=['accuracy'])\n",
233+
"\n",
234+
"# Train the network\n",
235+
"history = model.fit(\n",
236+
" x_train,\n",
237+
" y_train,\n",
238+
" batch_size=batch_size,\n",
239+
" epochs=epochs)\n"
240+
],
241+
"execution_count": 0,
242+
"outputs": []
243+
},
244+
{
245+
"metadata": {
246+
"id": "1Y--0tK69SXf",
247+
"colab_type": "code",
248+
"outputId": "163a7751-e35b-4d9f-cc07-1f8580bdf6bf",
249+
"colab": {
250+
"base_uri": "https://localhost:8080/",
251+
"height": 68
252+
}
253+
},
254+
"cell_type": "code",
255+
"source": [
256+
"# Evaluate the network\n",
257+
"print('Evaluate on test data:')\n",
258+
"results = model.evaluate(x_test, y_test, batch_size=128)\n",
259+
"print('Test loss = {0}, Test acc: {1}'.format(results[0], results[1]))"
260+
],
261+
"execution_count": 9,
262+
"outputs": [
263+
{
264+
"output_type": "stream",
265+
"text": [
266+
"Evaluate on test data:\n",
267+
"10000/10000 [==============================] - 0s 21us/sample - loss: 0.0884 - accuracy: 0.9752\n",
268+
"Test loss = 0.08840992146739736, Test acc: 0.9751999974250793\n"
269+
],
270+
"name": "stdout"
271+
}
272+
]
273+
}
274+
]
275+
}

tensorflow_addons/examples/tfa_optimizer.py

Lines changed: 0 additions & 92 deletions
This file was deleted.

0 commit comments

Comments
 (0)