Moonshot AI 新突破:MoBA 为大语言模型长文本处理提效论文速读

news/2025/2/22 4:30:38

前言

自然语言处理领域,随着大语言模型(LLMs)不断拓展其阅读、理解和生成文本的能力,如何高效处理长文本成为一项关键挑战。近日,Moonshot AI Research 联合清华大学、浙江大学的研究人员提出了一种创新方法 —— 混合块注意力机制(Mixture of Block Attention,MoBA),它将专家混合(Mixture of Experts,MoE)原理应用于注意力机制,为解决长文本处理难题带来了新的思路。

在 Transformer 架构广泛应用的当下,其注意力机制存在明显弊端。在处理长文本时,传统注意力机制需将每个 token 与其他所有 token 进行比较,这使得计算成本随序列长度呈二次方增长。当模型处理长篇文档、多章书籍、法律简报或大型代码库等包含大量文本信息的任务时,这种计算成本会变得难以承受。此前,为解决这一问题,研究人员尝试过多种方法。例如,滑动窗口机制将 token 限制在局部邻域内,虽降低了计算量,但会忽略重要的全局关系;而一些彻底改变基本架构的方法,如用全新结构替代 softmax 注意力机制,往往需要从头开始重新训练模型,难以利用现有的预训练成果。

核心原理

MoBA 的出现有效弥补了上述方法的不足。它的核心在于将输入划分为易于管理的 “块”,并借助可训练的门控系统来确定每个查询 token 相关的块。这种设计遵循 “少结构” 原则,不预先定义哪些 token 应该相互作用,而是由学习到的门控网络做出决策。与固定结构或近似处理的方法不同,MoBA 能让模型自主学习注意力的聚焦点。而且,MoBA 可与现有的基于 Transformer 的模型无缝协作,它作为一种 “插件” 或替代方案,保持与原模型相同的参数数量,避免架构膨胀,同时保留因果掩码,确保自回归生成的准确性。在实际应用中,MoBA 能在稀疏注意力和全注意力之间灵活切换。处理超长输入时,稀疏注意力可提升速度;而在训练的某些层或阶段,若需要全注意力,模型也能切换回标准模式。

从技术细节来看,MoBA 将上下文划分为多个块,每个块包含连续的 token 序列。门控机制通过比较查询 token 与块的池化键表示,计算查询 token 与每个块之间的 “亲和度” 分数,然后选择得分最高的块。这样,只有最相关块中的 token 才会对最终的注意力分布产生影响。同时,包含查询 token 本身的块始终被纳入,以确保局部上下文信息可访问。并且,MoBA 执行因果掩码,防止 token 关注未来位置,维持从左到右的自回归属性。这种基于块的方法大幅减少了 token 比较次数,使计算规模低于二次方,随着上下文长度增加到数十万甚至数百万个 token,效率提升愈发显著。此外,MoBA 与现代加速器和专用内核兼容性良好。研究人员将 MoBA 与 FlashAttention(一种高性能的快速、内存高效的精确注意力库)相结合,根据所选块对查询 - 键 - 值操作进行精心分组,进一步优化了计算流程。实验数据显示,在处理一百万个 token 时,MoBA 相比传统全注意力机制速度提升约 6 倍,凸显了其在实际应用中的优势。

在性能测试方面,MoBA 表现出色。技术报告显示,在多种任务中,MoBA 的性能与全注意力机制相当,但在处理长序列时可显著节省计算资源。在语言建模数据测试中,当序列长度为 8192 或 32768 个 token 时,MoBA 的困惑度与全注意力 Transformer 相近。更为关键的是,当研究人员将上下文长度逐渐扩展到 128000 及更长时,MoBA 仍能保持强大的长上下文理解能力。在 “尾随 token” 评估中,MoBA 能够有效处理长提示末尾附近的 token 预测任务,且预测质量没有明显下降。研究人员还对 MoBA 的块大小和门控策略进行了敏感性探索。实验表明,细化粒度(使用更小的块但选择更多的块)有助于模型更接近全注意力的效果。即使在忽略大部分上下文的情况下,自适应门控也能识别与查询真正相关的块。此外,“混合” 模式展现出一种平衡策略:部分层继续使用 MoBA 提升速度,少数层则恢复全注意力。这种混合方法在监督微调任务中尤为有益,例如当输入中的某些位置在训练目标中被屏蔽时,保留少数上层的全注意力,可使模型保持广泛的上下文覆盖,有助于需要全局视角的任务。

