QuantAllReduce

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

算子功能:实现低比特数据的AllReduce通信,在通信的过程中对数据进行反量化,并输出通信结果。

  • 计算公式

    AllGatherData=AllGather(x)AllGatherData = AllGather(x)

    AllGatherScales=AllGather(scales)AllGatherScales = AllGather(scales)

    output=Reduce(AllGatherScales∗AllGatherData)output = Reduce(AllGatherScales * AllGatherData)

    其中的Reduce计算是将来自不同rank的数据进行reduce计算。

参数说明

参数名 输入/输出/属性 描述 使用说明 数据类型 数据格式 维度(shape) 连续Tensor
x 输入 公式中的输入x。
  • 不支持空Tensor。
  • 支持的shape为:(bs, H)或者(b, s, H)。b为batch size,s为sequence length,H为hidden size。
  • INT8、HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2 ND 2-3
    scales 输入 公式中的输入scales。
  • 不支持空Tensor。
  • 当scales的数据类型为FLOAT8_E8M0时,x的数据类型必须为FLOAT8_E4M3FN、FLOAT8_E5M2,x的shape为(bs, H)或者(b, s, H),scales的shape必须对应x的shape为(bs, H/64, 2)或者(b, s, H/64, 2)。
  • 当scales的数据类型为FLOAT时,x的数据类型必须为INT8、HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2,x的shape为(bs, H)或者(b, s, H),scales的shape必须对应x的shape为(bs, H/128)或者(b, s, H/128)。
  • FLOAT、FLOAT8_E8M0 ND 2-4
    group 属性 通信域标识。
  • Host侧标识列组的字符串,通信域名称。
  • 通过Hccl提供的接口"extern HcclResult HcclGetCommName(HcclComm comm, char* commName);"获取,其中commName即为group。
  • Char*、String - - -
    reduceOp 可选属性 公式中的reduce操作类型。 当前仅支持"sum"操作。 Char*、String - - -
    output 输出 公式中的输出output。
  • 不支持空Tensor。
  • 支持的shape为(bs, H)或者(b, s, H),output的shape与x保持一致。
  • FLOAT、FLOAT16、BFLOAT16 ND 2-3

    约束说明

    • 当x的数据类型为FLOAT8_E4M3FN、FLOAT8_E5M2并且scales的数据类型为FLOAT8_E8M0时,输入数据的量化方式为mx量化。
    • 当x的数据类型为INT8、HIFLOAT8、FLOAT8_E4M3FN、FLOAT8_E5M2并且scales的数据类型为FLOAT时,输入数据的量化方式为pertoken-pergroup量化(groupSize=128)。
    • 只在Ascend950系列平台使能。
    • 不支持空Tensor输入。
    • 通信引擎约束:
      • Ascend950PR/Ascend950DT: 仅支持UB-Memory通信。
    • 通信域大小支持2, 4, 8。
    • 通信域使用约束:同一通信域内仅允许连续执行aclnnQuantAllReduceaclnnQuantReduceScatter算且子,该通信域中不允许有其他通信算子。
    • HCCL_BUFFSIZE:调用本算子前需检查HCCL_BUFFSIZE环境变量取值是否合理,该环境变量表示单个通信域占用内存大小,单位MB,不配置时默认为200MB。要求满足HCCL_BUFFSIZE>= 2 * (xDataSize + scalesDataSize + 1)。其中xDataSize为输入x的数据大小,计算公式为:xDataSize = b * s * H * 1 (Byte)scalesDataSizescales的数据大小,当量化方式为pertoken-pergroup量化时,计算公式为:scalesDataSize = b * s * H / 128 * 4 (Byte),当量化方式为mx量化时,计算公式为:scalesDataSize = b * s * H / 32 * 1 (Byte)
    • H范围仅支持[1024, 8192],要求128对齐。

    调用说明

    调用方式 样例代码 说明
    aclnn接口 test_aclnn_quant_all_reduce.cpp 通过aclnnQuantAllReduce接口方式调用quant_all_reduce算子。