sigmoid_focal_loss

接口原型

mx_driving.sigmoid_focal_loss(Tensor logit, Tensor target, float gamma=2, float alpha=0.25, Tensor weight=None, str reduction='mean') -> Tensor

功能描述

先计算输入logit中每个元素的sigmoid值,然后计算sigmoid值与类别目标值之间的Focal Loss,功能与mmcv库的sigmoid_focal_loss一致。

参数说明

  • logit (Tensor):表示全部样本的分类预测值,数据类型为float32。Shape为[N, C],其中N为样本数量,C为类别数量。
  • target (Tensor):表示全部样本的分类目标值,数据类型为int64。Shape为[N]
  • gamma (float):用于平衡易分类样本和难分类样本的超参数,默认值为2.0。
  • alpha (float):用于平衡正样本和负样本的超参数,默认值为0.25。
  • weight (Tensor):表示全部类别的权重系数,数据类型为float32。Shape为[C]
  • reduction (str):规约方式,数据类型为str,默认为mean

返回值

  • output (Tensor):表示输入中每个元素的focal loss,数据类型为float32。Shape为[N, C]

支持的型号

  • Atlas A5 训练系列产品

调用示例

import torch, torch_npu
from mx_driving import sigmoid_focal_loss
logit = torch.rand(1800, 10, dtype=torch.float32, device='npu') * 10 - 5
target = torch.randint(low=0, high=10, size=(1800,), dtype=torch.int64, device='npu')
weight = torch.rand(1800, dtype=torch.float32, device='npu') * 10 - 5
logit.requires_grad = True
output = sigmoid_focal_loss(logit, target, 2.0, 0.25, weight, 'mean')
output.backward()