关键代码分析:

以下是对 MoBA 库关键代码 MixedAttention 类的分析以及关键代码的摘录与注释:

整体分析

MixedAttention 类是一个自定义的 torch.autograd.Function,用于实现混合块注意力机制。这个类主要包含两个静态方法:forward 和 backward,分别用于前向传播和反向传播。

class MixedAttention(torch.autograd.Function):

    # 前向传播函数
    @staticmethod
    def forward(
        ctx,
        q,  # 查询张量
        k,  # 键张量
        v,  # 值张量
        self_attn_cu_seqlen,  # 自注意力累积序列长度
        moba_q,  # MoBA 查询张量
        moba_kv,  # MoBA 键值张量
        moba_cu_seqlen_q,  # MoBA 查询累积序列长度
        moba_cu_seqlen_kv,  # MoBA 键值累积序列长度
        max_seqlen,  # 最大序列长度
        moba_chunk_size,  # MoBA 块大小
        moba_q_sh_indices,  # MoBA 查询块索引
    ):
        # 保存一些参数,用于后续的反向传播
        ctx.max_seqlen = max_seqlen
        ctx.moba_chunk_size = moba_chunk_size
        ctx.softmax_scale = softmax_scale = q.shape[-1] ** (-0.5)

        # 自注意力计算
        _, _, _, _, self_attn_out_sh, self_attn_lse_hs, _, _ = (
            _flash_attn_varlen_forward(
                q=q,
                k=k,
                v=v,
                cu_seqlens_q=self_attn_cu_seqlen,
                cu_seqlens_k=self_attn_cu_seqlen,
                max_seqlen_q=max_seqlen,
                max_seqlen_k=max_seqlen,
                softmax_scale=softmax_scale,
                causal=True,
                dropout_p=0.0,
            )
        )

        # MoBA 注意力计算
        _, _, _, _, moba_attn_out, moba_attn_lse_hs, _, _ = _flash_attn_varlen_forward(
            q=moba_q,
            k=moba_kv[:, 0],
            v=moba_kv[:, 1],
            cu_seqlens_q=moba_cu_seqlen_q,
            cu_seqlens_k=moba_cu_seqlen_kv,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=moba_chunk_size,
            softmax_scale=softmax_scale,
            causal=False,
            dropout_p=0.0,
        )

        # 转换 lse 形状,从 hs 转换为 sh(遵循传统混合注意力逻辑)
        self_attn_lse_sh = self_attn_lse_hs.t().contiguous()
        moba_attn_lse = moba_attn_lse_hs.t().contiguous()

        # 初始化输出缓冲区,形状与 q 相同
        output = torch.zeros(
            (q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32
        )

        # 将输出张量展平为二维,便于后续索引操作
        output_2d = output.view(-1, q.shape[2])

        # 计算混合 lse
        # 减去最大 lse 以避免指数爆炸
        max_lse_1d = self_attn_lse_sh.view(-1)
        max_lse_1d = max_lse_1d.index_reduce(
            0, moba_q_sh_indices, moba_attn_lse.view(-1), "amax"
        )
        self_attn_lse_sh = self_attn_lse_sh - max_lse_1d.view_as(self_attn_lse_sh)
        moba_attn_lse = (
            moba_attn_lse.view(-1)
            .sub(max_lse_1d.index_select(0, moba_q_sh_indices))
            .reshape_as(moba_attn_lse)
        )

        # 计算自注意力和 MoBA 注意力的 softmax 结果
        mixed_attn_se_sh = self_attn_lse_sh.exp()
        moba_attn_se = moba_attn_lse.exp()

        # 将 MoBA 注意力结果累加到自注意力结果上
        mixed_attn_se_sh.view(-1).index_add_(
            0, moba_q_sh_indices, moba_attn_se.view(-1)
        )
        mixed_attn_lse_sh = mixed_attn_se_sh.log()

        # 加权自注意力输出
        factor = (self_attn_lse_sh - mixed_attn_lse_sh).exp()  # [ vS, H ]
        self_attn_out_sh = self_attn_out_sh * factor.unsqueeze(-1)
        output_2d += self_attn_out_sh.reshape_as(output_2d)

        # 加权 MoBA 输出
        mixed_attn_lse = (
            mixed_attn_lse_sh.view(-1)
            .index_select(0, moba_q_sh_indices)
            .view_as(moba_attn_lse)
        )
        factor = (moba_attn_lse - mixed_attn_lse).exp()  # [ vS, H ]
        moba_attn_out = moba_attn_out * factor.unsqueeze(-1)
        raw_attn_out = moba_attn_out.view(-1, moba_attn_out.shape[-1])
        output_2d.index_add_(0, moba_q_sh_indices, raw_attn_out)

        # 将输出转换为与输入相同的数据类型
        output = output.to(q.dtype)

        # 恢复最大 lse
        mixed_attn_lse_sh = mixed_attn_lse_sh + max_lse_1d.view_as(mixed_attn_se_sh)

        # 保存中间结果,用于反向传播
        ctx.save_for_backward(
            output,
            mixed_attn_lse_sh,
            q,
            k,
            v,
            self_attn_cu_seqlen,
            moba_q,
            moba_kv,
            moba_cu_seqlen_q,
            moba_cu_seqlen_kv,
            moba_q_sh_indices,
        )

        return output

    # 反向传播函数
    @staticmethod
    def backward(ctx, d_output):
        # 从上下文中获取保存的参数
        max_seqlen = ctx.max_seqlen
        moba_chunk_size = ctx.moba_chunk_size
        softmax_scale = ctx.softmax_scale

        (
            output,
            mixed_attn_vlse_sh,
            q,
            k,
            v,
            self_attn_cu_seqlen,
            moba_q,
            moba_kv,
            moba_cu_seqlen_q,
            moba_cu_seqlen_kv,
            moba_q_sh_indices,
        ) = ctx.saved_tensors

        # 确保输入梯度连续
        d_output = d_output.contiguous()

        # 计算自注意力的梯度
        dq, dk, dv, _ = _flash_attn_varlen_backward(
            dout=d_output,
            q=q,
            k=k,
            v=v,
            out=output,
            softmax_lse=mixed_attn_vlse_sh.t().contiguous(),
            dq=None,
            dk=None,
            dv=None,
            cu_seqlens_q=self_attn_cu_seqlen,
            cu_seqlens_k=self_attn_cu_seqlen,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=max_seqlen,
            softmax_scale=softmax_scale,
            causal=True,
            dropout_p=0.0,
            window_size=(-1, -1),
            softcap=0.0,
            alibi_slopes=None,
            deterministic=True,
        )

        # 计算 MoBA 注意力的梯度
        headdim = q.shape[-1]
        d_moba_output = (
            d_output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1)
        )
        moba_output = (
            output.view(-1, headdim).index_select(0, moba_q_sh_indices).unsqueeze(1)
        )

        mixed_attn_vlse = (
            mixed_attn_vlse_sh.view(-1).index_select(0, moba_q_sh_indices).view(1, -1)
        )

        dmq, dmk, dmv, _ = _flash_attn_varlen_backward(
            dout=d_moba_output,
            q=moba_q,
            k=moba_kv[:, 0],
            v=moba_kv[:, 1],
            out=moba_output,
            softmax_lse=mixed_attn_vlse,
            dq=None,
            dk=None,
            dv=None,
            cu_seqlens_q=moba_cu_seqlen_q,
            cu_seqlens_k=moba_cu_seqlen_kv,
            max_seqlen_q=max_seqlen,
            max_seqlen_k=moba_chunk_size,
            softmax_scale=softmax_scale,
            causal=False,
            dropout_p=0.0,
            window_size=(-1, -1),
            softcap=0.0,
            alibi_slopes=None,
            deterministic=True,
        )

        # 合并 MoBA 的键和值的梯度
        dmkv = torch.stack((dmk, dmv), dim=1)

        return dq, dk, dv, None, dmq, dmkv, None, None, None, None, None

