文件最后提交记录最后更新时间
10 天前
18 天前
4 天前
README.md

Axpy算子

算子概述

基础向量运算,实现 y = alpha * x + y

数学表达式:

y[i] = alpha * x[i] + y[i]  for i = 0 to n-1

包含以下接口:

接口名 功能简述
aclblasSaxpy 单精度浮点 AXPY
aclblasCaxpy 复数 AXPY

算子执行接口

aclblasSaxpy

产品支持情况

  • Ascend 950PR / Ascend 950DT:支持
  • Atlas A3 训练系列产品 / Atlas A3 推理系列产品:不支持
  • Atlas A2 训练系列产品 / Atlas A2 推理系列产品:不支持

函数原型

aclblasStatus_t aclblasSaxpy(aclblasHandle_t handle, int n, const float* alpha, float* x, int incx, float* y, int incy)

参数说明

参数名 输入/输出 参数类型 说明
handle 输入 aclblasHandle_t ops-blas 库上下文句柄,携带 stream,Host 内存
n 输入 int 向量元素个数,Host 内存
alpha 输入 const float*(FP32) 指向标量乘数的指针,Host 内存
x 输入 float*(FP32) 输入向量 x,Device 内存
incx 输入 int 向量 x 的步长,Host 内存
y 输入/输出 float*(FP32) 输入/输出向量 y,Device 内存
incy 输入 int 向量 y 的步长,Host 内存

约束说明

  • n >= 0
  • incx != 0
  • incy != 0

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例

#include <cstdio>
#include <memory>

#include "acl/acl.h"
#include "cann_ops_blas.h"

#define CHECK_RET(cond, return_expr) \
    do {                             \
        if (!(cond)) {               \
            return_expr;             \
        }                            \
    } while (0)

class AclContext {
public:
    explicit AclContext(int deviceId) : deviceId_(deviceId) {}

    ~AclContext()
    {
        if (stream_ != nullptr) {
            aclrtDestroyStream(stream_);
            stream_ = nullptr;
        }
        if (deviceSet_) {
            aclrtResetDevice(deviceId_);
            deviceSet_ = false;
        }
        if (aclInited_) {
            aclFinalize();
            aclInited_ = false;
        }
    }

    int Init()
    {
        auto ret = aclInit(nullptr);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        aclInited_ = true;

        ret = aclrtSetDevice(deviceId_);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        deviceSet_ = true;

        ret = aclrtCreateStream(&stream_);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        return ACL_SUCCESS;
    }

    aclrtStream Stream() const { return stream_; }

private:
    int deviceId_;
    aclrtStream stream_ = nullptr;
    bool aclInited_ = false;
    bool deviceSet_ = false;
};

struct AclrtMemDeleter {
    void operator()(void* ptr) const
    {
        if (ptr != nullptr) {
            aclrtFree(ptr);
        }
    }
};

struct AclblasHandleDeleter {
    void operator()(aclblasHandle_t handle) const
    {
        if (handle != nullptr) {
            aclblasDestroy(handle);
        }
    }
};

int aclblasSaxpyTest(AclContext& ctx)
{
    constexpr int n = 4;
    constexpr int incx = 1;
    constexpr int incy = 1;
    constexpr size_t bytes = n * sizeof(float);
    float alpha = 2.0f;

    float hX[n] = {1.0f, 2.0f, 3.0f, 4.0f};
    float hY[n] = {10.0f, 20.0f, 30.0f, 40.0f};

    void *rawX = nullptr;
    auto aclRet = aclrtMalloc(&rawX, bytes, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
    std::unique_ptr<void, AclrtMemDeleter> dX(rawX);

    void *rawY = nullptr;
    aclRet = aclrtMalloc(&rawY, bytes, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
    std::unique_ptr<void, AclrtMemDeleter> dY(rawY);

    aclRet = aclrtMemcpy(dX.get(), bytes, hX, bytes, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
    aclRet = aclrtMemcpy(dY.get(), bytes, hY, bytes, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);

    aclblasHandle_t rawHandle = nullptr;
    auto blasRet = aclblasCreate(&rawHandle);
    CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
    std::unique_ptr<void, AclblasHandleDeleter> handle(rawHandle);

    blasRet = aclblasSetStream(static_cast<aclblasHandle_t>(handle.get()), ctx.Stream());
    CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);

    blasRet = aclblasSaxpy(
        static_cast<aclblasHandle_t>(handle.get()), n, &alpha,
        static_cast<float*>(dX.get()), incx,
        static_cast<float*>(dY.get()), incy);
    CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);

    aclRet = aclrtSynchronizeStream(ctx.Stream());
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);

    aclRet = aclrtMemcpy(hY, bytes, dY.get(), bytes, ACL_MEMCPY_DEVICE_TO_HOST);
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);

    // 预期结果:y = alpha*x + y = {12, 24, 36, 48}
    for (int i = 0; i < n; i++) {
        printf("y[%d] = %f\n", i, hY[i]);
    }

    return 0;
}

int main()
{
    AclContext ctx(0);
    auto ret = ctx.Init();
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    ret = aclblasSaxpyTest(ctx);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    return 0;
}

aclblasCaxpy

产品支持情况

  • Ascend 950PR / Ascend 950DT:不支持
  • Atlas A3 训练系列产品 / Atlas A3 推理系列产品:支持
  • Atlas A2 训练系列产品 / Atlas A2 推理系列产品:支持

