Sinkhorn

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT ×
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品 ×
Atlas 训练系列产品 ×

功能说明

  • 算子功能:

    计算Sinkhorn距离,可以用于MoE模型中的专家路由。

  • 计算公式:

    p=Sinkhorn(cost,tol)p=Sinkhorn(cost, tol)

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
cost 输入
  • 表示成本张量,公式中的。cost,Device侧的aclTensor。
  • 输入为二维矩阵且列数不超过4096。
  • 支持非连续的Tensor
FLOAT、FLOAT16、BFLOAT16 ND
tol 输入
  • 计算Sinkhorn的误差。
  • 如果传入空指针,则tol取0.0001。
FLOAT ND
p 输出
  • 表示最优传输张量,公式中的p,Device侧的aclTensor。
  • 如果传入空指针,则tol取0.0001。
  • shape维度为2,不支持非连续的Tensor
  • 数据类型和shape与入参cost的数据类型和shape一致。
FLOAT、FLOAT16、BFLOAT16 ND

约束说明

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_sinkhorn 通过aclnnSinkhorn接口方式调用Sinkhorn算子。