scatter_max

接口原型

mx_driving.scatter_max(Tensor updates, Tensor indices, Tensor out=None) -> (Tensor out, Tensor argmax)

功能描述

在第0维上,将输入张量updates中的元素按照indices中的索引进行分散,然后在第0维上取最大值,返回最大值和对应的索引。对于1维张量,公式如下:

outi=max(outi,maxj(updatesj))out_i = max(out_i, max_j(updates_j))

argmaxi=argmaxj(updatesj)argmax_i = argmax_j(updates_j)

这里,i=indicesji = indices_j

参数说明

  • updates(Tensor):更新源张量,数据类型为float32
  • indices(Tensor):索引张量,数据类型为int32
  • out(Tensor):被更新张量,数据类型为float32,默认为None

返回值

  • out(Tensor):更新后的张量,数据类型为float32
  • argmax(Tensor):最大值对应的索引张量,数据类型为int32

算子约束

  • 假设updates第0维的元素数量为N,其余轴总元素数量为M,则N与M的取值需满足:N * (M + 1) < 4,026,531,840,若超出该约束则算子出现显存错误。
  • updates的第0维外其余轴合轴后必须32字节对齐。
  • indices的维度必须为1indices第0维的长度必须与updates第0维的长度相同。
  • indices的取值必须为非负的有效索引值,且indices的最大值必须小于491520
  • out的维度必须与updates的维度相同,且除第0维外其余维的长度必须与updates相同。
  • 反向仅支持updates的维度为2,其余约束与正向相同。

支持的型号

  • Atlas A2 训练系列产品

调用示例

import torch, torch_npu
from mx_driving import scatter_max
updates = torch.tensor([[2, 0, 1, 3, 1, 0, 0, 4], [0, 2, 1, 3, 0, 3, 4, 2], [1, 2, 3, 4, 4, 3, 2, 1]], dtype=torch.float32).npu()
indices = torch.tensor([0, 2, 0], dtype=torch.int32).npu()
updates.requires_grad = True
out = updates.new_zeros((3, 8))
out, argmax = scatter_max(updates, indices, out)
grad_out_tensor = torch.ones_like(out)
out.backward(grad_out_tensor)