stpttr算子实现

概述

BLAS stpttr算子实现。

stpttr(Symmetric Triangular matrix, Packed format To Triangular matrix, Regular storage)算子将 LAPACK 压缩格式(packed format)中的对称三角矩阵展开为按列主序存储的常规二维矩阵。仅写入 uplo 指定的三角区域,矩阵另一三角及未参与运算的元素保持原值不变。

产品支持情况

产品 是否支持
Ascend 950PR/Ascend 950DT
Atlas A3 训练系列产品/Atlas A3 推理系列产品 ×
Atlas A2 训练系列产品/Atlas A2 推理系列产品 ×

目录结构介绍

blas/tpttr/
├── README.md                   // 说明文档
└── arch35/
    ├── stpttr_host.cpp         // Host 侧实现
    ├── stpttr_kernel.cpp       // Kernel 侧实现
    └── stpttr_tiling_data.h    // Tiling 数据结构

测试代码位于 test/tpttr/

test/tpttr/
├── CMakeLists.txt              // 编译工程文件
├── stpttr_param.h              // 参数结构体(继承 BlasTestParamBase)
├── stpttr_golden.h             // CPU golden(签名与 BLAS API 一致)
└── arch35/
    ├── stpttr_npu_wrapper.h    // NPU wrapper(封装 aclrtMalloc/H2D/kernel/D2H/free)
    ├── stpttr_test.cpp         // 精度测试(GTest 入口)
    └── stpttr_test.csv         // 精度测试用例表

算子描述

  • 算子功能:
    将压缩格式三角矩阵 AP 中的元素按 uplo 展开到常规矩阵 A 的对应三角区域:

    • uplo == ACLBLAS_LOWER:复制到 A 的下三角(含对角),上三角不变
    • uplo == ACLBLAS_UPPER:复制到 A 的上三角(含对角),下三角不变

    AP 为列优先压缩存储,长度为 n * (n + 1) / 2Alda × n 的列主序矩阵,lda >= max(1, n)n == 0 时直接返回成功,不访问缓冲区。

    对应的接口为:

aclblasStatus_t aclblasStpttr(
    aclblasHandle_t handle,
    aclblasFillMode_t uplo,
    int n,
    const float *AP,
    float *A,
    int lda);
参数 stpttr 参数说明
参数列表 Param. Memory in/out 含义
handle in aclbLAS 库上下文句柄。
uplo in 三角存储方式:ACLBLAS_UPPER(121)、ACLBLAS_LOWER(122)。
n in 方阵维数,须 >= 0;为 0 时立即返回成功。
AP device in 压缩格式输入,<type> 数组,长度 n*(n+1)/2。
A device in/out 常规输出矩阵,<type> 数组,维度 lda × n;非目标三角保持原值。
lda in A 的主维长度,须满足 lda >= max(1, n)。
  • 算子规格:

    算子类型(OpType)stpttr
    算子输入nameshapedata typeformat
    APn*(n+1)/2floatpacked
    算子输出Alda * nfloatND
    核函数名stpttr_kernel
  • 算子实现:

    Host 侧完成参数校验与 Tiling 计算(按 Vector Core 数切分列块),将 Tiling 数据拷贝至 Device 后,通过 stpttr_kernel_do 启动 Kernel。Kernel 按列从 GM 上的压缩缓冲区 AP 分块搬入 UB,再写回 GM 上常规矩阵 A 的对应三角列段;lda > n 时列间存在 stride 间隔。

  • 调用实现
    使用内核调用符 <<<>>>stpttr_kernel_do)在 aclblas 关联的 stream 上异步执行,Host 在返回前同步 stream。

测试用例覆盖

分组 用例数 覆盖场景
L0 参数校验 4 未初始化 handle、n<0、lda 过小、非法 uplo
L0 功能 13 n=0/1/2/4/32/128/512,LOWER/UPPER
L1 规模与 lda 18 n=8~1024、lda>n(8×12、16×32 等)
L1 特殊数值 12 全零、大数、负数、inf、nan、极值组合
L1 参数校验 8 AP/A 空指针、非法 uplo、n=0 与 lda 组合
L1 往返与大规模 4 strttp→stpttr 往返(32×32)、n=10240

ST 采用 GTest 参数化 + stpttr_test.csvBlasTest<StpttrParam> fixture,精度模式为 EXACT(仅比对有效三角区,其余位置为 sentinel -999)。

注意makeBlasArray 的 size 参数为 int64_t,调用时需显式转换:makeBlasArray(static_cast<int64_t>(p.lda) * p.n, p.a),确保负值 n 正确返回空数组。

编译运行

在本样例根目录下执行如下步骤,编译并执行算子。

  • 配置环境变量
    请根据当前环境上CANN开发套件包的安装方式,选择对应配置环境变量的命令。

    • 默认路径,root用户安装CANN软件包

      source /usr/local/Ascend/cann/set_env.sh
      
    • 默认路径,非root用户安装CANN软件包

      source $HOME/Ascend/cann/set_env.sh
      
    • 指定路径install_path,安装CANN软件包

      source ${install_path}/cann/set_env.sh
      
  • 样例执行

    bash build.sh --ops=stpttr --soc=ascend950 --run
    

    其中--soc可选参数,用于指定目标硬件平台(与上文「产品支持情况」对应)。按实际硬件选用:

    产品 --soc 取值
    Ascend 950PR / Ascend 950DT ascend950
    Atlas A3 训练系列产品 / Atlas A3 推理系列产品 ascend910_93
    Atlas A2 训练系列产品 / Atlas A2 推理系列产品 ascend910b

    执行结果如下,说明精度对比成功。

    [PASS] stpttr_test