|
9 | 9 | "# 使用LeNet在MNIST数据集实现图像分类\n", |
10 | 10 | "\n", |
11 | 11 | "**作者:** [PaddlePaddle](https://github.com/PaddlePaddle) <br>\n", |
12 | | - "**日期:** 2022.4 <br>\n", |
| 12 | + "**日期:** 2022.5 <br>\n", |
13 | 13 | "**摘要:** 本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。" |
14 | 14 | ] |
15 | 15 | }, |
|
21 | 21 | "source": [ |
22 | 22 | "## 一、环境配置\n", |
23 | 23 | "\n", |
24 | | - "本教程基于PaddlePaddle 2.3.0-rc0 编写,如果你的环境不是本版本,请先参考官网[安装](https://www.paddlepaddle.org.cn/install/quick) PaddlePaddle 2.3.0-rc0。" |
| 24 | + "本教程基于PaddlePaddle 2.3.0 编写,如果你的环境不是本版本,请先参考官网[安装](https://www.paddlepaddle.org.cn/install/quick) PaddlePaddle 2.3.0。" |
25 | 25 | ] |
26 | 26 | }, |
27 | 27 | { |
28 | 28 | "cell_type": "code", |
29 | | - "execution_count": 1, |
| 29 | + "execution_count": null, |
30 | 30 | "metadata": { |
31 | 31 | "collapsed": false |
32 | 32 | }, |
|
35 | 35 | "name": "stdout", |
36 | 36 | "output_type": "stream", |
37 | 37 | "text": [ |
38 | | - "2.3.0-rc0\n" |
| 38 | + "2.3.0\n" |
39 | 39 | ] |
40 | 40 | } |
41 | 41 | ], |
|
58 | 58 | }, |
59 | 59 | { |
60 | 60 | "cell_type": "code", |
61 | | - "execution_count": 2, |
| 61 | + "execution_count": null, |
62 | 62 | "metadata": { |
63 | 63 | "collapsed": false |
64 | 64 | }, |
|
87 | 87 | }, |
88 | 88 | { |
89 | 89 | "cell_type": "code", |
90 | | - "execution_count": 3, |
| 90 | + "execution_count": null, |
91 | 91 | "metadata": { |
92 | 92 | "collapsed": false |
93 | 93 | }, |
|
98 | 98 | "text": [ |
99 | 99 | "train_data0 label is: [5]\n" |
100 | 100 | ] |
| 101 | + }, |
| 102 | + { |
| 103 | + "data": { |
| 104 | + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJIAAACPCAYAAAARM4LLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAACHBJREFUeJzt3V1oVOkZB/D/42j8ql9pZInZYBYVIRT8INYWi0atH13Q4E2JilZZWC/8aMFgTb3QCy+KQi803ixWUrGmFGvYtSwEXcyFuEgSDDbZNasuxs3i1yJq0QtdeXsxx+k8B5M5mXnmnDOZ/w9Czv+cZM4LPr7zzjmTZ8Q5B6JcjYp6ADQysJDIBAuJTLCQyAQLiUywkMgEC4lMsJDIRE6FJCJrRaRPRG6LyH6rQVHhkWyvbItIAsA3AFYBGADQAWCjc+6rwX6nrKzMVVVVZXU+ikZXV9cPzrnpmX5udA7n+DmA2865bwFARP4BoA7AoIVUVVWFzs7OHE5JYROR/iA/l8tTWwWA79LygLfPP5CPRaRTRDofP36cw+kozvK+2HbOfeKcq3HO1UyfnnGGpAKVSyF9D6AyLb/v7aMilEshdQCYIyIfiEgJgHoAn9kMiwpN1ott59yPIrILQBuABIBTzrles5FRQcnlVRucc58D+NxoLFTAeGWbTLCQyAQLiUywkMgEC4lMsJDIBAuJTLCQyAQLiUywkMgEC4lM5HSvrZi8efNG5WfPngX+3aamJpVfvnypcl9fn8onTpxQuaGhQeWWlhaVx40bp/L+/f9/+/zBgwcDjzMXnJHIBAuJTLCQyETRrJHu3bun8qtXr1S+evWqyleuXFH56dOnKp87d85sbJWVlSrv3r1b5dbWVpUnTZqk8rx581RetmyZ2diC4oxEJlhIZIKFRCZG7Brp+vXrKq9YsULl4VwHspZIJFQ+fPiwyhMnTlR58+bNKs+YMUPladOmqTx37txchzhsnJHIBAuJTLCQyMSIXSPNnDlT5bKyMpUt10iLFy9W2b9muXz5ssolJSUqb9myxWwsUeGMRCZYSGSChUQmRuwaqbS0VOWjR4+qfOHCBZUXLFig8p49e4Z8/Pnz56e2L126pI75rwP19PSofOzYsSEfuxBxRiITGQtJRE6JyCMR6UnbVyoiF0Xklvd92lCPQSNfkBmpGcBa3779AL5wzs0B8IWXqYgFao8sIlUA/u2c+5mX+wDUOufui0g5gHbnXMYbPDU1NS4uXW2fP3+usv89Pjt27FD55MmTKp85cya1vWnTJuPRxYeIdDnnajL9XLZrpPecc/e97QcA3svycWiEyHmx7ZJT2qDTGtsjF4dsC+mh95QG7/ujwX6Q7ZGLQ7bXkT4D8DsAf/a+f2o2opBMnjx5yONTpkwZ8nj6mqm+vl4dGzWq+K6qBHn53wLgSwBzRWRARD5CsoBWicgtAL/2MhWxjDOSc27jIIdWGo+FCljxzcGUFyP2XluuDh06pHJXV5fK7e3tqW3/vbbVq1fna1ixxRmJTLCQyAQLiUxk/VGk2YjTvbbhunPnjsoLFy5MbU+dOlUdW758uco1NfpW1c6dO1UWEYsh5kW+77URKSwkMsGX/wHNmjVL5ebm5tT29u3b1bHTp08PmV+8eKHy1q1bVS4vL892mJHhjEQmWEhkgoVEJrhGytKGDRtS27Nnz1bH9u7dq7L/FkpjY6PK/f39Kh84cEDlioqKrMcZFs5IZIKFRCZYSGSCt0jywN9K2f/n4du2bVPZ/2+wcqV+z+DFixftBjdMvEVCoWIhkQkWEpngGikCY8eOVfn169cqjxkzRuW2tjaVa2tr8zKud+EaiULFQiITLCQywXttBm7cuKGy/yO4Ojo6VPavifyqq6tVXrp0aQ6jCwdnJDLBQiITLCQywTVSQP6PVD9+/Hhq+/z58+rYgwcPhvXYo0frfwb/e7YLoU1O/EdIBSFIf6RKEbksIl+JSK+I/N7bzxbJlBJkRvoRwF7nXDWAXwDYKSLVYItkShOk0dZ9APe97f+KyNcAKgDUAaj1fuxvANoB/DEvowyBf11z9uxZlZuamlS+e/du1udatGiRyv73aK9fvz7rx47KsNZIXr/tBQCugS2SKU3gQhKRnwD4F4A/OOdUt/OhWiSzPXJxCFRIIjIGySL6u3Pu7WvdQC2S2R65OGRcI0my58pfAXztnPtL2qGCapH88OFDlXt7e1XetWuXyjdv3sz6XP6PJt23b5/KdXV1KhfCdaJMglyQXAJgC4D/iEi3t+9PSBbQP712yf0AfpufIVIhCPKq7QqAwTpBsUUyAeCVbTIyYu61PXnyRGX/x2R1d3er7G/lN1xLlixJbfv/1n/NmjUqjx8/PqdzFQLOSGSChUQmWEhkoqDWSNeuXUttHzlyRB3zvy96YGAgp3NNmDBBZf/Ht6ffH/N/PHsx4oxEJlhIZKKgntpaW1vfuR2E/0981q1bp3IikVC5oaFBZX93f9I4I5EJFhKZYCGRCba1oSGxrQ2FioVEJlhIZIKFRCZYSGSChUQmWEhkgoVEJlhIZIKFRCZYSGQi1HttIvIYyb/KLQPwQ2gnHp64ji2qcc10zmVs2hBqIaVOKtIZ5EZgFOI6triO6y0+tZEJFhKZiKqQPonovEHEdWxxHReAiNZINPLwqY1MhFpIIrJWRPpE5LaIRNpOWUROicgjEelJ2xeL3uGF2Ns8tEISkQSAEwB+A6AawEavX3dUmgGs9e2LS+/wwutt7pwL5QvALwG0peVGAI1hnX+QMVUB6EnLfQDKve1yAH1Rji9tXJ8CWBXX8TnnQn1qqwDwXVoe8PbFSex6hxdKb3Mutgfhkv/tI31Jm21v8yiEWUjfA6hMy+97++IkUO/wMOTS2zwKYRZSB4A5IvKBiJQAqEeyV3ecvO0dDkTYOzxAb3Mgbr3NQ140fgjgGwB3AByIeAHbguSH9bxGcr32EYCfIvlq6BaASwBKIxrbr5B82roBoNv7+jAu43vXF69skwkutskEC4lMsJDIBAuJTLCQyAQLiUywkMgEC4lM/A+jN2A4bkW+2gAAAABJRU5ErkJggg==\n", |
| 105 | + "text/plain": [ |
| 106 | + "<Figure size 144x144 with 1 Axes>" |
| 107 | + ] |
| 108 | + }, |
| 109 | + "metadata": {}, |
| 110 | + "output_type": "display_data" |
101 | 111 | } |
102 | 112 | ], |
103 | 113 | "source": [ |
|
122 | 132 | }, |
123 | 133 | { |
124 | 134 | "cell_type": "code", |
125 | | - "execution_count": 4, |
| 135 | + "execution_count": null, |
126 | 136 | "metadata": { |
127 | 137 | "collapsed": false |
128 | 138 | }, |
|
178 | 188 | }, |
179 | 189 | { |
180 | 190 | "cell_type": "code", |
181 | | - "execution_count": 5, |
| 191 | + "execution_count": null, |
182 | 192 | "metadata": { |
183 | 193 | "collapsed": false |
184 | 194 | }, |
185 | | - "outputs": [ |
186 | | - { |
187 | | - "name": "stderr", |
188 | | - "output_type": "stream", |
189 | | - "text": [ |
190 | | - "W0422 18:56:10.020583 19533 gpu_context.cc:244] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1\n", |
191 | | - "W0422 18:56:10.026566 19533 gpu_context.cc:272] device: 0, cuDNN Version: 7.6.\n" |
192 | | - ] |
193 | | - } |
194 | | - ], |
| 195 | + "outputs": [], |
195 | 196 | "source": [ |
196 | 197 | "from paddle.metric import Accuracy\n", |
197 | 198 | "model = paddle.Model(LeNet()) # 用Model封装模型\n", |
|
207 | 208 | }, |
208 | 209 | { |
209 | 210 | "cell_type": "code", |
210 | | - "execution_count": 6, |
| 211 | + "execution_count": null, |
211 | 212 | "metadata": { |
212 | 213 | "collapsed": false |
213 | 214 | }, |
214 | | - "outputs": [ |
215 | | - { |
216 | | - "name": "stdout", |
217 | | - "output_type": "stream", |
218 | | - "text": [ |
219 | | - "The loss value printed in the log is the current step, and the metric is the average value of previous steps.\n", |
220 | | - "Epoch 1/2\n", |
221 | | - "step 20/938 [..............................] - loss: 1.4646 - acc: 0.3828 - ETA: 17s - 19ms/ste" |
222 | | - ] |
223 | | - }, |
224 | | - { |
225 | | - "name": "stdout", |
226 | | - "output_type": "stream", |
227 | | - "text": [ |
228 | | - "step 30/938 [..............................] - loss: 1.1068 - acc: 0.4672 - ETA: 14s - 16ms/stepstep 938/938 [==============================] - loss: 0.1653 - acc: 0.9273 - 11ms/step \n", |
229 | | - "Epoch 2/2\n", |
230 | | - "step 938/938 [==============================] - loss: 0.0199 - acc: 0.9767 - 11ms/step \n" |
231 | | - ] |
232 | | - } |
233 | | - ], |
| 215 | + "outputs": [], |
234 | 216 | "source": [ |
235 | 217 | "# 训练模型\n", |
236 | 218 | "model.fit(train_dataset,\n", |
|
251 | 233 | }, |
252 | 234 | { |
253 | 235 | "cell_type": "code", |
254 | | - "execution_count": 7, |
| 236 | + "execution_count": null, |
255 | 237 | "metadata": { |
256 | 238 | "collapsed": false |
257 | 239 | }, |
|
261 | 243 | "output_type": "stream", |
262 | 244 | "text": [ |
263 | 245 | "Eval begin...\n", |
264 | | - "step 157/157 [==============================] - loss: 0.0048 - acc: 0.9780 - 8ms/step \n", |
| 246 | + "step 157/157 [==============================] - loss: 4.2854e-04 - acc: 0.9841 - 7ms/step \n", |
265 | 247 | "Eval samples: 10000\n" |
266 | 248 | ] |
267 | 249 | }, |
268 | 250 | { |
269 | 251 | "data": { |
270 | 252 | "text/plain": [ |
271 | | - "{'loss': [0.0047780997], 'acc': 0.978}" |
| 253 | + "{'loss': [0.00042853763], 'acc': 0.9841}" |
272 | 254 | ] |
273 | 255 | }, |
274 | | - "execution_count": 7, |
| 256 | + "execution_count": null, |
275 | 257 | "metadata": {}, |
276 | 258 | "output_type": "execute_result" |
277 | 259 | } |
|
303 | 285 | }, |
304 | 286 | { |
305 | 287 | "cell_type": "code", |
306 | | - "execution_count": 8, |
| 288 | + "execution_count": null, |
307 | 289 | "metadata": { |
308 | 290 | "collapsed": false |
309 | 291 | }, |
|
312 | 294 | "name": "stdout", |
313 | 295 | "output_type": "stream", |
314 | 296 | "text": [ |
315 | | - "epoch: 0, batch_id: 0, loss is: [3.7514806], acc is: [0.21875]\n", |
316 | | - "epoch: 0, batch_id: 300, loss is: [0.19029362], acc is: [0.953125]\n", |
317 | | - "epoch: 0, batch_id: 600, loss is: [0.12201739], acc is: [0.953125]\n", |
318 | | - "epoch: 0, batch_id: 900, loss is: [0.03218058], acc is: [0.984375]\n", |
319 | | - "epoch: 1, batch_id: 0, loss is: [0.114471], acc is: [0.953125]\n", |
320 | | - "epoch: 1, batch_id: 300, loss is: [0.00857661], acc is: [1.]\n", |
321 | | - "epoch: 1, batch_id: 600, loss is: [0.10740176], acc is: [0.96875]\n", |
322 | | - "epoch: 1, batch_id: 900, loss is: [0.19590104], acc is: [0.9375]\n" |
| 297 | + "epoch: 0, batch_id: 0, loss is: [2.9878871], acc is: [0.140625]\n", |
| 298 | + "epoch: 0, batch_id: 300, loss is: [0.22775462], acc is: [0.921875]\n", |
| 299 | + "epoch: 0, batch_id: 600, loss is: [0.06251755], acc is: [0.984375]\n", |
| 300 | + "epoch: 0, batch_id: 900, loss is: [0.1097075], acc is: [0.96875]\n", |
| 301 | + "epoch: 1, batch_id: 0, loss is: [0.04311676], acc is: [0.984375]\n", |
| 302 | + "epoch: 1, batch_id: 300, loss is: [0.00150577], acc is: [1.]\n", |
| 303 | + "epoch: 1, batch_id: 600, loss is: [0.08764459], acc is: [0.96875]\n", |
| 304 | + "epoch: 1, batch_id: 900, loss is: [0.14419323], acc is: [0.9375]\n" |
323 | 305 | ] |
324 | 306 | } |
325 | 307 | ], |
|
361 | 343 | }, |
362 | 344 | { |
363 | 345 | "cell_type": "code", |
364 | | - "execution_count": 9, |
| 346 | + "execution_count": null, |
365 | 347 | "metadata": { |
366 | 348 | "collapsed": false |
367 | 349 | }, |
|
370 | 352 | "name": "stdout", |
371 | 353 | "output_type": "stream", |
372 | 354 | "text": [ |
373 | | - "batch_id: 0, loss is: [0.04440754], acc is: [0.984375]\n", |
374 | | - "batch_id: 20, loss is: [0.19196557], acc is: [0.9375]\n", |
375 | | - "batch_id: 40, loss is: [0.09817676], acc is: [0.984375]\n", |
376 | | - "batch_id: 60, loss is: [0.16782945], acc is: [0.953125]\n", |
377 | | - "batch_id: 80, loss is: [0.05786889], acc is: [0.96875]\n", |
378 | | - "batch_id: 100, loss is: [0.00799548], acc is: [1.]\n", |
379 | | - "batch_id: 120, loss is: [0.00511317], acc is: [1.]\n", |
380 | | - "batch_id: 140, loss is: [0.01672031], acc is: [1.]\n" |
| 355 | + "batch_id: 0, loss is: [0.01201783], acc is: [1.]\n", |
| 356 | + "batch_id: 20, loss is: [0.09013407], acc is: [0.984375]\n", |
| 357 | + "batch_id: 40, loss is: [0.07025866], acc is: [0.96875]\n", |
| 358 | + "batch_id: 60, loss is: [0.08602518], acc is: [0.984375]\n", |
| 359 | + "batch_id: 80, loss is: [0.00779913], acc is: [1.]\n", |
| 360 | + "batch_id: 100, loss is: [0.00508764], acc is: [1.]\n", |
| 361 | + "batch_id: 120, loss is: [0.00401443], acc is: [1.]\n", |
| 362 | + "batch_id: 140, loss is: [0.03930391], acc is: [0.96875]\n" |
381 | 363 | ] |
382 | 364 | } |
383 | 365 | ], |
|
0 commit comments