"""
FillPad Operator System Test
FillPad 功能说明:
- 输入 tensor 有两个关键属性:
- shape: tensor 的总大小 (pad后总大小)
- valid_shape: 有效数据的大小
- fillpad 的作用: 将 valid_shape 之外的区域填充为指定值
示例:
输入 tensor shape = [32, 32], valid_shape = [16, 16]
fillpad 后: [0:16, 0:16] 保持原数据, [16:32, :] 和 [:, 16:32] 被填充为 pad_val
"""
import pypto
import torch
import numpy as np
from st.pypto_test import TestBuilder
def op_fillpad(params, a, b):
"""
FillPad 算子实现
参数说明:
- view_shape: 每个 tile 的大小 (等于 pad 后的总大小)
- tile_shape: NPU 计算单元的 tile 形状
- valid_shape: 有效数据的大小 (小于 view_shape 时会产生 padding 区域)
- pad_val: 填充值
"""
view_shape, tile_shape, valid_shape, pad_val = params
valid_h, valid_w = valid_shape
for _ in pypto.loop(1, name="LOOP_FILLPAD_L0", idx_name="b_idx"):
for _ in pypto.loop(1, name="LOOP_FILLPAD_L1", idx_name="s_idx"):
offset_x = 0
offset_y = 0
tile_a = pypto.view(a, view_shape, [offset_x, offset_y],
valid_shape=[valid_h, valid_w])
pypto.set_vec_tile_shapes(tile_shape[0], tile_shape[1])
tile_res = tile_a.fillpad(mode="constant", value=pad_val)
pypto.assemble(tile_res, [offset_x, offset_y], b)
def op_fillpad_golden(params, a, b):
"""
FillPad Golden 实现
将 valid_shape 之外的区域填充为 pad_val
"""
view_shape, tile_shape, valid_shape, pad_val = params
valid_h, valid_w = valid_shape
result = a.clone()
h, w = a.shape
if valid_h < h:
result[valid_h:, :] = pad_val
if valid_w < w:
result[:, valid_w:] = pad_val
return result
class FillPadTest(TestBuilder):
def __init__(self, params: tuple, kernel, kernel_golden, tiling: int):
super().__init__(params, kernel, kernel_golden, tiling)
def get_input_from_param(self):
view_shape = self.params[0]
n_in, m_in = view_shape
a_tensor = torch.rand(n_in, m_in, dtype=torch.float32) * 10
self.setup_inputs(a_tensor)
self.set_tol(rtol=1e-3, atol=1e-3)
return (a_tensor, )
def test():
"""
测试用例说明:
- view_shape = (32, 32): pad 后的总大小
- tile_shape = (16, 16): NPU 计算单元 tile 形状
- valid_shape = (16, 16): 有效数据大小 (只有前 16x16 有数据)
- pad_val = 0: 填充值
预期结果:
- 输出 tensor 的 [0:16, 0:16] 保持原数据
- 输出 tensor的 [16:32, :] 和 [:, 16:32] 被填充为 0
"""
params = ((32, 32), (16, 16), (16, 16), 0.0)
st = FillPadTest(params, op_fillpad, op_fillpad_golden, tiling=32)
st()
def test_partial_valid():
"""
测试部分有效数据的场景
- view_shape = (32, 32)
- valid_shape = (20, 24): 有效数据只有 20x24
"""
params = ((32, 32), (16, 16), (20, 24), 0.0)
st = FillPadTest(params, op_fillpad, op_fillpad_golden, tiling=32)
st()
def test_small_valid():
"""
测试有效数据较小的场景
- view_shape = (16, 16)
- valid_shape = (8, 8): 有效数据只有 8x8
"""
params = ((16, 16), (8, 8), (8, 8), 0.0)
st = FillPadTest(params, op_fillpad, op_fillpad_golden, tiling=16)
st()
if __name__ == "__main__":
test()
test_partial_valid()
test_small_valid()