#include "csrc/OpApiCommon.h"
#include "csrc/functions.h"
void sigmoid_focal_loss(const at::Tensor &input, const at::Tensor &target, const at::Tensor &weight,
const at::Tensor &output, double gamma, double alpha) {
int64_t n_class = input.size(1);
at::Tensor target_y = at::ones_like(input);
if (n_class == 1) {
target_y = at::reshape(target, input.sizes());
target_y = at::mul(target_y, -1.0);
target_y = at::add(target_y, 1.0);
} else {
target_y = at::one_hot(target, n_class);
}
target_y = target_y.to(at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
at::Tensor weight_selected = weight.gather(0, target);
weight_selected = weight_selected.unsqueeze(1);
weight_y = weight_selected.expand_as(input);
}
EXEC_NPU_CMD(aclnnSigmoidFocalLoss, input, target_y, weight_y, gamma, alpha, output);
}
void sigmoid_focal_loss_backward(const at::Tensor &input, const at::Tensor &target, const at::Tensor &weight,
const at::Tensor &grad_input, double gamma, double alpha) {
int64_t n_class = input.size(1);
at::Tensor target_y = at::ones_like(input);
if (n_class == 1) {
target_y = at::reshape(target, input.sizes());
target_y = at::mul(target_y, -1.0);
target_y = at::add(target_y, 1.0);
} else {
target_y = at::one_hot(target, n_class);
}
target_y = target_y.to(at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
at::Tensor weight_selected = weight.gather(0, target);
weight_selected = weight_selected.unsqueeze(1);
weight_y = weight_selected.expand_as(input);
}
EXEC_NPU_CMD(aclnnSigmoidFocalLossGrad, input, target_y, weight_y, gamma, alpha, grad_input);
}
void sigmoid_focal_loss_cann(const at::Tensor &input, const at::Tensor &target, const at::Tensor &weight,
at::Tensor &output, float gamma, float alpha) {
int64_t n_class = input.size(1);
at::Tensor target_y = at::ones_like(input);
if (n_class == 1) {
target_y = at::reshape(target, input.sizes());
target_y = at::mul(target_y, -1.0);
target_y = at::add(target_y, 1.0);
} else {
target_y = at::one_hot(target, n_class);
}
target_y = target_y.to(at::kInt);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
at::Tensor weight_selected = weight.gather(0, target);
weight_selected = weight_selected.unsqueeze(1);
weight_y = weight_selected.expand_as(input);
}
at_npu::native::OpCommand cmd;
string reduction = "none";
cmd.Name("SigmoidFocalLoss")
.Input(input)
.Input(target_y)
.Input(weight_y)
.Output(output)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", reduction)
.Run();
}
void sigmoid_focal_loss_backward_cann(const at::Tensor &input, const at::Tensor &target, const at::Tensor &weight,
at::Tensor &grad_input, float gamma, float alpha) {
int64_t n_class = input.size(1);
at::Tensor target_y = at::ones_like(input);
if (n_class == 1) {
target_y = at::reshape(target, input.sizes());
} else {
target_y = at::one_hot(target, n_class);
target_y = at::mul(target_y, -1.0);
target_y = at::add(target_y, 1.0);
}
target_y = target_y.to(at::kInt);
at::Tensor grad_up = at::ones_like(input);
int64_t weight_size = weight.size(0);
at::Tensor weight_y = at::ones_like(input);
if (weight_size > 0) {
at::Tensor weight_selected = weight.gather(0, target);
weight_selected = weight_selected.unsqueeze(1);
weight_y = weight_selected.expand_as(input);
}
at_npu::native::OpCommand cmd;
string reduction = "none";
cmd.Name("SigmoidFocalLossGrad")
.Input(input)
.Input(target_y)
.Input(grad_up)
.Input(weight_y)
.Output(grad_input)
.Attr("gamma", gamma)
.Attr("alpha", alpha)
.Attr("reduction", reduction)
.Run();
}