Oryx架构:以共享表示打破大模型注意力计算瓶颈,开启混合序列建模新纪元

针对大语言模型中Softmax注意力机制随序列长度呈二次方增长的算力瓶颈,Oryx架构提出了一种在序列轴上灵活切换不同混合器的混合模型范式。其核心创新在于至少90%的参数在混合器间共享,使模型能动态选择二次方复杂度的注意力机制以捕捉关键上下文,或采用线性递归机制实现高效生成。实验显示,在1.4B参数规模下,Oryx在语言建模任务中性能优于单一混合器基线,且在检索任务中仅需处理不到10%的Token即可达到Transformer同等水平,证明了共享表示混合架构在平衡计算效率与上下文理解能力方面的巨大潜力。

现代大语言模型的性能基石在于Softmax注意力机制,但其随着序列长度增加而呈现的线性内存增长和二次方计算复杂度,成为了制约长上下文处理效率的瓶颈。尽管线性递归模型如线性注意力和状态空间模型因具备线性计算和恒定内存优势而受到广泛关注,但它们在需要长上下文检索或上下文学习(in-context learning)的任务上仍落后于注意力模型。现有的混合架构尝试通过静态交错或合并注意力与递归块来缓解这一权衡,但缺乏灵活性。本研究提出了一种全新的混合模型开发维度:跨序列轴的动态混合。我们提出了Oryx架构,它允许模型在整个序列处理过程中,根据上下文需求灵活地在不同的混合器之间切换。例如,在需要深入理解复杂语义的关键位置使用二次方复杂度的注意力机制以充分利用上下文,而在生成阶段或简单序列段使用线性递归机制以追求极致的效率。这种设计旨在打破传统单一架构在效率与能力之间的零和博弈,实现两者的最优平衡。Oryx的技术核心在于其参数共享机制与动态路由策略。

不同于以往将不同模块简单堆叠的做法,Oryx将至少90%的参数在注意力混合器和线性递归混合器之间共享。这意味着两种模式并非独立运行,而是操作于高度一致的共享内部表示之上。这种设计不仅大幅减少了模型的整体参数量,还确保了模式切换时的语义一致性,避免了因表示空间不匹配导致的性能下降。在具体实现上,我们验证了基于Mamba-2和Gated DeltaNet两种先进线性递归变体的Oryx实例,模型规模最高达到1.4B参数。训练策略上,采用了混合训练方法,即在训练过程中动态地让模型在不同序列位置体验不同的混合器模式,从而学习到何时使用何种混合器最为有效。这种跨序列轴的混合化策略,使得模型能够自适应地分配计算资源,在关键节点投入高精度计算,在次要节点采用低开销处理。在实验评估方面,我们在多个标准基准上对Oryx进行了全面测试,并与单一混合器基线进行了严格对比。在固定token预算和混合训练策略下,Oryx展现了显著的优势。

特别是在1.4B参数规模下,所有Oryx实例在平均语言建模任务上的表现均优于其对应的单一混合器基线,提升幅度至少达到0.7个百分点。这一结果证实了共享表示混合架构在语言建模任务上的有效性。更令人印象深刻的是在检索任务上的表现:Oryx仅需以二次方复杂度的注意力模式处理序列中不到10%的关键token,即可达到与全注意力Transformer基线相当的性能。这意味着模型能够智能地识别并聚焦于对检索至关重要的信息片段,而忽略无关噪音,从而在保持高精度的同时极大降低了计算开销。消融实验进一步揭示了参数共享比例和混合策略对最终性能的影响,证明了90%以上的参数共享是实现高效混合的关键。Oryx的提出对开源社区和工业落地具有深远意义。首先,它证明了注意力机制和线性递归模型并非互斥,而是可以通过共享内部表示协同工作,为后续研究提供了新的理论视角和技术路径。其次,这种灵活的序列轴混合化策略为开发更高效、更强大的大语言模型提供了切实可行的方案,特别是在资源受限的边缘设备或需要长上下文处理的工业场景中,具有巨大的应用潜力。对于开源社区而言,Oryx的代码和模型权重将促进更多研究者探索混合架构的边界,加速AI基础设施的优化。未来,随着混合器类型的丰富和共享机制的完善,Oryx范式有望成为下一代高效大语言模型的主流架构之一,推动AI技术在更广泛领域的落地应用。