scatter_add

接口原型

mx_driving.scatter_add(Tensor src, Tensor indices, int dim=0, Tensor out=None, int dim_size=None) -> Tensor

功能描述

将输入张量src中的元素按照indices中的索引在指定的dim维进行分组,并对每组进行求和,求和后的结果放在out中。

参数说明

  • src (Tensor):源张量 (Tensor),数据类型为float32
  • indices (Tensor):索引张量 (Tensor),数据类型为int32
  • out (Tensor):被更新张量 (Tensor),数据类型为float32,可选入参,默认为None,输入out不为None时,out中的元素参与求和的计算。
  • dim (int):指定的维度,表示按照哪个维度进行分组求和计算,数据类型为int32,可选入参,默认取值为0
  • dim_size (int):输出张量在dim维的长度,数据类型为int32,可选入参,默认为None,该参数仅在输入outNone时生效。

返回值

  • out (Tensor):求和后的张量 (Tensor),数据类型为float32

算子约束

  • indices的维度必须小于等于src的维度,且每一维的长度均必须与src长度相同。
  • indices的取值必须为非负的有效索引值,参数outdim_size不为None时,indices的取值应该为输出张量在dim维的有效索引值。
  • out的维度必须与src的维度相同,且除第dim维外其余维的长度必须与src相同。
  • dim取值不能超过indices的维度。
  • dim_size的取值必须为非负的有效长度值。
  • srcout不支持inf-infnan
  • 该算子的正反向均对尾块较大的场景较为亲和,对尾块很小的场景不亲和,其中,尾块表示srcN维的大小,N = src.dim() - indices.dim()

支持的型号

  • Atlas A2 训练系列产品

调用示例

import torch, torch_npu
from mx_driving import scatter_add
src = torch.randn(4, 5, 6).to(torch.float)
indices = torch.randint(5, (4, 5)).to(torch.int32)
dim = 0
src.requires_grad = True
out = scatter_add(src.npu(), indices.npu(), None, dim)
grad_out_tensor = torch.ones_like(out)
out.backward(grad_out_tensor)