检查和可视化Torch FX图

💡 原文英文,约1200词,阅读约需5分钟。
📝

内容提要

本文介绍了如何使用Torch FX对PyTorch模块进行图形检查和可视化,特别是多层感知器(MLP)。通过符号追踪和导出生成图形,并使用FxGraphDrawer进行可视化。同时展示了记录函数调用和调度的方法,以帮助理解模块的结构和操作。

🎯

关键要点

  • PyTorch模块可以包含嵌套模块,理解其结构和操作较为复杂。
  • Torch FX图是PyTorch模块的中间表示,可以进行检查和可视化。
  • 本文展示了如何使用Torch FX的FxGraphDrawer对多层感知器(MLP)模块进行可视化。
  • 通过符号追踪和导出生成图形,并使用FxGraphDrawer进行可视化。
  • 使用TorchFunctionMode和TorchDispatchMode记录函数调用和调度,以帮助理解模块的结构和操作。
  • Torch FX符号追踪生成的图使用高层torch.nn.module描述模块结构和操作。
  • torch.export是用于将模型捕获为ExportedProgram对象的API,支持动态控制流。
  • 导出的ATen图使用ATen操作符描述模块结构,提供了张量形状信息。
  • Core ATen IR是ATen操作符的核心子集,功能完全且无副作用。
  • torch.cond和torch.loop支持动态控制流,导出时会将控制流的分支保存为子图。
  • FxGraphDrawer可视化导出模型时,子图需要单独保存以便可视化。
➡️

继续阅读