#include "csrc/OpApiCommon.h"
#include "csrc/functions.h"
at::Tensor npu_three_interpolate(
int b, int c, int m, int n, const at::Tensor& points, const at::Tensor& idx, const at::Tensor& weight)
{
TORCH_CHECK_NPU(points);
TORCH_CHECK_NPU(idx);
TORCH_CHECK_NPU(weight);
auto point_dtype = points.scalar_type();
auto idx_dtype = idx.scalar_type();
auto weight_dtype = weight.scalar_type();
TORCH_CHECK((point_dtype == at::kFloat || point_dtype == at::kHalf),
"three_interpolate_forward ascend only support fp32 and fp16.");
TORCH_CHECK((weight_dtype == at::kFloat || weight_dtype == at::kHalf),
"three_interpolate_forward ascend only support fp32 and fp16.");
TORCH_CHECK((point_dtype == weight_dtype), "input dtype is inconsistent.");
TORCH_CHECK((idx_dtype == at::kInt), "indices: int32 tensor expected but got a tensor with dtype: ", idx_dtype);
auto point_size = points.sizes();
auto idx_size = idx.sizes();
auto weight_size = weight.sizes();
TORCH_CHECK(
(point_size.size() == 3 && idx_size.size() == 3 && weight_size.size() == 3), "input dimension should be 3.");
TORCH_CHECK((point_size[0] == idx_size[0] && point_size[0] == weight_size[0] && idx_size[0] == weight_size[0]),
"the first dimension of input should be the same.");
TORCH_CHECK((idx_size[1] == weight_size[1]), "the second dimension of indices and weight should be the same.");
TORCH_CHECK((idx_size[2] == 3 && weight_size[2] == 3), "the third dimension of indices and weight should be 3.");
TORCH_CHECK((b < 10001 && c < 10001 && m < 10001 && n < 10001), "input dimension is too heavy.");
auto point_c_trans = points.transpose(1, 2).to(at::kFloat);
auto weight_cast = weight.to(at::kFloat);
c10::SmallVector<int64_t, 8> output_size = {b, c, n};
at::Tensor out_cast = at::zeros(output_size, points.options()).to(at::kFloat);
at_npu::native::OpCommand cmd;
cmd.Name("ThreeInterpolate").Input(point_c_trans).Input(idx).Input(weight_cast).Output(out_cast).Run();
auto out = out_cast;
if (point_dtype == at::kHalf) {
out = out_cast.to(at::kHalf);
}
auto output = out_cast.view({b, n, c}).transpose(1, 2);
auto res = output.contiguous();
out.copy_(res);
return out;
}
at::Tensor npu_three_interpolate_backward(
int b, int c, int n, int m, const at::Tensor& grad_out, const at::Tensor& idx, const at::Tensor& weight)
{
TORCH_CHECK_NPU(grad_out);
TORCH_CHECK_NPU(idx);
TORCH_CHECK_NPU(weight);
auto grad_dtype = grad_out.scalar_type();
auto idx_dtype = idx.scalar_type();
auto weight_dtype = weight.scalar_type();
TORCH_CHECK((grad_dtype == at::kFloat || grad_dtype == at::kHalf),
"three_interpolate_forward ascend only support fp32 and fp16.");
TORCH_CHECK((weight_dtype == at::kFloat || weight_dtype == at::kHalf),
"three_interpolate_forward ascend only support fp32 and fp16.");
TORCH_CHECK((grad_dtype == weight_dtype), "input dtype is inconsistent.");
TORCH_CHECK((idx_dtype == at::kInt), "indices: int32 tensor expected but got a tensor with dtype: ", idx_dtype);
auto grad_size = grad_out.sizes();
auto idx_size = idx.sizes();
auto weight_size = weight.sizes();
TORCH_CHECK(
(grad_size.size() == 3 && idx_size.size() == 3 && weight_size.size() == 3), "the input dimension should be 3.");
TORCH_CHECK((grad_size[0] == idx_size[0] && grad_size[0] == weight_size[0] && idx_size[0] == weight_size[0]),
"the first dimension of input should be the same.");
TORCH_CHECK((grad_size[2] == idx_size[1] && grad_size[2] == weight_size[1] && idx_size[1] == weight_size[1]),
"the second dimension of indices and weight should be the same.");
TORCH_CHECK((idx_size[2] == 3 && weight_size[2] == 3), "the third dimension of indices and weight should be 3.");
TORCH_CHECK((b < 10001 && c < 10001 && m < 10001 && n < 10001), "input dimension is too heavy.");
at::Tensor grad_points = at::zeros({b, c, m}, grad_out.options());
auto grad_x = at::unsqueeze(grad_out, 3);
auto grad_y = at::unsqueeze(grad_points, 3);
EXEC_NPU_CMD(aclnnThreeInterpolateBackward, grad_x, idx, weight, m, grad_y);
auto output = at::squeeze(grad_y, 3);
auto res = output.contiguous();
grad_points.copy_(res);
return grad_points;
}