Axpy算子
算子概述
基础向量运算,实现 y = alpha * x + y。
数学表达式:
y[i] = alpha * x[i] + y[i] for i = 0 to n-1
包含以下接口:
| 接口名 | 功能简述 |
|---|---|
| aclblasSaxpy | 单精度浮点 AXPY |
| aclblasCaxpy | 复数 AXPY |
算子执行接口
aclblasSaxpy
产品支持情况
- Ascend 950PR / Ascend 950DT:支持
- Atlas A3 训练系列产品 / Atlas A3 推理系列产品:不支持
- Atlas A2 训练系列产品 / Atlas A2 推理系列产品:不支持
函数原型
aclblasStatus_t aclblasSaxpy(aclblasHandle_t handle, int n, const float* alpha, float* x, int incx, float* y, int incy)
参数说明
| 参数名 | 输入/输出 | 参数类型 | 说明 |
|---|---|---|---|
| handle | 输入 | aclblasHandle_t | ops-blas 库上下文句柄,携带 stream,Host 内存 |
| n | 输入 | int | 向量元素个数,Host 内存 |
| alpha | 输入 | const float*(FP32) | 指向标量乘数的指针,Host 内存 |
| x | 输入 | float*(FP32) | 输入向量 x,Device 内存 |
| incx | 输入 | int | 向量 x 的步长,Host 内存 |
| y | 输入/输出 | float*(FP32) | 输入/输出向量 y,Device 内存 |
| incy | 输入 | int | 向量 y 的步长,Host 内存 |
约束说明
- n >= 0
- incx != 0
- incy != 0
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
#include <cstdio>
#include <memory>
#include "acl/acl.h"
#include "cann_ops_blas.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
class AclContext {
public:
explicit AclContext(int deviceId) : deviceId_(deviceId) {}
~AclContext()
{
if (stream_ != nullptr) {
aclrtDestroyStream(stream_);
stream_ = nullptr;
}
if (deviceSet_) {
aclrtResetDevice(deviceId_);
deviceSet_ = false;
}
if (aclInited_) {
aclFinalize();
aclInited_ = false;
}
}
int Init()
{
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, return ret);
aclInited_ = true;
ret = aclrtSetDevice(deviceId_);
CHECK_RET(ret == ACL_SUCCESS, return ret);
deviceSet_ = true;
ret = aclrtCreateStream(&stream_);
CHECK_RET(ret == ACL_SUCCESS, return ret);
return ACL_SUCCESS;
}
aclrtStream Stream() const { return stream_; }
private:
int deviceId_;
aclrtStream stream_ = nullptr;
bool aclInited_ = false;
bool deviceSet_ = false;
};
struct AclrtMemDeleter {
void operator()(void* ptr) const
{
if (ptr != nullptr) {
aclrtFree(ptr);
}
}
};
struct AclblasHandleDeleter {
void operator()(aclblasHandle_t handle) const
{
if (handle != nullptr) {
aclblasDestroy(handle);
}
}
};
int aclblasSaxpyTest(AclContext& ctx)
{
constexpr int n = 4;
constexpr int incx = 1;
constexpr int incy = 1;
constexpr size_t bytes = n * sizeof(float);
float alpha = 2.0f;
float hX[n] = {1.0f, 2.0f, 3.0f, 4.0f};
float hY[n] = {10.0f, 20.0f, 30.0f, 40.0f};
void *rawX = nullptr;
auto aclRet = aclrtMalloc(&rawX, bytes, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
std::unique_ptr<void, AclrtMemDeleter> dX(rawX);
void *rawY = nullptr;
aclRet = aclrtMalloc(&rawY, bytes, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
std::unique_ptr<void, AclrtMemDeleter> dY(rawY);
aclRet = aclrtMemcpy(dX.get(), bytes, hX, bytes, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
aclRet = aclrtMemcpy(dY.get(), bytes, hY, bytes, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
aclblasHandle_t rawHandle = nullptr;
auto blasRet = aclblasCreate(&rawHandle);
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
std::unique_ptr<void, AclblasHandleDeleter> handle(rawHandle);
blasRet = aclblasSetStream(static_cast<aclblasHandle_t>(handle.get()), ctx.Stream());
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
blasRet = aclblasSaxpy(
static_cast<aclblasHandle_t>(handle.get()), n, &alpha,
static_cast<float*>(dX.get()), incx,
static_cast<float*>(dY.get()), incy);
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
aclRet = aclrtSynchronizeStream(ctx.Stream());
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
aclRet = aclrtMemcpy(hY, bytes, dY.get(), bytes, ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
// 预期结果:y = alpha*x + y = {12, 24, 36, 48}
for (int i = 0; i < n; i++) {
printf("y[%d] = %f\n", i, hY[i]);
}
return 0;
}
int main()
{
AclContext ctx(0);
auto ret = ctx.Init();
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = aclblasSaxpyTest(ctx);
CHECK_RET(ret == ACL_SUCCESS, return ret);
return 0;
}
aclblasCaxpy
产品支持情况
- Ascend 950PR / Ascend 950DT:不支持
- Atlas A3 训练系列产品 / Atlas A3 推理系列产品:支持
- Atlas A2 训练系列产品 / Atlas A2 推理系列产品:支持
函数原型
aclblasStatus_t aclblasCaxpy(aclblasHandle_t handle, const int64_t n, const std::complex<float> alpha, uint8_t* x, int64_t incx, uint8_t* y, int64_t incy)
参数说明
| 参数名 | 输入/输出 | 参数类型 | 说明 |
|---|---|---|---|
| handle | 输入 | aclblasHandle_t | ops-blas 库上下文句柄,携带 stream,Host 内存 |
| n | 输入 | const int64_t | 向量元素个数,Host 内存 |
| alpha | 输入 | const std::complex | 复数标量系数,Host 内存 |
| x | 输入 | uint8_t* | 输入复向量,Device 内存 |
| incx | 输入 | int64_t | x 的步长,Host 内存 |
| y | 输入/输出 | uint8_t* | 输入/输出复向量,Device 内存 |
| incy | 输入 | int64_t | y 的步长,Host 内存 |
约束说明
- n >= 0
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
#include <cstdio>
#include <memory>
#include "acl/acl.h"
#include "cann_ops_blas.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
class AclContext {
public:
explicit AclContext(int deviceId) : deviceId_(deviceId) {}
~AclContext()
{
if (stream_ != nullptr) {
aclrtDestroyStream(stream_);
stream_ = nullptr;
}
if (deviceSet_) {
aclrtResetDevice(deviceId_);
deviceSet_ = false;
}
if (aclInited_) {
aclFinalize();
aclInited_ = false;
}
}
int Init()
{
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, return ret);
aclInited_ = true;
ret = aclrtSetDevice(deviceId_);
CHECK_RET(ret == ACL_SUCCESS, return ret);
deviceSet_ = true;
ret = aclrtCreateStream(&stream_);
CHECK_RET(ret == ACL_SUCCESS, return ret);
return ACL_SUCCESS;
}
aclrtStream Stream() const { return stream_; }
private:
int deviceId_;
aclrtStream stream_ = nullptr;
bool aclInited_ = false;
bool deviceSet_ = false;
};
struct AclrtMemDeleter {
void operator()(void* ptr) const
{
if (ptr != nullptr) {
aclrtFree(ptr);
}
}
};
struct AclblasHandleDeleter {
void operator()(aclblasHandle_t handle) const
{
if (handle != nullptr) {
aclblasDestroy(handle);
}
}
};
int aclblasCaxpyTest(AclContext& ctx)
{
constexpr int64_t n = 4;
constexpr int64_t incx = 1;
constexpr int64_t incy = 1;
constexpr size_t bytes = static_cast<size_t>(n) * sizeof(std::complex<float>);
std::complex<float> alpha(2.0f, 1.0f);
std::complex<float> hX[n] = {{1.0f, 0.0f}, {2.0f, 0.0f}, {3.0f, 0.0f}, {4.0f, 0.0f}};
std::complex<float> hY[n] = {{10.0f, 0.0f}, {20.0f, 0.0f}, {30.0f, 0.0f}, {40.0f, 0.0f}};
void *rawX = nullptr;
auto aclRet = aclrtMalloc(&rawX, bytes, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
std::unique_ptr<void, AclrtMemDeleter> dX(rawX);
void *rawY = nullptr;
aclRet = aclrtMalloc(&rawY, bytes, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
std::unique_ptr<void, AclrtMemDeleter> dY(rawY);
aclRet = aclrtMemcpy(dX.get(), bytes, hX, bytes, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
aclRet = aclrtMemcpy(dY.get(), bytes, hY, bytes, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
aclblasHandle_t rawHandle = nullptr;
auto blasRet = aclblasCreate(&rawHandle);
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
std::unique_ptr<void, AclblasHandleDeleter> handle(rawHandle);
blasRet = aclblasSetStream(static_cast<aclblasHandle_t>(handle.get()), ctx.Stream());
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
blasRet = aclblasCaxpy(
static_cast<aclblasHandle_t>(handle.get()), n, alpha,
static_cast<uint8_t*>(dX.get()), incx,
static_cast<uint8_t*>(dY.get()), incy);
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
aclRet = aclrtSynchronizeStream(ctx.Stream());
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
aclRet = aclrtMemcpy(hY, bytes, dY.get(), bytes, ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
// 预期结果:y = alpha*x + y
for (int64_t i = 0; i < n; i++) {
printf("y[%lld] = (%f, %f)\n", static_cast<long long>(i), hY[i].real(), hY[i].imag());
}
return 0;
}
int main()
{
AclContext ctx(0);
auto ret = ctx.Init();
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = aclblasCaxpyTest(ctx);
CHECK_RET(ret == ACL_SUCCESS, return ret);
return 0;
}