本文接上文,继续学习TensorFlow在CIFAR-10上的教程,该代码主要由以下五部分组成:
文件 | 作用 |
---|---|
cifar10_input.py |
读取原始的 CIFAR-10 二进制格式文件 |
cifar10.py |
建立 CIFAR-10 网络模型 |
cifar10_train.py |
在单块CPU或者GPU上训练 CIFAR-10 模型 |
cifar10_multi_gpu_train.py |
在多块GPU上训练 CIFAR-10 模型 |
cifar10_eval.py |
在测试集上评估 CIFAR-10 模型的表现 |
本次主要学习cifar10_train.py
和cifar10_eval.py
两个文件,内容分别为训练模型和评估模型,并最终给出实验过程。
教程地址:https://www.tensorflow.org/tutorials/deep_cnn
代码地址:https://github.com/tensorflow/models/tree/master/tutorials/image/cifar10
训练模型
这部分代码在cifar10_train.py
文件中,实现了用单块GPU训练模型,具体训练过程设计为:
- 共计100万次迭代(自己实验时改成了10万次)
- batch_size为128
- 每10次迭代打印一次训练数据(损失、样本/秒、秒/batch)
- 每600s保存一次checkpoint文件
- 每300s对最新的checkpoint文件执行一次评估
- 每100次迭代保存一次summary
具体代码如下:
|
|
补充
|
|
每600s保存一次checkpoint,每100s保存一次summary。
评估模型
这部分代码在cifar10_eval.py
中,默认每300s执行一次评估,具体流程:
evaluate()
负责创建和维护整个评估过程:
- 获得测试数据
- 搭建神经网络模型(和训练过程一样)
- 创建
saver
,saver
负责恢复shadow variable
的值并赋给variable
- 每隔固定的间隔(300s),运行一次
eval_once()
eval_once()
负责完成一次评估,步骤是:
- 从checkpoint中取出最新模型
- 运行
saver.restore
从checkpoint
中恢复shadow variable
的值并赋给variable
。 - 运行神经网络,对测试集的数据按批次进行预测
- 计算整个测试集的预测精度
具体代码如下:
|
|
补充
|
|
判断targets
是否在top k的预测之中。输出batch_size
大小的bool数组,如果对目标累的预测在所有预测的top k中,则out[i]=True
。
实验过程
作者在单块Tesla K40中训练了10万次用了8小时(350 - 600 images/sec),我在单块Quadro M5000上只用了46分钟(4800~5000 images/sec),下面是训练过程:
|
|
训练和评估过程是放在两个程序分开进行的,具体的实现方法是,在训练过程中,为每个训练变量添加指数滑动平均变量,然后每600s就将模型训练到的变量值保存在checkpoint中,评估过程运行时,从最新存储的checkpoint中取出模型的shadow variable
,赋值给对应的变量,然后进行评估。
我们需要同时运行两个程序才能实时的对训练过程进行评估,否则得到的永远只是最新的checkpoint文件中的评估结果。具体可以先运行python cifar_train.py
,再打开另一个窗口运行python cifar_eval.py
。
官方给的代码最大迭代次数是100万,我运行的时候改成了10万。
因为我的迭代速度太快了,到600s时第一次保存checkpoint就已经是两万多次迭代了,可以通过修改tf.train.MonitoredTrainingSession()
函数的save_checkpoint_secs
参数来修改保存checkpoint的时间间隔,默认600s。
最终10万次迭代后的评估准确率是86.2%,和官方给出的数据还是吻合的。
最后来张TensorBoard的图:
参考
- Convolutional Neural Networks | TensorFlow
- CIFAR-10 and CIFAR-100 datasets
- Images | TensorFlow
- tf.strided_slice | TensorFlow
- TensorFlow学习笔记(11):数据操作指南 - 数据实验室 - SegmentFault
- tf.nn.local_response_normalization | TensorFlow
- tf.train.MonitoredTrainingSession | TensorFlow
- TensorFlow官网教程Convolutional Neural Networks 难点详解 - 玛莎鱼的博客 - CSDN博客