萧箫发自凹非寺

  量子位公众号 QbitAI

  只需给大模型“加点小零件”,推理速度立刻提升 2 倍!

  不需要额外训练一个模型,也不需要对计算硬件做优化,单张 A100 最快几小时就能微调完成。

  这项新研究名叫 Medusa(美杜莎),来自普林斯顿、UIUC、CMU 和康涅狄格大学,FlashAttention 作者 Tri Dao 也在其中。

  目前,它已经成功部署到伯克利 70 亿参数的“骆马”Vicuna中,后续还会支持其他大模型,已经登上 GitHub 热榜:

  但其实,在这种方法推出之前,业界并非没有大模型推理加速方法,主流的就是 DeepMind 推出的投机采样(speculative decoding)。

  相比这种方法,Medusa 有什么不一样的地方?

  投机采样的 2 个“bug”

  要想加速大模型推理,需要先知道究竟是什么“限制”了它的速度。

  相比计算量的增加,大模型推理速度更容易受到内存带宽的影响(memory bound)。

  这是因为,大模型由于参数量巨大、远超缓存容量,因此推理时需要先把权重从外部内存(显存)读取一次到缓存中,这个过程受内存带宽限制,速度通常很慢。

  因此,模型做批量推理(batch inference)时,一次处理 100 个 tokens 和一个 tokens 时间上区别不大。

  基于这个特点,DeepMind 去年 11 月想出了一个名叫投机采样的神奇操作——

  训练一个更小的模型(draft 模型),给大模型提前生成一批“候选词”,相比于让大模型自己“思考”生成,直接做“选择”就好。

  由于小模型生成速度比大模型快好几倍,一旦大模型觉得小模型已有的词“可用”,就直接拿来,不用自己再缓慢生成一遍。

  这个过程,有点像是输入法的联想词候选,在我们(大模型)想好下一个词用什么之前,输入法(小模型)先给列出一些备选项:

  要是看到觉得不错,就从中选一个用;要是觉得生成的都不行,就 pass 掉自己重新打。

  这种投机采样方法确实取得了显著成效,甚至能轻轻松松在 M2 Ultra 上以高精度跑 340 亿参数 LLaMA 大模型。

  BUT,这种方法存在两个问题。

  一方面,给大模型找个生成“候选词”的 draft 小模型,没那么容易。

  这个小模型可不是随便抓个生成模型就能用,除了接口统一、概率分布接近等要求,生成质量也不能比大模型差太多。

  对于 Meta 发布的 LLaMA 这种模型可能还好,既有几百亿参数的大模型版本,又有几十亿参数的小模型版本,可以把参数量更小的版本拿来当 draft 模型使用。

  但对于其他开源大模型,这种方法就不太适用了,自己去搭建训练一个小模型,不仅时间成本更高,生成效果可能还不达预期。

  另一方面,双模型的组合,使得后续要想做系统调优变得更复杂。

  这是因为,相比于大模型自身是一个系统,新增加的 draft 模型相当于又引入了一个系统。

  这样会导致模型部署起来更复杂,包括额外的网络传输、不同的硬件条件都需要考虑到,在做计算优化时难度也会进一步提升。

  为了解决这些问题,Medusa 出现了。

  不用小模型,加几个“头”就行

  Medusa(美杜莎,一种长有多个头的妖怪)是一种新的大模型推理加速方法。

  相比投机采样,它选择直接给 Transformer 大模型多加几个解码头(decoding heads),每个头都是一个单层前馈网络。

  这几个多出来的解码头,可以让大模型直接一次多生成几个词,而不是“挤牙膏式”一个一个生成。

  生成准确率也还可以,在预测“下一个词的下一个词”时,Medusa 准确率达到了 60%,还在不断优化中。

  随后,结合树状注意力机制(tree-based attention mechanism)并行验证这些词,从而实现推理加速。

  基于 Medusa,Vicuna 的 70 亿、130 亿和 330 亿参数大模型推理速度,均有了1. 9 倍以上的效率提升:

  针对 70 亿参数的模型,研究者们还在不同任务上测试了一下加速效果,显示最高在代码生成上有 2.15 倍的速度提升。

  最关键的是,用上 Medusa 后,并不需要将整个大模型重新训练一遍。

  相比之下,它可以和大模型一起训练,只需要冻结大模型的参数就行,甚至单个 GPU 就能搞定。

  由于不增加额外的模型,对于分布式推理也很友好。

  作者介绍

  这项研究有两位共同一作。

  共同一作蔡天乐,普林斯顿大学博士生,研究方向包括优化、表示学习、架构设计等,本科毕业于北京大学数学科学学院,获得应用数学和计算机科学双学位。

  共同一作 Yuhong (Jesse) Li,伊利诺伊大学香槟分校(UIUC)博士生,研究方向是高效机器学习,本科毕业于北京邮电大学。

  此外,这项研究也有 FlashAttention 作者、斯坦福博士 Tri Dao 的参与。

  FlashAttention 是一种能加快注意力并减少内存占用的方法,相比 PyTorch 标准注意力实现,最高能提速 9 倍。

  GitHub 地址:

  https://github.com/FasterDecoding/Medusa

  研究地址:

  https://sites.google.com/view/medusa-llm