Trmv算子
算子概述
trmv (Triangular Matrix-Vector Multiplication) 实现三角矩阵与向量的乘法运算。本算子包含实数三角矩阵-向量乘法(Strmv)和复数三角矩阵-向量乘法(Ctrmv)。
数学表达式:
x = op(A) * x
包含以下接口:
| 接口名 | 功能简述 |
|---|---|
| aclblasStrmv | 实数三角矩阵-向量乘法 |
| aclblasCtrmv | 复数三角矩阵-向量乘法 |
算子执行接口
aclblasStrmv
产品支持情况
- Ascend 950PR / Ascend 950DT:支持
- Atlas A3 训练系列产品 / Atlas A3 推理系列产品:支持
- Atlas A2 训练系列产品 / Atlas A2 推理系列产品:支持
函数原型
aclblasStatus_t aclblasStrmv(aclblasHandle_t handle, aclblasFillMode_t uplo, aclblasOperation_t trans, aclblasDiagType_t diag, int n, const float *A, int lda, float *x, int incx)
参数说明
| 参数名 | 输入/输出 | 参数类型 | 说明 |
|---|---|---|---|
| handle | 输入 | aclblasHandle_t | ops-blas 库上下文句柄,携带 stream,Host 内存 |
| uplo | 输入 | aclblasFillMode_t | 矩阵填充类型:ACLBLAS_UPPER(上三角)、ACLBLAS_LOWER(下三角),Host 内存 |
| trans | 输入 | aclblasOperation_t | 矩阵操作类型:ACLBLAS_OP_N(不转置)、ACLBLAS_OP_T(转置)、ACLBLAS_OP_C(共轭转置,实数下同 T),Host 内存 |
| diag | 输入 | aclblasDiagType_t | 对角线类型:ACLBLAS_NON_UNIT(非单位对角线)、ACLBLAS_UNIT(单位对角线,对角元素视为 1),Host 内存 |
| n | 输入 | int | 三角矩阵 A 的行数和列数,Host 内存 |
| A | 输入 | const float*(FP32) | 三角矩阵数组,维度为 lda x n,Device 内存 |
| lda | 输入 | int | 矩阵 A 存储的主维长度,lda >= n,Host 内存 |
| x | 输入/输出 | float*(FP32) | 向量,包含 n 个元素。输入为原始向量,输出为计算结果(原地覆盖),Device 内存 |
| incx | 输入 | int | x 中连续元素之间的步长,不可为 0,Host 内存 |
约束说明
- n >= 0
- incx != 0
- lda >= n
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
#include <memory>
#include "acl/acl.h"
#include "cann_ops_blas.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
class AclContext {
public:
explicit AclContext(int deviceId) : deviceId_(deviceId) {}
~AclContext()
{
if (stream_ != nullptr) {
aclrtDestroyStream(stream_);
stream_ = nullptr;
}
if (deviceSet_) {
aclrtResetDevice(deviceId_);
deviceSet_ = false;
}
if (aclInited_) {
aclFinalize();
aclInited_ = false;
}
}
int Init()
{
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, return ret);
aclInited_ = true;
ret = aclrtSetDevice(deviceId_);
CHECK_RET(ret == ACL_SUCCESS, return ret);
deviceSet_ = true;
ret = aclrtCreateStream(&stream_);
CHECK_RET(ret == ACL_SUCCESS, return ret);
return ACL_SUCCESS;
}
aclrtStream Stream() const { return stream_; }
private:
int deviceId_;
aclrtStream stream_ = nullptr;
bool aclInited_ = false;
bool deviceSet_ = false;
};
struct AclrtMemDeleter {
void operator()(void* ptr) const
{
if (ptr != nullptr) {
aclrtFree(ptr);
}
}
};
struct AclblasHandleDeleter {
void operator()(aclblasHandle_t handle) const
{
if (handle != nullptr) {
aclblasDestroy(handle);
}
}
};
int aclblasStrmvTest(AclContext& ctx)
{
constexpr int n = 3;
constexpr int lda = 3;
constexpr int incx = 1;
constexpr size_t aSize = lda * n * sizeof(float);
constexpr size_t xSize = n * sizeof(float);
// A 按列主序存储,此处存储下三角部分。
float hA[lda * n] = {
1.0f, 2.0f, 4.0f,
0.0f, 3.0f, 5.0f,
0.0f, 0.0f, 6.0f
};
float hX[n] = {1.0f, 2.0f, 3.0f};
void *rawA = nullptr;
auto aclRet = aclrtMalloc(&rawA, aSize, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
std::unique_ptr<void, AclrtMemDeleter> dA(rawA);
void *rawX = nullptr;
aclRet = aclrtMalloc(&rawX, xSize, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
std::unique_ptr<void, AclrtMemDeleter> dX(rawX);
aclRet = aclrtMemcpy(dA.get(), aSize, hA, aSize, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
aclRet = aclrtMemcpy(dX.get(), xSize, hX, xSize, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
aclblasHandle_t rawHandle = nullptr;
auto blasRet = aclblasCreate(&rawHandle);
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
std::unique_ptr<void, AclblasHandleDeleter> handle(rawHandle);
blasRet = aclblasSetStream(static_cast<aclblasHandle_t>(handle.get()), ctx.Stream());
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
blasRet = aclblasStrmv(
static_cast<aclblasHandle_t>(handle.get()), ACLBLAS_LOWER, ACLBLAS_OP_N, ACLBLAS_NON_UNIT,
n, static_cast<float*>(dA.get()), lda, static_cast<float*>(dX.get()), incx);
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
aclRet = aclrtSynchronizeStream(ctx.Stream());
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
aclRet = aclrtMemcpy(hX, xSize, dX.get(), xSize, ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
return 0;
}
int main()
{
AclContext ctx(0);
auto ret = ctx.Init();
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = aclblasStrmvTest(ctx);
CHECK_RET(ret == ACL_SUCCESS, return ret);
return 0;
}
aclblasCtrmv
产品支持情况
- Ascend 950PR / Ascend 950DT:不支持
- Atlas A3 训练系列产品 / Atlas A3 推理系列产品:支持
- Atlas A2 训练系列产品 / Atlas A2 推理系列产品:支持
函数原型
aclblasStatus_t aclblasCtrmv(aclblasHandle_t handle, aclblasFillMode_t uplo, aclblasOperation_t trans, aclblasDiagType_t diag, int64_t n, uint8_t *A, int64_t lda, uint8_t *x, int64_t incx)
参数说明
| 参数名 | 输入/输出 | 参数类型 | 说明 |
|---|---|---|---|
| handle | 输入 | aclblasHandle_t | ops-blas 库上下文句柄,携带 stream,Host 内存 |
| uplo | 输入 | aclblasFillMode_t | 指定矩阵 A 的上三角或下三角部分。ACLBLAS_UPPER 或 ACLBLAS_LOWER,Host 内存 |
| trans | 输入 | aclblasOperation_t | 指定对矩阵 A 的操作类型。ACLBLAS_OP_N(不转置)、ACLBLAS_OP_T(转置)、ACLBLAS_OP_C(共轭转置),Host 内存 |
| diag | 输入 | aclblasDiagType_t | 指定对角线元素是否为单位元。ACLBLAS_UNIT(单位对角线)或 ACLBLAS_NON_UNIT(非单位对角线),Host 内存 |
| n | 输入 | int64_t | 矩阵 A 的阶数,即向量的长度,Host 内存 |
| A | 输入 | uint8_t* | n x lda 的复数矩阵,Device 内存 |
| lda | 输入 | int64_t | 矩阵 A 的主维度,Host 内存 |
| x | 输入/输出 | uint8_t* | 复数向量,长度为 n。既是输入也是输出,Device 内存 |
| incx | 输入 | int64_t | x 中连续元素之间的步长,Host 内存 |
约束说明
- n 的取值范围为 [1, 8192]
- 仅支持 complex 数据类型
- incx > 0
- lda > 0