SpatialTransformer

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品
Atlas A2 训练系列产品/Atlas A2 推理系列产品
Atlas 200I/500 A2 推理产品 ×
Atlas 推理系列产品
Atlas 训练系列产品

功能说明

算子功能:Spatial Transformer Network (STN) 算子用于对输入张量进行仿射变换。该算子通过变换矩阵theta对输入图像x进行空间变换,输出变换后的图像y。

参数说明

参数名 输入/输出/属性 描述 数据类型 数据格式
x 输入 输入张量。 INT8、INT16、INT32、INT64、UINT8、UINT16、UINT32、UINT64
FLOAT16、FLOAT、DOUBLE
NCHW、NC1HWC0
theta 输入 变换矩阵,包含仿射变换参数。 INT8、INT16、INT32、INT64、UINT8、UINT16、UINT32、UINT64
FLOAT16、FLOAT、DOUBLE
ND
y 输出 变换后的输出张量。 INT8、INT16、INT32、INT64、UINT8、UINT16、UINT32、UINT64
FLOAT16、FLOAT、DOUBLE
NCHW、NC1HWC0
output_size 属性 指定输出的高度和宽度,包含2个整数。默认为 [-1, -1],表示使用输入尺寸。 ListInt -
default_theta 属性 默认的仿射变换参数,当use_default_theta为true时使用。默认为空列表。 ListFloat -
align_corners 属性 如果为true,则输入和输出张量的4个角像素中心对齐,保留角像素的值。默认为false。 Bool -
use_default_theta 属性 指定哪些theta参数从default_theta使用。1表示使用默认值,0表示使用输入theta。默认为空列表。 ListInt -

约束说明

  • 输入张量x的格式必须为NCHW或NC1HWC0。
  • 输出张量y的格式必须与输入张量x的格式一致。
  • 变换矩阵theta的形状必须为 [batch, 2, 3] 或 [2, 3]。
  • 当use_default_theta为空列表时,使用输入theta进行变换。
  • 当use_default_theta不为空时,对应位置为1的参数使用default_theta中的值,为0的参数使用输入theta中的值。

调用说明

调用方式 样例代码 说明
图模式调用 test_geir_spatial_transformer 通过算子IR构图方式调用SpatialTransformer算子。