代码关键部分解释

  • 前向传播 (forward)

    • 分别计算自注意力和 MoBA 注意力的结果。
    • 对注意力分数进行处理,包括形状转换、归一化等操作,以避免指数爆炸。
    • 将自注意力和 MoBA 注意力的结果进行加权合并,得到最终的输出。
    • 保存中间结果,用于后续的反向传播。
  • 反向传播 (backward)

    • 根据前向传播保存的中间结果,计算自注意力和 MoBA 注意力的梯度。
    • 最终返回各个输入张量的梯度。

小结

通过这种方式,MixedAttention 类实现了 MoBA 混合块注意力机制,通过将上下文划分为块并进行选择性的注意力计算,有效减少了计算量,提升了处理长文本的效率。

总结

总体而言,MoBA 非常适合处理涉及大量上下文的任务,如长篇文档阅读理解、大规模代码补全以及需要完整对话历史的多轮对话系统。它在提高效率的同时,性能损失极小,为大规模训练大语言模型提供了一种极具吸引力的方法。虽然目前 MoBA 主要应用于文本领域,但研究人员认为,其底层机制在其他数据模态中也具有应用潜力。只要序列长度足够长,引发计算或内存问题,将查询分配给块 “专家” 的思路就有望缓解瓶颈,同时保持处理关键全局依赖关系的能力。随着语言应用中的序列长度持续增长,像 MoBA 这样的方法可能会在推动神经语言建模的可扩展性和成本效益方面发挥关键作用,为人工智能的发展注入新的活力。


