|
9 | 9 | "# 使用LeNet在MNIST数据集实现图像分类\n", |
10 | 10 | "\n", |
11 | 11 | "**作者:** [PaddlePaddle](https://github.com/PaddlePaddle) <br>\n", |
12 | | - "**日期:** 2022.1 <br>\n", |
| 12 | + "**日期:** 2022.4 <br>\n", |
13 | 13 | "**摘要:** 本示例教程演示如何在MNIST数据集上用LeNet进行图像分类。" |
14 | 14 | ] |
15 | 15 | }, |
|
21 | 21 | "source": [ |
22 | 22 | "## 一、环境配置\n", |
23 | 23 | "\n", |
24 | | - "本教程基于Paddle 2.2 编写,如果你的环境不是本版本,请先参考官网[安装](https://www.paddlepaddle.org.cn/install/quick) Paddle 2.2。" |
| 24 | + "本教程基于PaddlePaddle 2.3.0-rc0 编写,如果你的环境不是本版本,请先参考官网[安装](https://www.paddlepaddle.org.cn/install/quick) PaddlePaddle 2.3.0-rc0。" |
25 | 25 | ] |
26 | 26 | }, |
27 | 27 | { |
28 | 28 | "cell_type": "code", |
29 | | - "execution_count": null, |
| 29 | + "execution_count": 1, |
30 | 30 | "metadata": { |
31 | 31 | "collapsed": false |
32 | 32 | }, |
|
35 | 35 | "name": "stdout", |
36 | 36 | "output_type": "stream", |
37 | 37 | "text": [ |
38 | | - "2.2.2\n" |
| 38 | + "2.3.0-rc0\n" |
39 | 39 | ] |
40 | 40 | } |
41 | 41 | ], |
|
58 | 58 | }, |
59 | 59 | { |
60 | 60 | "cell_type": "code", |
61 | | - "execution_count": null, |
| 61 | + "execution_count": 2, |
62 | 62 | "metadata": { |
63 | 63 | "collapsed": false |
64 | 64 | }, |
65 | | - "outputs": [ |
66 | | - { |
67 | | - "name": "stdout", |
68 | | - "output_type": "stream", |
69 | | - "text": [ |
70 | | - "download training data and load training data\n", |
71 | | - "load finished\n" |
72 | | - ] |
73 | | - } |
74 | | - ], |
| 65 | + "outputs": [], |
75 | 66 | "source": [ |
76 | 67 | "from paddle.vision.transforms import Compose, Normalize\n", |
77 | 68 | "\n", |
|
96 | 87 | }, |
97 | 88 | { |
98 | 89 | "cell_type": "code", |
99 | | - "execution_count": null, |
| 90 | + "execution_count": 3, |
100 | 91 | "metadata": { |
101 | 92 | "collapsed": false |
102 | 93 | }, |
|
107 | 98 | "text": [ |
108 | 99 | "train_data0 label is: [5]\n" |
109 | 100 | ] |
110 | | - }, |
111 | | - { |
112 | | - "data": { |
113 | | - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJIAAACPCAYAAAARM4LLAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvIxREBQAACHBJREFUeJzt3V1oVOkZB/D/42j8ql9pZInZYBYVIRT8INYWi0atH13Q4E2JilZZWC/8aMFgTb3QCy+KQi803ixWUrGmFGvYtSwEXcyFuEgSDDbZNasuxs3i1yJq0QtdeXsxx+k8B5M5mXnmnDOZ/w9Czv+cZM4LPr7zzjmTZ8Q5B6JcjYp6ADQysJDIBAuJTLCQyAQLiUywkMgEC4lMsJDIRE6FJCJrRaRPRG6LyH6rQVHhkWyvbItIAsA3AFYBGADQAWCjc+6rwX6nrKzMVVVVZXU+ikZXV9cPzrnpmX5udA7n+DmA2865bwFARP4BoA7AoIVUVVWFzs7OHE5JYROR/iA/l8tTWwWA79LygLfPP5CPRaRTRDofP36cw+kozvK+2HbOfeKcq3HO1UyfnnGGpAKVSyF9D6AyLb/v7aMilEshdQCYIyIfiEgJgHoAn9kMiwpN1ott59yPIrILQBuABIBTzrles5FRQcnlVRucc58D+NxoLFTAeGWbTLCQyAQLiUywkMgEC4lMsJDIBAuJTLCQyAQLiUywkMgEC4lM5HSvrZi8efNG5WfPngX+3aamJpVfvnypcl9fn8onTpxQuaGhQeWWlhaVx40bp/L+/f9/+/zBgwcDjzMXnJHIBAuJTLCQyETRrJHu3bun8qtXr1S+evWqyleuXFH56dOnKp87d85sbJWVlSrv3r1b5dbWVpUnTZqk8rx581RetmyZ2diC4oxEJlhIZIKFRCZG7Brp+vXrKq9YsULl4VwHspZIJFQ+fPiwyhMnTlR58+bNKs+YMUPladOmqTx37txchzhsnJHIBAuJTLCQyMSIXSPNnDlT5bKyMpUt10iLFy9W2b9muXz5ssolJSUqb9myxWwsUeGMRCZYSGSChUQmRuwaqbS0VOWjR4+qfOHCBZUXLFig8p49e4Z8/Pnz56e2L126pI75rwP19PSofOzYsSEfuxBxRiITGQtJRE6JyCMR6UnbVyoiF0Xklvd92lCPQSNfkBmpGcBa3779AL5wzs0B8IWXqYgFao8sIlUA/u2c+5mX+wDUOufui0g5gHbnXMYbPDU1NS4uXW2fP3+usv89Pjt27FD55MmTKp85cya1vWnTJuPRxYeIdDnnajL9XLZrpPecc/e97QcA3svycWiEyHmx7ZJT2qDTGtsjF4dsC+mh95QG7/ujwX6Q7ZGLQ7bXkT4D8DsAf/a+f2o2opBMnjx5yONTpkwZ8nj6mqm+vl4dGzWq+K6qBHn53wLgSwBzRWRARD5CsoBWicgtAL/2MhWxjDOSc27jIIdWGo+FCljxzcGUFyP2XluuDh06pHJXV5fK7e3tqW3/vbbVq1fna1ixxRmJTLCQyAQLiUxk/VGk2YjTvbbhunPnjsoLFy5MbU+dOlUdW758uco1NfpW1c6dO1UWEYsh5kW+77URKSwkMsGX/wHNmjVL5ebm5tT29u3b1bHTp08PmV+8eKHy1q1bVS4vL892mJHhjEQmWEhkgoVEJrhGytKGDRtS27Nnz1bH9u7dq7L/FkpjY6PK/f39Kh84cEDlioqKrMcZFs5IZIKFRCZYSGSCt0jywN9K2f/n4du2bVPZ/2+wcqV+z+DFixftBjdMvEVCoWIhkQkWEpngGikCY8eOVfn169cqjxkzRuW2tjaVa2tr8zKud+EaiULFQiITLCQywXttBm7cuKGy/yO4Ojo6VPavifyqq6tVXrp0aQ6jCwdnJDLBQiITLCQywTVSQP6PVD9+/Hhq+/z58+rYgwcPhvXYo0frfwb/e7YLoU1O/EdIBSFIf6RKEbksIl+JSK+I/N7bzxbJlBJkRvoRwF7nXDWAXwDYKSLVYItkShOk0dZ9APe97f+KyNcAKgDUAaj1fuxvANoB/DEvowyBf11z9uxZlZuamlS+e/du1udatGiRyv73aK9fvz7rx47KsNZIXr/tBQCugS2SKU3gQhKRnwD4F4A/OOdUt/OhWiSzPXJxCFRIIjIGySL6u3Pu7WvdQC2S2R65OGRcI0my58pfAXztnPtL2qGCapH88OFDlXt7e1XetWuXyjdv3sz6XP6PJt23b5/KdXV1KhfCdaJMglyQXAJgC4D/iEi3t+9PSBbQP712yf0AfpufIVIhCPKq7QqAwTpBsUUyAeCVbTIyYu61PXnyRGX/x2R1d3er7G/lN1xLlixJbfv/1n/NmjUqjx8/PqdzFQLOSGSChUQmWEhkoqDWSNeuXUttHzlyRB3zvy96YGAgp3NNmDBBZf/Ht6ffH/N/PHsx4oxEJlhIZKKgntpaW1vfuR2E/0981q1bp3IikVC5oaFBZX93f9I4I5EJFhKZYCGRCba1oSGxrQ2FioVEJlhIZIKFRCZYSGSChUQmWEhkgoVEJlhIZIKFRCZYSGQi1HttIvIYyb/KLQPwQ2gnHp64ji2qcc10zmVs2hBqIaVOKtIZ5EZgFOI6triO6y0+tZEJFhKZiKqQPonovEHEdWxxHReAiNZINPLwqY1MhFpIIrJWRPpE5LaIRNpOWUROicgjEelJ2xeL3uGF2Ns8tEISkQSAEwB+A6AawEavX3dUmgGs9e2LS+/wwutt7pwL5QvALwG0peVGAI1hnX+QMVUB6EnLfQDKve1yAH1Rji9tXJ8CWBXX8TnnQn1qqwDwXVoe8PbFSex6hxdKb3Mutgfhkv/tI31Jm21v8yiEWUjfA6hMy+97++IkUO/wMOTS2zwKYRZSB4A5IvKBiJQAqEeyV3ecvO0dDkTYOzxAb3Mgbr3NQ140fgjgGwB3AByIeAHbguSH9bxGcr32EYCfIvlq6BaASwBKIxrbr5B82roBoNv7+jAu43vXF69skwkutskEC4lMsJDIBAuJTLCQyAQLiUywkMgEC4lM/A+jN2A4bkW+2gAAAABJRU5ErkJggg==\n", |
114 | | - "text/plain": [ |
115 | | - "<Figure size 144x144 with 1 Axes>" |
116 | | - ] |
117 | | - }, |
118 | | - "metadata": {}, |
119 | | - "output_type": "display_data" |
120 | 101 | } |
121 | 102 | ], |
122 | 103 | "source": [ |
|
141 | 122 | }, |
142 | 123 | { |
143 | 124 | "cell_type": "code", |
144 | | - "execution_count": null, |
| 125 | + "execution_count": 4, |
145 | 126 | "metadata": { |
146 | 127 | "collapsed": false |
147 | 128 | }, |
|
197 | 178 | }, |
198 | 179 | { |
199 | 180 | "cell_type": "code", |
200 | | - "execution_count": null, |
| 181 | + "execution_count": 5, |
201 | 182 | "metadata": { |
202 | 183 | "collapsed": false |
203 | 184 | }, |
204 | | - "outputs": [], |
| 185 | + "outputs": [ |
| 186 | + { |
| 187 | + "name": "stderr", |
| 188 | + "output_type": "stream", |
| 189 | + "text": [ |
| 190 | + "W0422 18:56:10.020583 19533 gpu_context.cc:244] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.1, Runtime API Version: 10.1\n", |
| 191 | + "W0422 18:56:10.026566 19533 gpu_context.cc:272] device: 0, cuDNN Version: 7.6.\n" |
| 192 | + ] |
| 193 | + } |
| 194 | + ], |
205 | 195 | "source": [ |
206 | 196 | "from paddle.metric import Accuracy\n", |
207 | 197 | "model = paddle.Model(LeNet()) # 用Model封装模型\n", |
|
217 | 207 | }, |
218 | 208 | { |
219 | 209 | "cell_type": "code", |
220 | | - "execution_count": 12, |
| 210 | + "execution_count": 6, |
221 | 211 | "metadata": { |
222 | 212 | "collapsed": false |
223 | 213 | }, |
|
228 | 218 | "text": [ |
229 | 219 | "The loss value printed in the log is the current step, and the metric is the average value of previous steps.\n", |
230 | 220 | "Epoch 1/2\n", |
231 | | - "step 938/938 [==============================] - loss: 0.0132 - acc: 0.9585 - 10ms/step \n", |
| 221 | + "step 20/938 [..............................] - loss: 1.4646 - acc: 0.3828 - ETA: 17s - 19ms/ste" |
| 222 | + ] |
| 223 | + }, |
| 224 | + { |
| 225 | + "name": "stdout", |
| 226 | + "output_type": "stream", |
| 227 | + "text": [ |
| 228 | + "step 30/938 [..............................] - loss: 1.1068 - acc: 0.4672 - ETA: 14s - 16ms/stepstep 938/938 [==============================] - loss: 0.1653 - acc: 0.9273 - 11ms/step \n", |
232 | 229 | "Epoch 2/2\n", |
233 | | - "step 938/938 [==============================] - loss: 0.0075 - acc: 0.9850 - 10ms/step \n" |
| 230 | + "step 938/938 [==============================] - loss: 0.0199 - acc: 0.9767 - 11ms/step \n" |
234 | 231 | ] |
235 | 232 | } |
236 | 233 | ], |
|
254 | 251 | }, |
255 | 252 | { |
256 | 253 | "cell_type": "code", |
257 | | - "execution_count": 13, |
| 254 | + "execution_count": 7, |
258 | 255 | "metadata": { |
259 | 256 | "collapsed": false |
260 | 257 | }, |
|
264 | 261 | "output_type": "stream", |
265 | 262 | "text": [ |
266 | 263 | "Eval begin...\n", |
267 | | - "step 157/157 [==============================] - loss: 1.6993e-04 - acc: 0.9865 - 8ms/step \n", |
| 264 | + "step 157/157 [==============================] - loss: 0.0048 - acc: 0.9780 - 8ms/step \n", |
268 | 265 | "Eval samples: 10000\n" |
269 | 266 | ] |
270 | 267 | }, |
271 | 268 | { |
272 | 269 | "data": { |
273 | 270 | "text/plain": [ |
274 | | - "{'loss': [0.0001699343], 'acc': 0.9865}" |
| 271 | + "{'loss': [0.0047780997], 'acc': 0.978}" |
275 | 272 | ] |
276 | 273 | }, |
277 | | - "execution_count": 13, |
| 274 | + "execution_count": 7, |
278 | 275 | "metadata": {}, |
279 | 276 | "output_type": "execute_result" |
280 | 277 | } |
|
306 | 303 | }, |
307 | 304 | { |
308 | 305 | "cell_type": "code", |
309 | | - "execution_count": null, |
| 306 | + "execution_count": 8, |
310 | 307 | "metadata": { |
311 | 308 | "collapsed": false |
312 | 309 | }, |
|
315 | 312 | "name": "stdout", |
316 | 313 | "output_type": "stream", |
317 | 314 | "text": [ |
318 | | - "epoch: 0, batch_id: 0, loss is: [2.8395634], acc is: [0.0625]\n", |
319 | | - "epoch: 0, batch_id: 300, loss is: [0.2528286], acc is: [0.890625]\n", |
320 | | - "epoch: 0, batch_id: 600, loss is: [0.02093708], acc is: [1.]\n", |
321 | | - "epoch: 0, batch_id: 900, loss is: [0.06315502], acc is: [0.984375]\n" |
| 315 | + "epoch: 0, batch_id: 0, loss is: [3.7514806], acc is: [0.21875]\n", |
| 316 | + "epoch: 0, batch_id: 300, loss is: [0.19029362], acc is: [0.953125]\n", |
| 317 | + "epoch: 0, batch_id: 600, loss is: [0.12201739], acc is: [0.953125]\n", |
| 318 | + "epoch: 0, batch_id: 900, loss is: [0.03218058], acc is: [0.984375]\n", |
| 319 | + "epoch: 1, batch_id: 0, loss is: [0.114471], acc is: [0.953125]\n", |
| 320 | + "epoch: 1, batch_id: 300, loss is: [0.00857661], acc is: [1.]\n", |
| 321 | + "epoch: 1, batch_id: 600, loss is: [0.10740176], acc is: [0.96875]\n", |
| 322 | + "epoch: 1, batch_id: 900, loss is: [0.19590104], acc is: [0.9375]\n" |
322 | 323 | ] |
323 | 324 | } |
324 | 325 | ], |
|
360 | 361 | }, |
361 | 362 | { |
362 | 363 | "cell_type": "code", |
363 | | - "execution_count": null, |
| 364 | + "execution_count": 9, |
364 | 365 | "metadata": { |
365 | 366 | "collapsed": false |
366 | 367 | }, |
|
369 | 370 | "name": "stdout", |
370 | 371 | "output_type": "stream", |
371 | 372 | "text": [ |
372 | | - "batch_id: 0, loss is: [0.01972857], acc is: [0.984375]\n", |
373 | | - "batch_id: 20, loss is: [0.19958115], acc is: [0.9375]\n", |
374 | | - "batch_id: 40, loss is: [0.23575728], acc is: [0.953125]\n", |
375 | | - "batch_id: 60, loss is: [0.07018849], acc is: [0.984375]\n", |
376 | | - "batch_id: 80, loss is: [0.02309197], acc is: [0.984375]\n", |
377 | | - "batch_id: 100, loss is: [0.00239462], acc is: [1.]\n", |
378 | | - "batch_id: 120, loss is: [0.01583934], acc is: [1.]\n", |
379 | | - "batch_id: 140, loss is: [0.00399609], acc is: [1.]\n" |
| 373 | + "batch_id: 0, loss is: [0.04440754], acc is: [0.984375]\n", |
| 374 | + "batch_id: 20, loss is: [0.19196557], acc is: [0.9375]\n", |
| 375 | + "batch_id: 40, loss is: [0.09817676], acc is: [0.984375]\n", |
| 376 | + "batch_id: 60, loss is: [0.16782945], acc is: [0.953125]\n", |
| 377 | + "batch_id: 80, loss is: [0.05786889], acc is: [0.96875]\n", |
| 378 | + "batch_id: 100, loss is: [0.00799548], acc is: [1.]\n", |
| 379 | + "batch_id: 120, loss is: [0.00511317], acc is: [1.]\n", |
| 380 | + "batch_id: 140, loss is: [0.01672031], acc is: [1.]\n" |
380 | 381 | ] |
381 | 382 | } |
382 | 383 | ], |
|
0 commit comments