函数原型

aclblasStatus_t aclblasCaxpy(aclblasHandle_t handle, const int64_t n, const std::complex<float> alpha, uint8_t* x, int64_t incx, uint8_t* y, int64_t incy)

参数说明

参数名 输入/输出 参数类型 说明
handle 输入 aclblasHandle_t ops-blas 库上下文句柄,携带 stream,Host 内存
n 输入 const int64_t 向量元素个数,Host 内存
alpha 输入 const std::complex 复数标量系数,Host 内存
x 输入 uint8_t* 输入复向量,Device 内存
incx 输入 int64_t x 的步长,Host 内存
y 输入/输出 uint8_t* 输入/输出复向量,Device 内存
incy 输入 int64_t y 的步长,Host 内存

约束说明

  • n >= 0

调用示例

示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例

#include <cstdio>
#include <memory>

#include "acl/acl.h"
#include "cann_ops_blas.h"

#define CHECK_RET(cond, return_expr) \
    do {                             \
        if (!(cond)) {               \
            return_expr;             \
        }                            \
    } while (0)

class AclContext {
public:
    explicit AclContext(int deviceId) : deviceId_(deviceId) {}

    ~AclContext()
    {
        if (stream_ != nullptr) {
            aclrtDestroyStream(stream_);
            stream_ = nullptr;
        }
        if (deviceSet_) {
            aclrtResetDevice(deviceId_);
            deviceSet_ = false;
        }
        if (aclInited_) {
            aclFinalize();
            aclInited_ = false;
        }
    }

    int Init()
    {
        auto ret = aclInit(nullptr);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        aclInited_ = true;

        ret = aclrtSetDevice(deviceId_);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        deviceSet_ = true;

        ret = aclrtCreateStream(&stream_);
        CHECK_RET(ret == ACL_SUCCESS, return ret);
        return ACL_SUCCESS;
    }

    aclrtStream Stream() const { return stream_; }

private:
    int deviceId_;
    aclrtStream stream_ = nullptr;
    bool aclInited_ = false;
    bool deviceSet_ = false;
};

struct AclrtMemDeleter {
    void operator()(void* ptr) const
    {
        if (ptr != nullptr) {
            aclrtFree(ptr);
        }
    }
};

struct AclblasHandleDeleter {
    void operator()(aclblasHandle_t handle) const
    {
        if (handle != nullptr) {
            aclblasDestroy(handle);
        }
    }
};

int aclblasCaxpyTest(AclContext& ctx)
{
    constexpr int64_t n = 4;
    constexpr int64_t incx = 1;
    constexpr int64_t incy = 1;
    constexpr size_t bytes = static_cast<size_t>(n) * sizeof(std::complex<float>);
    std::complex<float> alpha(2.0f, 1.0f);

    std::complex<float> hX[n] = {{1.0f, 0.0f}, {2.0f, 0.0f}, {3.0f, 0.0f}, {4.0f, 0.0f}};
    std::complex<float> hY[n] = {{10.0f, 0.0f}, {20.0f, 0.0f}, {30.0f, 0.0f}, {40.0f, 0.0f}};

    void *rawX = nullptr;
    auto aclRet = aclrtMalloc(&rawX, bytes, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
    std::unique_ptr<void, AclrtMemDeleter> dX(rawX);

    void *rawY = nullptr;
    aclRet = aclrtMalloc(&rawY, bytes, ACL_MEM_MALLOC_HUGE_FIRST);
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
    std::unique_ptr<void, AclrtMemDeleter> dY(rawY);

    aclRet = aclrtMemcpy(dX.get(), bytes, hX, bytes, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);
    aclRet = aclrtMemcpy(dY.get(), bytes, hY, bytes, ACL_MEMCPY_HOST_TO_DEVICE);
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);

    aclblasHandle_t rawHandle = nullptr;
    auto blasRet = aclblasCreate(&rawHandle);
    CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);
    std::unique_ptr<void, AclblasHandleDeleter> handle(rawHandle);

    blasRet = aclblasSetStream(static_cast<aclblasHandle_t>(handle.get()), ctx.Stream());
    CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);

    blasRet = aclblasCaxpy(
        static_cast<aclblasHandle_t>(handle.get()), n, alpha,
        static_cast<uint8_t*>(dX.get()), incx,
        static_cast<uint8_t*>(dY.get()), incy);
    CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, return blasRet);

    aclRet = aclrtSynchronizeStream(ctx.Stream());
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);

    aclRet = aclrtMemcpy(hY, bytes, dY.get(), bytes, ACL_MEMCPY_DEVICE_TO_HOST);
    CHECK_RET(aclRet == ACL_SUCCESS, return aclRet);

    // 预期结果:y = alpha*x + y
    for (int64_t i = 0; i < n; i++) {
        printf("y[%lld] = (%f, %f)\n", static_cast<long long>(i), hY[i].real(), hY[i].imag());
    }

    return 0;
}

int main()
{
    AclContext ctx(0);
    auto ret = ctx.Init();
    CHECK_RET(ret == ACL_SUCCESS, return ret);

    ret = aclblasCaxpyTest(ctx);
    CHECK_RET(ret == ACL_SUCCESS, return ret);
    return 0;
}