* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file coalesce_sparse_tiling.cpp
* \brief
*/
#include "log/log.h"
#include "platform/platform_infos_def.h"
#include "register/op_impl_registry.h"
#include "platform/platform_ascendc.h"
#include "util/math_util.h"
#include "coalesce_sparse_tiling.h"
namespace {
constexpr uint64_t INT64_ID = 0;
constexpr uint64_t INT32_ID = 1;
constexpr uint64_t FP32_ID = 0;
constexpr uint64_t FP16_ID = 2;
constexpr uint64_t KEY_MODE1 = 6;
constexpr uint64_t KEY_MODE2 = 3;
constexpr uint64_t ASCEND950_SIMT_TILING_OFFSET = 100;
constexpr uint64_t AILGN256 = 256;
constexpr uint64_t AILGN32 = 32;
constexpr uint64_t MAXRPTIME = 4096;
}
namespace optiling {
class CoalesceSparseTiling {
public:
explicit CoalesceSparseTiling(gert::TilingContext* context) : TilingContext(context) {};
void TilingKeyCalcu(ge::DataType uniqueIndicesDtype, ge::DataType indicesDtype,
ge::DataType valuesDtype);
ge::graphStatus TilingDataCalcu(ge::DataType uniqueIndicesDtype, ge::DataType indicesDtype,
ge::DataType valuesDtype, const gert::Shape &indicesShape, const gert::Shape &valueShape);
ge::graphStatus Init();
ge::graphStatus RunKernelTiling();
void TilingDataPrint();
private:
CoalesceSparseTilingData TilingData;
gert::TilingContext* TilingContext = nullptr;
uint64_t tiling_key = 0;
uint64_t n = 0;
uint64_t m = 0;
uint64_t valueSize = 0;
uint64_t taskNum = 0;
uint64_t usedCoreNum = 0;
uint64_t taskTail = 0;
uint64_t moveOneSize = 0;
uint64_t taskRepeatTimes = 0;
uint64_t taskRepeatTail = 0;
uint64_t taskTailRepeatTimes = 0;
uint64_t taskTailRepeatTail = 0;
uint64_t moveValueTimes = 0;
uint64_t moveValueLen = 0;
uint64_t moveValueTail = 0;
};
void CoalesceSparseTiling::TilingKeyCalcu(ge::DataType uniqueIndicesDtype, ge::DataType indicesDtype,
ge::DataType valuesDtype) {
OP_LOGD(TilingContext, "GetTensorType.");
OP_LOGD(TilingContext, "SetTilingKey.");
auto platformInfo = TilingContext->GetPlatformInfo();
auto ascendcPlatform = platform_ascendc::PlatformAscendC(platformInfo);
bool isAscend950 = ascendcPlatform.GetSocVersion() == platform_ascendc::SocVersion::ASCEND950;
if (uniqueIndicesDtype == ge::DT_INT64) {
tiling_key = INT64_ID * KEY_MODE1;
} else {
tiling_key = INT32_ID * KEY_MODE1;
}
if (indicesDtype == ge::DT_INT64) {
tiling_key += INT64_ID * KEY_MODE2;
} else {
tiling_key += INT32_ID * KEY_MODE2;
}
if (valuesDtype == ge::DT_FLOAT) {
tiling_key += FP32_ID;
} else if (valuesDtype == ge::DT_INT32) {
tiling_key += INT32_ID;
} else {
tiling_key += FP16_ID;
}
if (isAscend950) {
tiling_key += ASCEND950_SIMT_TILING_OFFSET;
}
}
ge::graphStatus CoalesceSparseTiling::TilingDataCalcu(ge::DataType uniqueIndicesDtype, ge::DataType indicesDtype,
ge::DataType valuesDtype, const gert::Shape &indicesShape,
const gert::Shape &valueShape) {
auto platformInfo = TilingContext->GetPlatformInfo();
uint64_t uniqueIndiceTypeSize = GetSizeByDataType(uniqueIndicesDtype);
uint64_t indiceTypeSize = GetSizeByDataType(indicesDtype);
uint64_t valueTypeSize = GetSizeByDataType(valuesDtype);
n = indicesShape.GetDim(0);
m = indicesShape.GetDim(1);
OP_LOGD(TilingContext, "Calculate valueSize.");
valueSize = 1;
uint64_t valueShapeSize = valueShape.GetDimNum();
for (uint64_t i = 1; i < valueShapeSize; i++) {
valueSize *= valueShape.GetDim(i);
}
OP_LOGD(TilingContext, "Init coreNum.");
uint32_t coreNum = platformInfo->GetCoreNum();
OP_LOGD(TilingContext, "Calculate taskNum.");
OP_CHECK_IF(coreNum == 0, OP_LOGD(TilingContext, "Variale CoreNum should not be 0."), return ge::GRAPH_FAILED);
taskNum = ((n - 1) / coreNum / AILGN256 + 1) * AILGN256;
usedCoreNum = (n - 1) / taskNum + 1;
taskTail = n % taskNum;
if (taskTail == 0) {
taskTail = taskNum;
}
moveValueLen = valueSize;
if (moveValueLen > MAXRPTIME) {
moveValueLen = MAXRPTIME;
}
OP_LOGD(TilingContext, "Init ubSize.");
uint64_t ubSize;
auto ascendcPlatform = platform_ascendc::PlatformAscendC(TilingContext->GetPlatformInfo());
ascendcPlatform.GetCoreMemSize(platform_ascendc::CoreMemType::UB, ubSize);
uint64_t indexUb = ubSize - (moveValueLen * valueTypeSize / AILGN32 + 1) * AILGN32;
OP_CHECK_IF(
indiceTypeSize == 0, OP_LOGD(TilingContext, "Variale indiceTypeSize should not be 0."),
return ge::GRAPH_FAILED);
OP_CHECK_IF(
indiceTypeSize > AILGN32 , OP_LOGD(TilingContext, "indiceTypeSize should not be greater than 32."),
return ge::GRAPH_FAILED);
uint64_t indicesSizeAlign32 = ((m - 1) / (AILGN32 / indiceTypeSize) + 1) * AILGN32;
moveOneSize = (indexUb / (indicesSizeAlign32 + uniqueIndiceTypeSize)) / AILGN32 * AILGN32;
if (moveOneSize >= MAXRPTIME) {
moveOneSize = MAXRPTIME - AILGN32;
}
taskRepeatTimes = taskNum / moveOneSize;
taskRepeatTail = taskNum % moveOneSize;
taskTailRepeatTimes = taskTail / moveOneSize;
taskTailRepeatTail = taskTail % moveOneSize;
moveValueTimes = valueSize / moveValueLen;
moveValueTail = valueSize % moveValueLen;
return ge::GRAPH_SUCCESS;
}
ge::graphStatus CoalesceSparseTiling::Init()
{
OP_LOGD(TilingContext, "Tiling initing.");
auto uniqueIndicesDtype = TilingContext->GetInputDesc(1)->GetDataType();
auto indicesDtype = TilingContext->GetInputDesc(2)->GetDataType();
auto valuesDtype = TilingContext->GetInputDesc(3)->GetDataType();
auto indicesShape = TilingContext->GetInputShape(2)->GetStorageShape();
auto valueShape = TilingContext->GetInputShape(3)->GetStorageShape();
TilingKeyCalcu(uniqueIndicesDtype, indicesDtype, valuesDtype);
auto ret = TilingDataCalcu(uniqueIndicesDtype, indicesDtype, valuesDtype, indicesShape, valueShape);
OP_LOGD(TilingContext, "Tiling inited.");
return ret;
}
ge::graphStatus CoalesceSparseTiling::RunKernelTiling()
{
OP_LOGD(TilingContext, "Tiling start.");
TilingContext->SetBlockDim(usedCoreNum);
TilingContext->SetTilingKey(tiling_key);
TilingData.set_usedCoreNum(usedCoreNum);
TilingData.set_m(m);
TilingData.set_valueSize(valueSize);
TilingData.set_taskNum(taskNum);
TilingData.set_taskTail(taskTail);
TilingData.set_moveOneSize(moveOneSize);
TilingData.set_taskRepeatTimes(taskRepeatTimes);
TilingData.set_taskRepeatTail(taskRepeatTail);
TilingData.set_taskTailRepeatTimes(taskTailRepeatTimes);
TilingData.set_taskTailRepeatTail(taskTailRepeatTail);
TilingData.set_moveValueTimes(moveValueTimes);
TilingData.set_moveValueLen(moveValueLen);
TilingData.set_moveValueTail(moveValueTail);
size_t sysWorkspaceSize = 16*1024*1024;
size_t* currentWorkspace = TilingContext->GetWorkspaceSizes(1);
currentWorkspace[0] = sysWorkspaceSize;
TilingData.SaveToBuffer(
TilingContext->GetRawTilingData()->GetData(), TilingContext->GetRawTilingData()->GetCapacity());
TilingContext->GetRawTilingData()->SetDataSize(TilingData.GetDataSize());
TilingDataPrint();
OP_LOGD(TilingContext, "Tiling end.");
return ge::GRAPH_SUCCESS;
}
void CoalesceSparseTiling::TilingDataPrint()
{
OP_LOGD(TilingContext, "usedCoreNum:%ld.", usedCoreNum);
OP_LOGD(TilingContext, "m:%ld.", m);
OP_LOGD(TilingContext, "valueSize:%ld.", valueSize);
OP_LOGD(TilingContext, "taskNum:%ld.", taskNum);
OP_LOGD(TilingContext, "taskTail:%ld.", taskTail);
OP_LOGD(TilingContext, "moveOneSize:%ld.", moveOneSize);
OP_LOGD(TilingContext, "taskRepeatTimes:%ld.", taskRepeatTimes);
OP_LOGD(TilingContext, "taskRepeatTail:%ld.", taskRepeatTail);
OP_LOGD(TilingContext, "taskTailRepeatTimes:%ld.", taskTailRepeatTimes);
OP_LOGD(TilingContext, "taskTailRepeatTail:%ld.", taskTailRepeatTail);
OP_LOGD(TilingContext, "moveValueTimes:%ld.", moveValueTimes);
OP_LOGD(TilingContext, "moveValueLen:%ld.", moveValueLen);
OP_LOGD(TilingContext, "moveValueTail:%ld.", moveValueTail);
}
static ge::graphStatus TilingCoalesceSparse(gert::TilingContext* context)
{
CoalesceSparseTiling tilingObject(context);
tilingObject.Init();
return tilingObject.RunKernelTiling();
}
struct CoalesceSparseCompileInfo {};
IMPL_OP_OPTILING(CoalesceSparse).Tiling(TilingCoalesceSparse);
}