Scal算子
算子概述
向量缩放算子,实现向量乘以标量的运算。包含实数向量缩放(Sscal)和复数向量缩放(Cscal)。
数学表达式:
x[i] = alpha * x[i] (i = 0 .. n-1,步长为 incx)
包含以下接口:
| 接口名 | 功能简述 |
|---|---|
| aclblasSscal | 实数向量乘以标量 |
| aclblasCscal | 复数向量乘以复数标量 |
算子执行接口
aclblasSscal
产品支持情况
- Ascend 950PR / Ascend 950DT:支持
- Atlas A3 训练系列产品 / Atlas A3 推理系列产品:支持
- Atlas A2 训练系列产品 / Atlas A2 推理系列产品:支持
函数原型
aclblasStatus_t aclblasSscal(aclblasHandle_t handle, int n, const float* alpha, float* x, int incx);
参数说明
| 参数名 | 输入/输出 | 参数类型 | 说明 |
|---|---|---|---|
| handle | 输入 | aclblasHandle_t | ops-blas 库上下文句柄,携带 stream,Host 内存 |
| n | 输入 | int | 向量 x 中的元素个数,Host 内存 |
| alpha | 输入 | const float*(FP32) | 指向标量乘数的指针,Host 内存 |
| x | 输入/输出 | float*(FP32) | float 向量,包含 n 个元素,Device 内存 |
| incx | 输入 | int | x 中连续元素之间的步长,Host 内存 |
约束说明
Ascend 950PR / Ascend 950DT(arch35):
- n 为整数;n <= 0 时为 no-op(直接返回 ACLBLAS_STATUS_SUCCESS,不修改 x)
- incx 为整数;incx <= 0 时为 no-op(不修改 x,对齐参考 BLAS cblas_sscal 的 IF (INCX.LE.0) RETURN 语义);incx > 0 时支持任意步长
- handle 不能为 nullptr,否则返回 ACLBLAS_STATUS_HANDLE_IS_NULLPTR
- n > 0 时 alpha、x 不能为 nullptr,否则返回 ACLBLAS_STATUS_INVALID_VALUE
Atlas A2 / Atlas A3 系列产品(arch22):
- incx 参数当前实现未启用(固定按连续向量 incx=1 处理,传入的 incx 取值不生效)
- 未对 n、handle、alpha、x 做入参校验
调用示例
示例代码如下,仅供参考,具体编译和执行过程请参考编译与运行样例。
#include <cstdio>
#include <memory>
#include <vector>
#include "acl/acl.h"
#include "cann_ops_blas.h"
#define CHECK_RET(cond, return_expr) \
do { \
if (!(cond)) { \
return_expr; \
} \
} while (0)
#define LOG_PRINT(message, ...) \
do { \
printf(message, ##__VA_ARGS__); \
} while (0)
class AclContext {
public:
explicit AclContext(int32_t deviceId) : deviceId_(deviceId) {}
~AclContext()
{
if (stream_ != nullptr) {
aclrtDestroyStream(stream_);
stream_ = nullptr;
}
if (deviceSet_) {
aclrtResetDevice(deviceId_);
deviceSet_ = false;
}
if (aclInited_) {
aclFinalize();
aclInited_ = false;
}
}
int Init()
{
auto ret = aclInit(nullptr);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclInit failed. ERROR: %d\n", ret); return ret);
aclInited_ = true;
ret = aclrtSetDevice(deviceId_);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtSetDevice failed. ERROR: %d\n", ret); return ret);
deviceSet_ = true;
ret = aclrtCreateStream(&stream_);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclrtCreateStream failed. ERROR: %d\n", ret); return ret);
return ACL_SUCCESS;
}
aclrtStream Stream() const { return stream_; }
private:
int32_t deviceId_;
aclrtStream stream_ = nullptr;
bool aclInited_ = false;
bool deviceSet_ = false;
};
struct AclMemDeleter {
void operator()(void* p) const { aclrtFree(p); }
};
struct BlasHandleDeleter {
void operator()(aclblasHandle_t h) const { aclblasDestroy(h); }
};
int aclblasSscalTest(AclContext& ctx)
{
aclrtStream stream = ctx.Stream();
// 1. 创建 ops-blas 句柄
aclblasHandle_t rawHandle = nullptr;
auto blasRet = aclblasCreate(&rawHandle);
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, LOG_PRINT("aclblasCreate failed. ERROR: %d\n", blasRet);
return blasRet);
std::unique_ptr<std::remove_pointer<aclblasHandle_t>::type, BlasHandleDeleter> handlePtr(rawHandle);
blasRet = aclblasSetStream(handlePtr.get(), stream);
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, LOG_PRINT("aclblasSetStream failed. ERROR: %d\n", blasRet);
return blasRet);
// 2. 准备 Host 数据
int n = 5;
int incx = 1;
float alpha = 2.0f;
std::vector<float> xHostData = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f}; // 缩放后期望 {2,4,6,8,10}
size_t xBytes = n * sizeof(float);
// 3. 申请 Device 内存并拷贝数据
void* rawMem = nullptr;
auto aclRet = aclrtMalloc(&rawMem, xBytes, ACL_MEM_MALLOC_HUGE_FIRST);
CHECK_RET(aclRet == ACL_SUCCESS, LOG_PRINT("aclrtMalloc for x failed. ERROR: %d\n", aclRet); return aclRet);
std::unique_ptr<float, AclMemDeleter> xDevicePtr(static_cast<float*>(rawMem));
aclRet = aclrtMemcpy(xDevicePtr.get(), xBytes, xHostData.data(), xBytes, ACL_MEMCPY_HOST_TO_DEVICE);
CHECK_RET(aclRet == ACL_SUCCESS, LOG_PRINT("aclrtMemcpy for x failed. ERROR: %d\n", aclRet); return aclRet);
// 4. 调用 aclblasSscal(alpha 为 Host 指针,原地缩放)
blasRet = aclblasSscal(handlePtr.get(), n, &alpha, xDevicePtr.get(), incx);
CHECK_RET(blasRet == ACLBLAS_STATUS_SUCCESS, LOG_PRINT("aclblasSscal failed. ERROR: %d\n", blasRet);
return blasRet);
// 5. 同步等待任务执行结束
aclRet = aclrtSynchronizeStream(stream);
CHECK_RET(aclRet == ACL_SUCCESS, LOG_PRINT("aclrtSynchronizeStream failed. ERROR: %d\n", aclRet); return aclRet);
// 6. 将结果从 Device 拷贝回 Host 并打印
std::vector<float> resultData(n, 0);
aclRet = aclrtMemcpy(resultData.data(), xBytes, xDevicePtr.get(), xBytes, ACL_MEMCPY_DEVICE_TO_HOST);
CHECK_RET(aclRet == ACL_SUCCESS, LOG_PRINT("copy result from device to host failed. ERROR: %d\n", aclRet);
return aclRet);
for (int i = 0; i < n; i++) {
LOG_PRINT("result[%d] is: %f\n", i, resultData[i]);
}
return ACL_SUCCESS;
}
int main()
{
AclContext ctx(0);
auto ret = ctx.Init();
CHECK_RET(ret == ACL_SUCCESS, return ret);
ret = aclblasSscalTest(ctx);
CHECK_RET(ret == ACL_SUCCESS, LOG_PRINT("aclblasSscalTest failed. ERROR: %d\n", ret); return ret);
return 0;
}
aclblasCscal
产品支持情况
- Ascend 950PR / Ascend 950DT:不支持
- Atlas A3 训练系列产品 / Atlas A3 推理系列产品:支持
- Atlas A2 训练系列产品 / Atlas A2 推理系列产品:支持
函数原型
aclblasStatus_t aclblasCscal(aclblasHandle_t handle, const int64_t n, const std::complex<float> alpha, uint8_t* x, const int64_t incx)
参数说明
| 参数名 | 输入/输出 | 参数类型 | 说明 |
|---|---|---|---|
| handle | 输入 | aclblasHandle_t | ops-blas 库上下文句柄,携带 stream,Host 内存 |
| n | 输入 | int64_t | 向量 x 中的复数元素个数,Host 内存 |
| alpha | 输入 | const std::complex | 用于乘法的复数标量,Host 内存 |
| x | 输入/输出 | uint8_t*(FP32 complex) | 复数向量,包含 n 个 complex 元素,Device 内存 |
| incx | 输入 | int64_t | x 中连续元素之间的步长,Host 内存 |
约束说明
- n >= 0
- incx != 0