* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file sdot_host.cpp
* \brief Single-precision dot product (SIMD AIV): result = sum(x[i] * y[i])
* Arch35 (ascend950) host-side implementation.
*/
#include <cstdint>
#include <cstdio>
#include <algorithm>
#include "acl/acl.h"
#include "log/log.h"
#include "cann_ops_blas.h"
#include "cann_ops_blas_common.h"
#include "common/helper/aclblas_handle_internal.h"
#include "common/helper/host_utils.h"
#include "sdot_tiling_data.h"
void sdot_kernel_do(uint8_t* x, uint8_t* y, uint8_t* result, uint8_t* workSpace,
uint32_t numBlocks, const SdotTilingData& tiling, void *stream);
static aclblasStatus_t ValidateSdotParams(
int64_t incx, int64_t incy, const float* x, const float* y, const float* result)
{
CHECK_RET(incx != 0, OP_LOGE("aclblasSdot", "incx must not be zero"); return ACLBLAS_STATUS_INVALID_VALUE);
CHECK_RET(incy != 0, OP_LOGE("aclblasSdot", "incy must not be zero"); return ACLBLAS_STATUS_INVALID_VALUE);
CHECK_RET(x != nullptr, OP_LOGE("aclblasSdot", "x must not be nullptr"); return ACLBLAS_STATUS_INVALID_VALUE);
CHECK_RET(y != nullptr, OP_LOGE("aclblasSdot", "y must not be nullptr"); return ACLBLAS_STATUS_INVALID_VALUE);
CHECK_RET(
result != nullptr, OP_LOGE("aclblasSdot", "result must not be nullptr"); return ACLBLAS_STATUS_INVALID_VALUE);
return ACLBLAS_STATUS_SUCCESS;
}
static SdotTilingData CalSdotTilingData(int64_t n, int64_t incx, int64_t incy, uint32_t vectorCoreNum)
{
SdotTilingData tiling{};
tiling.n = n;
tiling.incx = incx;
tiling.incy = incy;
tiling.useCoreNum = std::min(vectorCoreNum, static_cast<uint32_t>(n));
if (tiling.useCoreNum == 0) {
tiling.useCoreNum = 1;
}
return tiling;
}
static aclblasStatus_t LaunchSdotKernel(
_aclblas_handle* h, const SdotTilingData& tiling, const float* x, const float* y, float* result)
{
size_t workspaceNeed = tiling.useCoreNum * sizeof(float);
CHECK_RET(
workspaceNeed <= aclblasGetEffectiveWorkspaceSize(h),
OP_LOGE("aclblasSdot", "workspace %zu > handle %zu", workspaceNeed, aclblasGetEffectiveWorkspaceSize(h));
return ACLBLAS_STATUS_EXECUTION_FAILED);
uint8_t* workspaceDevice = reinterpret_cast<uint8_t*>(aclblasGetEffectiveWorkspace(h));
OP_LOGI("aclblasSdot", "launching kernel with %u cores", tiling.useCoreNum);
sdot_kernel_do(
reinterpret_cast<uint8_t*>(const_cast<float*>(x)), reinterpret_cast<uint8_t*>(const_cast<float*>(y)),
reinterpret_cast<uint8_t*>(result), workspaceDevice, tiling.useCoreNum, tiling, h->stream);
return ACLBLAS_STATUS_SUCCESS;
}
aclblasStatus_t aclblasSdot(
aclblasHandle_t handle, const int64_t n, const float* x, const int64_t incx, const float* y, const int64_t incy,
float* result)
{
CHECK_RET(handle != nullptr, OP_LOGE("aclblasSdot", "handle is nullptr"); return ACLBLAS_STATUS_HANDLE_IS_NULLPTR);
if (n <= 0) {
if (result != nullptr) {
aclError ret = aclrtMemset(result, sizeof(float), 0, sizeof(float));
if (ret != ACL_SUCCESS) {
OP_LOGE("aclblasSdot", "aclrtMemset failed, ret=%d", ret);
return ACLBLAS_STATUS_EXECUTION_FAILED;
}
}
return ACLBLAS_STATUS_SUCCESS;
}
aclblasStatus_t status = ValidateSdotParams(incx, incy, x, y, result);
if (status != ACLBLAS_STATUS_SUCCESS) {
return status;
}
auto* h = reinterpret_cast<_aclblas_handle*>(handle);
uint32_t vectorCoreNum = GetAivCoreCount();
if (vectorCoreNum == 0) {
OP_LOGE("aclblasSdot", "GetAivCoreCount failed");
return ACLBLAS_STATUS_EXECUTION_FAILED;
}
SdotTilingData tiling = CalSdotTilingData(n, incx, incy, vectorCoreNum);
OP_LOGD(
"aclblasSdot", "tiling: n=%ld incx=%ld incy=%ld useCoreNum=%u", tiling.n, tiling.incx, tiling.incy,
tiling.useCoreNum);
return LaunchSdotKernel(h, tiling, x, y, result);
}