(beta)torch_npu.contrib.function.roll
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Atlas A3 训练系列产品 | √ |
| Atlas A2 训练系列产品 | √ |
| Atlas 推理系列产品 | √ |
| Atlas 训练系列产品 | √ |
功能说明
使用NPU亲和写法替换swin-transformer中的原生roll。
函数原型
torch_npu.contrib.function.roll(input1, shifts, dims)
参数说明
- input1 (
Tensor):输入张量。 - shifts (
Tupleofints):每个维度张量滚动(roll)的位移量。 - dims (
Tupleofints):要滚动的维度。
返回值说明
Tensor
滚动之后的结果。
约束说明
input1是4维张量,shifts和dims的长度为2。
调用示例
>>> import torch, torch_npu
>>> from torch_npu.contrib.function import roll
>>> input1 = torch.randn(32, 56, 56, 16).npu()
>>> shift_size = 3
>>> shifted_x_npu = roll(input1, shifts=(-shift_size, -shift_size), dims=(1, 2))
>>> shifted_x_npu.shape
torch.Size([32, 56, 56, 16])