npu_index_select[beta]

接口原型

mx_driving.npu_index_select(Tensor feature, Int dim, Tensor index) -> Tensor

功能描述

从输入feature的指定维度dim,按index中的下标序号提取元素,保存到out中。 例如,对于输入张量 feature=[123456789]feature=\begin{bmatrix}1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9\end{bmatrix} 和索引张量 index=[1,0]index=[1, 0]mx_driving.npu_index_select(feature,0,index)mx\_driving.npu\_index\_select(feature, 0, index) 的结果: out=[456123]out=\begin{bmatrix}4 & 5 & 6 \\ 1 & 2 & 3\end{bmatrix}

参数说明

  • feature(Tensor):待提取张量,数据类型支持FLOAT、FLOAT16、INT32、INT16,维度仅支持二维,支持非连续的Tensor
  • dim(Int):提取维度,数据类型支持INT64
  • index(Tensor):提取索引,数据类型支持INT64、INT32,仅支持一维Tensor,支持非连续的Tensor

返回值

  • out(Tensor):输出Tensor,数据类型与feature一致,维度为两维,维度为[index.shape[0], feature.shape[1]]

约束说明

  • feature仅支持二维Tensor
  • dim仅支持0和-2;
  • index不支持负索引和越界索引,即取值范围为[0, feature.shape[0])
  • 该API依赖aclnnIndexAddV2接口,因此需配置2025年6月18日之后的CANN包才能生效;
  • 反向具有相同约束。

支持的型号

  • Atlas A2 训练系列产品

调用示例

import torch, torch_npu
from mx_driving import npu_index_select

x = torch.randn(3, 4)
index = torch.tensor([0, 2])
npu_output = npu_index_select(x.npu(), 0, index.npu())
npu_output.backward(torch.ones_like(npu_output))