AttentionWorkerCombine

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:将多个计算单元处理的注意力token数据进行融合,结合专家权重对结果进行加权,输出最终的注意力融合结果,并更新层ID。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
schedule_context 输入 包含调度上下文信息。 INT8 ND
expert_scales 输入 表示专家权重。 FLOAT ND
layer_id 输入 当前的模型层ID。 INT32 ND
y 输出 最终的注意力合并结果。 FLOAT16,BFLOAT16 ND
next_layer_id 输出 下一个要处理的层ID。 INT32 ND
hidden_size 属性 token_data的隐藏维度大小,用于确定输出y的第二维大小。必要属性。 Int -
token_dtype 属性 指定schedule_context中token数据的原始精度类型,0表示FLOAT16,1表示BFLOAT16。 Int -
need_schedule 属性 指定是否等待token数据填充完成后再执行,0表示不等待,1表示等待。 Int -

约束说明

  • schedule_context为1D的Tensor。
  • expert_scales为2D的Tensor,[BatchSize, K]。
  • y为2D的Tensor,[BatchSize, HiddenSize],即第二维由属性hidden_size确定。
  • layer_id和next_layer_id为1D的Tensor。

调用说明

调用方式 样例代码 说明
图模式调用 test_geir_attention_worker_combine.cpp 通过算子IR构图方式调用AttentionWorkerCombine算子。