谁动了我的显存?——深度学习训练过程显存占用分析及优化

谁动了我的显存?——深度学习训练过程显存占用分析及优化

💡 原文中文,约8100字,阅读约需20分钟。
📝

内容提要

在大语言模型时代,显存不足是一个突出问题。本文分析了深度学习训练中的显存占用,包括框架、模型参数和特征相关的占用。通过样例程序展示了不同情况下的显存需求。使用torch.autograd.Function实现算子融合可以节约显存开销。介绍了查看pytorch自带算子保存的变量的方法。

🎯

关键要点

  • 在大语言模型时代,显存不足是一个突出问题。
  • 深度学习训练中的显存占用分为框架占用、模型参数相关占用和特征相关占用。
  • 框架占用如pytorch的cuda context占用几百MB显存。
  • 模型参数占用以FP16格式的7B模型需要14GB显存,优化器和梯度相关参数也占用显存。
  • 特征相关的显存占用与模型计算流程有关,具体比例系数难以分析。
  • 使用样例程序计算(x+1)(y+1)的显存需求,区分峰值显存占用与持续显存占用。
  • 不需要计算梯度时,显存占用较低,计算结束后几乎不占显存。
  • 需要计算梯度时,临时变量不会被释放,显存占用增加。
  • 通过torch.autograd.Function实现算子融合可以节约显存开销。
  • 使用pytorch自带的融合算子如sigmoid,显存占用显著降低。
  • 算子融合是深度学习编译器的核心技术,优化仍需人工设计。
  • 可以通过grad_fn属性查看pytorch自带算子保存的变量。
➡️

继续阅读