http://www.niftyadmin.cn/n/5861588.html

相关文章

【Kafka系列】Kafka 消息传递保障机制

Kafka 消息传递保障机制 在现代分布式系统中,消息队列扮演着至关重要的角色。Kafka 作为一款高性能、高吞吐量的消息队列系统,提供了多种消息传递保障机制来满足不同的业务需求。本文将详细介绍 Kafka 的三种主要消息传递保障机制:最多一次&a…

【开源商城系统是否能直接拿去售卖】

开源商城系统是否能直接拿去售卖,需要根据具体的开源协议和相关法律法规来判断,以下是具体分析: 遵循开源协议的情况 GPL协议:如果开源商城系统遵循GNU通用公共许可证(GPL),这种协议属于强拷贝…

PDF文档管理系统V2.0

在<PDF文档管理系统V1.0>的基础上新增了&#xff08;图片文档识别&#xff09;、&#xff08;文档翻译&#xff09;、&#xff08;阅读计划管理的功能&#xff09;&#xff0c;以及其他的小功能完善。由于此版本需要安装数据库&#xff0c;所以不再提供免费下载链接&…

星途汽车掉队?2024销量增速回落,“星纪元”序列后劲不足

近日&#xff0c;奇瑞集团旗下的星途汽车召开了2025商务年会&#xff0c;勾勒了“科技新豪华三步走”的未来规划&#xff0c;宣布将锚定“3大目标、5大超越、3大满意”开启加速奔跑模式。 “技术奇瑞、品质奇瑞、国际化奇瑞,最终都要汇聚成星途星纪元品牌的向上”&#xff0c;…

前端web安全

一、黑盒扫描和白盒扫描 白盒扫描和黑盒扫描都是针对网络安全和应用程序安全的常用测试方法。 1、白盒扫描指的是测试人员具有关于系统内部结构和代码的全部或部分信息&#xff0c;是基于源代码的安全检测方法&#xff0c;它可以对源代码进行深度分析&#xff0c;从而发现潜在…

EasyPoi系列之通用导入接口设计

EasyPoi系列之通用导入接口设计 1 背景2 分析及设计2.1 标准导入交互分析2.2 设计2.2.1 导入模板生成接口2.2.2 数据导入接口 3、代码实现3.1 人员实体-PersonEntity3.2 定义数据保存通用接口-ExcelImporter3.3 人员数据保存实现- PersonService3.4 建立业务与实体及保存实现类…

数据结构系列一:初识集合框架+复杂度

前言 数据结构——是相互之间存在一种或多种特定关系的数据元素的集合。数据结构是计算机专业的基础课程&#xff0c;但也是一门不太容易学好的课&#xff0c;它当中有很多费脑子的东西&#xff0c;之后在学习时&#xff0c;你若碰到了困惑或不解的地方 都是很正常的反应&…

petalinux-build ERROR

最近编译Xilinx的固件的时候报了一个错&#xff0c;看的我云里雾里&#xff0c;一度认为ubuntu的版本跟petalinux的版本不匹配&#xff0c;想要重新安装操作系统和编译环境&#xff0c;想想都头大。 petalinux-create -t project --template zynqMP -n petalinux-config --ge…