2021
07-09
07-09
解决Pytorch半精度浮点型网络训练的问题
用Pytorch1.0进行半精度浮点型网络训练需要注意下问题:1、网络要在GPU上跑,模型和输入样本数据都要cuda().half()2、模型参数转换为half型,不必索引到每层,直接model.cuda().half()即可3、对于半精度模型,优化算法,Adam我在使用过程中,在某些参数的梯度为0的时候,更新权重后,梯度为零的权重变成了NAN,这非常奇怪,但是Adam算法对于全精度数据类型却没有这个问题。另外,SGD算法对于半精度和全精度计算均没有问题。还有一...
继续阅读 >