FfnWorkerBatching

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:FFNWorkerBatching 在 Attention 与 FFN 分离部署场景下,完成 FFN worker 上的 token 重排操作。Attention 将 token 按专家路由发送到对应 FFN worker 的预分配数据区,FFNWorkerBatching 从该数据区中扫描调度信息,按专家维度聚合并重排 token,产出各专家对应的连续 token 数据块。

  • 计算步骤:

    1. expert_ids_in 中所有 token 的专家 ID 进行排序(被 mask 的 token 初始化为大值),生成 gather 索引。
    2. 多核并行按 gather 索引从 token_data 中提取 token 的 hidden states 和 dynamic scale,同时查表得到对应的 session_idmicro_batch_idtoken_id
    3. 单核扫描排序后的专家 ID 序列,查找跳变点,生成 group_list(每个专家处理的 token 起止偏移)。

    其中 Y=A×BS×(K+1)Y = A \times BS \times (K+1)AA 为 Attention worker 数量,BSBS 为 micro batch size,K+1K+1 为 topK 加共享专家数。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
schedule_context 输入 调度上下文数据结构,内含 CommonArea、ControlArea、AttentionArea、FfnArea。算子从 FfnArea 中读取 token_info_buf 和 token_data_buf 获取待重排的 token 数据与描述信息,并获取 layer_id、session_id、micro_batch_id、expert_ids 等路由信息。结构体总大小 1024 字节。 INT8 -
expert_num 属性 本卡专家总数,等于每层本卡专家数 × layer_num。用于推导 group_list 输出大小。 INT64 -
max_out_shape 属性 输出 shape 上限,格式为 {A, BS, topK+1, H}。用于推导 y 输出的 shape 上限 Y = A × BS × (topK+1),以及 H 值。 LIST_INT -
token_dtype 属性 输入 token 的数据类型。0 表示 FP16;1 表示 BF16;2 表示 INT8 动态量化(INT8 数据与 FP32 dynamic scale 连续排布)。默认值为 0。取值为 2 时需输出 dynamic_scale。 INT64 -
need_schedule 属性 调度模式。0 表示仅做 batching 不扫描数据;1 表示先扫描数据再做 batching。默认值为 0。 INT64 -
layer_num 属性 层数,每层专家独立索引。默认值为 0。 INT64 -
y 输出 重排后的 token hidden states,按专家 ID 排序后连续存放。shape 为 [Y, H]。 FP16、BF16、INT8 ND
group_list 输出 每个专家处理的 token 范围,shape 为 [expert_num, 2]。每行格式为 [expert_id, expert_token_num],未使用的专家填 [0, 0]。示例:[[1, 20], [10, 40], [22, 15], ...]。 INT64 ND
session_ids 输出 每个输出 token 对应的 Attention session ID,shape 为 [Y]。 INT32 ND
micro_batch_ids 输出 每个输出 token 对应的 micro batch ID,shape 为 [Y]。 INT32 ND
token_ids 输出 每个输出 token 在原始输入中的位置索引,shape 为 [Y]。 INT32 ND
expert_offsets 输出 每个输出 token 在其所属专家分组内的偏移,shape 为 [Y]。 INT32 ND
dynamic_scale 输出 动态量化的 scale 值,仅在 token_dtype=2 时有效。shape 为 [Y]。token_dtype 为 0 或 1 时为空 tensor。 FP32 ND
actual_token_num 输出 所有专家有效 token 数之和,标量输出。shape 为 []。 INT64 -

约束说明

  • 该接口支持图模式(GEIR)和单算子模式(aclnn)。
  • 参数 A(Attention worker 数量)支持 ≤ 1024。
  • 参数 M(micro batch 数量)支持 ≤ 64。
  • 参数 K(topK 数)支持 ≤ 64。
  • 参数 BS(micro batch size)和 Y 支持泛化,无硬上限(受内存限制)。
  • 参数 H(hidden size)支持泛化。
  • token_dtype 为 2 时,输入 int8 数据与 fp32 scale 连续排布。

调用说明

调用方式 样例代码 说明
图模式调用 test_geir_ffn_worker_batching.cpp 通过算子IR构图方式调用FfnWorkerBatching算子。

调度上下文数据结构

struct ScheduleContext {
  struct CommonArea {
    uint32_t session_num;           // Attention 节点数
    uint32_t micro_batch_num;       // micro batch 拆分数量
    uint32_t micro_batch_size;      // batch_size / micro_batch_num
    uint32_t selected_expert_num;   // topK 个数 + 1(含共享专家)
    uint32_t expert_num;            // 每层专家个数
    uint32_t attn_to_ffn_token_size;// 每个 token 在 FFN window 数据区占用大小,对齐到 512
    uint32_t ffn_to_attn_token_size;// 每个 token 在 Attention window 数据区占用大小,对齐到 512
    int32_t  schedule_mode;         // 0:只调度FFN, 1:只调度Attention, 2:同时调度
    int8_t   reserve0[96];          // padding to 128 bytes
  };
  struct ControlArea {
    int32_t run_flag;               // 控制循环退出
    int8_t  reserve1[124];
  };
  struct AttentionArea {
    uint64_t token_info_buf;        // [M, DataDesc],DataDesc 含 flags[batch_size][topK+1]
    uint64_t token_info_buf_size;
    uint64_t token_data_buf;        // [M, BS, K+1, HS]
    uint64_t token_data_buf_size;
    uint32_t micro_batch_id;        // 轮询用,初始值 micro_batch_num-1
    int8_t   reserve5[92];
  };
  struct FfnArea {
    // FFN 输入区
    uint64_t token_info_buf;        // [A, M, F],DataDesc 含 flag/layer_id/expert_ids
    uint64_t token_info_buf_size;
    uint64_t token_data_buf;        // [A, M, BS, K+1, HS]
    uint64_t token_data_buf_size;
    uint64_t polling_index;
    int8_t   reserve3[88];
    // FFN 输出区
    uint64_t layer_ids_buf;         // [session_num]
    uint64_t layer_ids_buf_size;
    uint64_t session_ids_buf;       // [session_num]
    uint64_t session_ids_buf_size;
    uint64_t micro_batch_ids_buf;   // [session_num]
    uint64_t micro_batch_ids_buf_size;
    uint64_t expert_ids_buf;        // [session_num, BS, K+1]
    uint64_t expert_ids_buf_size;
    uint32_t out_num;               // 实际收齐的 session 个数
    int8_t   reserve4[60];
  };
  CommonArea    common;
  ControlArea   control;
  AttentionArea attention;
  FfnArea       ffn;
  int8_t        reserve6[384];      // padding to 1024 bytes
}; // 总大小 1024 字节