前言 作为一名每天与神经网络训练/测试打交道的同学,是否经常会遇到以下这几个问题,时常怀疑人生:
总而言之,当模型效果不如预期的时候去调试深度学习网络是一件头疼且繁琐的事情,为了让这件麻烦事情更加仅仅有条,笔者结合实际经验简单整理了一些checklist,方便广大炼丹师傅掌握火候。 1. 从最简单的数据/模型开始 现在开源社区做的很好,同学们用模型也十分方便,但也有相应的问题。以句子情感识别为例,新入手的同学可能一上来就调出HuggingFace/transformers代码库,然后一股脑BERT/Roberta啥的跑个结果,当然文档做好的开源代码一般都能照着跑个好结果,但改到自己数据集上往往就懵逼了,啊这?51%的二分类准确率(可能夸张了点,但如果任务要比二分类稍微复杂点,基本结果不会如预期),也太差了吧,HuggingFace/transofmrers这些模型不行啊。算了,咱换一个库吧,再次求助github和谷歌搜索。其实可能都还不清楚数据输入格式对不对?数据量够不够?评测指标含义是否清楚?Roberta的tokenizer是咋做的?模型结构是什么样子? 所以第1个checklist是:请尽量简单!
模型简单:解决一个深度学习任务,最好是先自己搭建一个最简单的神经网络,就几层全连接的那种。 数据简单:一般来说少于10个样本做调试足够了,一定要做过拟合测试(特别是工作的同学,拿过来前人的代码直接改个小结构就跑全量数据训练7-8天是可能踩坑的哦,比如某tensorflow版本 GPU embedding查表,输入超出了vocab size维度甚至可能都不报错哦,但cpu又会报错)!如果你的模型无法在7、8个样本上过拟合,要么模型参数实在太少,要么有模型有bug,要么数据有bug。为什么不建议1个样本呢?多选几个有代表性的输入数据有助于直接测试出非法数据格式。但数据太多模型就很难轻松过拟合了,所以建议在10个以下,1个以上,基本ok了。 2. loss设计是否合理? loss决定了模型参数如何更新,所以记得确定一下你的loss是否合理?
3. 网络中间输出检查、网络连接检查 Pytorch已经可以让我们像写python一样单步debug了,所以输入输出shape对齐这步目前还挺好做的,基本上单步debug走一遍forward就能将网络中间输出shape对齐,连接也能对上,但有时候还是可能眼花看漏几个子网络的连接。 所以最好再外部测试一下每个参数的梯度是否更新了,训练前后参数是否都改变了。
读者可以参考stanford cs231n中的Gradient checking: https://cs231n./neural-networks-3/#gradcheck https://cs231n./optimization-1/#gradcompute 另外用tensorboard来检查一下网络连接/输入输出shape和连接关系也是不错的。 4. 时刻关注着模型参数 所谓模型参数也就是一堆矩阵/或者说大量的数值。如果这些数值中有些数值异常大/小,那么模型效果一般也会出现异常。一般来说,让模型参数保持正常有这么几个方法:
统计梯度下降中,我们需要的batch size要求是:1、batch size足够大到能让我们在loss反向传播时候正确估算出梯度;2、batch size足够小到统计梯度下降(SGD)能够一定程度上regularize我们的网络结构。batch size太小优化困难,太大又会导致:Generalization Gap和Sharp Minima(具体参考:论文https:///abs/1609.04836,On Large-Batch Training for Deep Learning: Generalization Gap and Sharp Minima)。 Tensorflow: https://www./api_docs/python/tf/train/exponential_decay Keras: https:///callbacks/#learningratescheduler。 5. 详细记录调试/调参数过程 可能这会儿换了一个learning rate,过会儿增大了dropout,过会儿又加了一个batch normalization,最后也不知道自己改了啥。
由于要实验或者改的地方太多,通常就时不时忘记/不方便使用git了,而是copy一大堆名字相似的文件,这个时候,请千万注意你的代码结构/命名规则,当然使用好bash脚本将使用的参数,训练过程一一存放起来也是不错的选择。
模型对数据/超参数,甚至是随机种子、GPU版本,tensorflow/pytorch版本,所以请尽可能记录好每个部分,并且最好时刻可以复现。最后小时候学的控制变量法也很重要哦。 总结 将以上内容做一个总结:
参考文献: https:///checklist-for-debugging-neural-networks-d8b2a9434f21 |
|