aclnnKlDivV2

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品
Atlas 训练系列产品

功能说明

  • 算子功能:计算KL散度
  • 计算公式:
    • 定义loss_pointwise,保存中间结果。

      loss_pointwisei={NaN if logTarget=false and targeti<=0,targeti∗(log⁡(targeti)−selfi) if logTarget=false,exp⁡targeti∗(targeti−selfi) else. loss\_pointwise_i=\begin{cases} NaN & \text{ if }&logTarget=false \text{ and } target_i <= 0, \\ target_i * \left ( \log{(target_i)}- self_i \right ) & \text{ if }& logTarget=false, \\ \exp^ {target_i} * \left ( target_i- self_i \right ) & \text{ else. } \end{cases}

    • out计算公式为:

      out={loss_pointwiseˉ if reduction=1,∑loss_pointwise elif reduction=2,∑loss_pointwiseself.size(0) elif reduction=3,loss_pointwise else. out=\begin{cases} \bar{loss\_pointwise} & \text{ if }& reduction= 1, \\ \sum loss\_pointwise & \text{ elif }& reduction= 2,\\ \frac{\sum loss\_pointwise}{self.size(0)} & \text{ elif }& reduction= 3,\\ loss\_pointwise & \text{ else. } \end{cases}

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 公式中的输入张量x。 FLOAT、FLOAT16、BFLOAT16 ND
target 输入 公式中的输入张量x。 FLOAT、FLOAT16、BFLOAT16 ND
reduction 可选属性
  • 公式中的reduction。
  • 默认值为mean。
STRING ND
log_target 可选属性
  • 公式中的logTarget。
  • 默认值为false。
BOOL ND
y 输出 公式中的输出张量y。 FLOAT、FLOAT16、BFLOAT16 ND

约束说明

调用说明

调用方式 调用样例 说明
aclnn调用 test_aclnn_kl_div_v2 通过aclnnKlDiv接口方式调用KlDivV2算子。
图模式调用 test_geir_kl_div_v2 通过算子IR构图方式调用KlDivV2算子。