* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <limits>
#include <mki_loader/op_register.h>
#include <mki/base/kernel_base.h>
#include <mki/utils/assert/assert.h>
#include <mki/utils/log/log.h>
#include <mki/utils/math/tensor_utils.h>
#include <mki/utils/math/math.h>
#include <mki/utils/platform/platform_info.h>
#include "atbops/params/params.h"
#include "tiling_data.h"
using namespace Mki;
namespace AtbOps {
constexpr uint64_t TENSOR_SEQLEN_IDX = 4;
constexpr uint64_t TENSOR_TOKEN_OFFSET_IDX = 3;
constexpr uint64_t TENSOR_BATCH_STATUS_IDX = 5;
constexpr uint32_t MAX_PROCESS_NUM = 98304;
static constexpr uint64_t KVCACHE_TILING_ID_BASE = 2000000000;
static constexpr uint64_t KVCACHE_TILING_ID_TYPE = 100000000;
template <bool IS_PARAMS = false>
Status KVCacheTiling(const LaunchParam &launchParam, KernelInfo &kernelInfo)
{
auto tilingDataPtr = reinterpret_cast<KVCacheTilingData *>(kernelInfo.GetTilingHostAddr());
uint32_t batch = 0;
if constexpr (IS_PARAMS) {
auto param = AnyCast<OpParam::KVCache>(launchParam.GetParam());
batch = param.seqLen.size();
} else {
const auto &inTensor3Dims = launchParam.GetInTensor(DIM_3).desc.dims;
batch = static_cast<uint32_t>(inTensor3Dims[0]);
}
MKI_CHECK(batch > 0, "batch should be larger than 0 ",
return Status::FailStatus(ERROR_INVALID_VALUE));
const auto &inTensor0Dims = launchParam.GetInTensor(0).desc.dims;
auto hiddenSize = static_cast<uint32_t>(inTensor0Dims[1]);
uint32_t maxUbBytes = static_cast<uint32_t>(PlatformInfo::Instance().GetUbSize());
TensorDType dtype = launchParam.GetInTensor(0).desc.dtype;
uint32_t maxProcessNum = maxUbBytes / GetTensorElementSize(dtype);
MKI_CHECK(hiddenSize > 0 && hiddenSize <= maxProcessNum, "invalid hiddenSize: " << hiddenSize,
return Status::FailStatus(ERROR_INVALID_VALUE, "invalid hiddenSize:" + std::to_string(hiddenSize)));
const auto &inTensor2Dims = launchParam.GetInTensor(DIM_2).desc.dims;
uint32_t maxSeqLen = 0;
if (inTensor2Dims.size() == DIM_4) {
MKI_CHECK(inTensor2Dims[DIM_2] <= std::numeric_limits<uint32_t>::max(),
"invalid tilingData maxSeqLen: " << inTensor2Dims[DIM_2], return Status::FailStatus(ERROR_INVALID_VALUE,
"invalid tilingData maxSeqLen: " + std::to_string(inTensor2Dims[DIM_2]) +
", maxSeqLen > " + std::to_string(0)));
maxSeqLen = static_cast<uint32_t>(inTensor2Dims[DIM_2]);
} else if (inTensor2Dims.size() == DIM_3) {
MKI_CHECK(inTensor2Dims[DIM_1] <= std::numeric_limits<uint32_t>::max(),
"invalid tilingData maxSeqLen: " << inTensor2Dims[DIM_1], return Status::FailStatus(ERROR_INVALID_VALUE,
"invalid tilingData maxSeqLen: " + std::to_string(inTensor2Dims[DIM_1]) +
", maxSeqLen > " + std::to_string(0)));
maxSeqLen = static_cast<uint32_t>(inTensor2Dims[DIM_1]);
} else {
return Status::FailStatus(ERROR_INVALID_VALUE, "invalid tilingData maxSeqLen: " + std::to_string(maxSeqLen) +
", maxSeqLen > " + std::to_string(0));
}
tilingDataPtr->batch = batch;
tilingDataPtr->maxSeqLen = maxSeqLen;
tilingDataPtr->hiddenSize = hiddenSize;
MKI_LOG(INFO) << "KVCache Tiling Data: batch " << tilingDataPtr->batch << ", maxSeqLen "
<< tilingDataPtr->maxSeqLen << ", hiddenSize: " << tilingDataPtr->hiddenSize;
kernelInfo.SetBlockDim(batch);
return Status::OkStatus();
}
template <bool IS_PARAMS = false>
Status KVCacheNdTiling(const LaunchParam &launchParam, KernelInfo &kernelInfo)
{
Status status = KVCacheTiling<IS_PARAMS>(launchParam, kernelInfo);
if (!status.Ok()) {
return status;
}
uint64_t tilingKey = KVCACHE_TILING_ID_BASE;
TensorDType dtype = launchParam.GetInTensor(0).desc.dtype;
tilingKey += (dtype == TENSOR_DTYPE_FLOAT16 || dtype == TENSOR_DTYPE_BF16) ? 0 : KVCACHE_TILING_ID_TYPE;
kernelInfo.SetTilingId(tilingKey);
MKI_LOG(INFO) << "KVCacheNd TilingKey: " << tilingKey;
return Status::OkStatus();
}
template <bool IS_PARAMS = false>
Status KVCacheNzTiling(const LaunchParam &launchParam, KernelInfo &kernelInfo)
{
auto tilingDataPtr = reinterpret_cast<KVCacheTilingData *>(kernelInfo.GetTilingHostAddr());
uint32_t batch = 0;
if constexpr (IS_PARAMS) {
auto param = AnyCast<OpParam::KVCache>(launchParam.GetParam());
batch = param.seqLen.size();
} else {
const auto &inTensor3Dims = launchParam.GetInTensor(3).desc.dims;
batch = static_cast<uint32_t>(inTensor3Dims[0]);
}
MKI_CHECK(batch > 0, "batch should be larger than 0 ",
return Status::FailStatus(ERROR_INVALID_VALUE));
const auto &inTensor0Dims = launchParam.GetInTensor(0).desc.dims;
int64_t hiddenSize = inTensor0Dims[1] * 16;
MKI_CHECK(hiddenSize <= MAX_PROCESS_NUM, "invalid hiddenSize: " << hiddenSize,
return Status::FailStatus(ERROR_INVALID_VALUE, "invalid hiddenSize:" + std::to_string(hiddenSize)));
const auto &inTensor2Dims = launchParam.GetInTensor(2).desc.dims;
MKI_CHECK(inTensor2Dims[DIM_3] <= std::numeric_limits<uint32_t>::max(),
"invalid tilingData maxSeqLen: " << inTensor2Dims[DIM_3],
return Status::FailStatus(ERROR_INVALID_VALUE,
"invalid tilingData maxSeqLen: " + std::to_string(inTensor2Dims[DIM_3]) +
", maxSeqLen > " + std::to_string(0)));
auto maxSeqLen = static_cast<uint32_t>(inTensor2Dims[3]);
uint32_t coreNum = PlatformInfo::Instance().GetCoreNum(CoreType::CORE_TYPE_VECTOR);
tilingDataPtr->batchPerCore = Utils::CeilDiv(batch, coreNum);
uint32_t numCore = Utils::CeilDiv(batch, tilingDataPtr->batchPerCore);
tilingDataPtr->batch = batch;
tilingDataPtr->maxSeqLen = maxSeqLen;
tilingDataPtr->hiddenSize = static_cast<uint32_t>(hiddenSize);
MKI_LOG(INFO) << "KVCacheNz Tiling Data: batch " << tilingDataPtr->batch << ", maxSeqLen "
<< tilingDataPtr->maxSeqLen << ", hiddenSize: " << tilingDataPtr->hiddenSize << ", batchPerCore: "
<< tilingDataPtr->batchPerCore;
kernelInfo.SetBlockDim(numCore);
return Status::OkStatus();
}
class KVCacheKernel : public KernelBase {
public:
explicit KVCacheKernel(const std::string &kernelName, const BinHandle *handle) noexcept
: KernelBase(kernelName, handle)
{
launchBufferSize_ = sizeof(KVCacheTilingData);
}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::KVCache),
"kv_cache: param type invalid", return false);
MKI_CHECK(launchParam.GetInTensorCount() == 5, "in tensor num invalid", return false);
MKI_CHECK(launchParam.GetOutTensorCount() == 1, "out tensor num invalid", return false);
MKI_CHECK(launchParam.GetInTensor(DIM_0).desc.dims.size() > 1,
"in tensor 0 dims num is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(DIM_0).desc.dims[DIM_1] <= std::numeric_limits<uint32_t>::max(),
"in tensor 0 shape is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(DIM_3).desc.dims.size() == 1,
"in tensor 3 dims num is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(DIM_3).desc.dims[DIM_0] <= std::numeric_limits<uint32_t>::max(),
"in tensor 3 shape is invalid", return false);
return true;
}
};
class KVCacheNdKernel : public KVCacheKernel {
public:
explicit KVCacheNdKernel(const std::string &kernelName, const BinHandle *handle) noexcept
: KVCacheKernel(kernelName, handle) {}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(KVCacheKernel::CanSupport(launchParam), "failed to check support", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.format == TENSOR_FORMAT_ND,
"in tensor 0 format is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 2,
"in tensor 0 dims num is invalid", return false);
return true;
}
Status InitImpl(const LaunchParam &launchParam) override
{
return KVCacheNdTiling(launchParam, kernelInfo_);
}
};
REG_KERNEL_BASE(KVCacheNdKernel);
class KVCacheNzKernel : public KVCacheKernel {
public:
explicit KVCacheNzKernel(const std::string &kernelName, const BinHandle *handle) noexcept
: KVCacheKernel(kernelName, handle) {}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(KVCacheKernel::CanSupport(launchParam), "failed to check support", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.format == TENSOR_FORMAT_FRACTAL_NZ,
"in tensor 0 format is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 4,
"in tensor 0 dims num is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(2).desc.dims.size() == 5,
"in tensor 2 dims num is invalid", return false);
return true;
}
Status InitImpl(const LaunchParam &launchParam) override
{
return KVCacheNzTiling(launchParam, kernelInfo_);
}
};
REG_KERNEL_BASE(KVCacheNzKernel);
class KVCacheNdDynamicBatchKernel : public KVCacheKernel {
public:
explicit KVCacheNdDynamicBatchKernel(const std::string &kernelName, const BinHandle *handle) noexcept
: KVCacheKernel(kernelName, handle) {}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(KVCacheKernel::CanSupport(launchParam), "failed to check support", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.format == TENSOR_FORMAT_ND,
"in tensor 0 format is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 2,
"in tensor 0 dims num is invalid", return false);
return true;
}
uint64_t GetTilingSize(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::KVCache),
"OpParam is invalid", return 0);
auto param = AnyCast<OpParam::KVCache>(launchParam.GetParam());
size_t constTensorSize = Utils::GetConstTensorSize<int32_t>(param.batchRunStatus);
size_t maxVal = std::numeric_limits<uint32_t>::max();
MKI_CHECK(maxVal - constTensorSize >= launchBufferSize_, "batchRunStatus Size is invalid", return 0);
return launchBufferSize_ + constTensorSize;
}
Status InitImpl(const LaunchParam &launchParam) override
{
auto status = KVCacheTiling(launchParam, kernelInfo_);
MKI_CHECK_NO_LOG(status.Ok(), return status);
kernelInfo_.SetConstTensorOffset(launchBufferSize_);
auto ¶m = AnyCast<OpParam::KVCache>(launchParam.GetParam());
auto ret = kernelInfo_.AddConstTensorData<int32_t>(TENSOR_BATCH_STATUS_IDX, param.batchRunStatus);
MKI_CHECK_NO_LOG(ret, return Status::FailStatus(1));
return Status::OkStatus();
}
};
REG_KERNEL_BASE(KVCacheNdDynamicBatchKernel);
class KVCacheParamsKernel : public KernelBase {
public:
explicit KVCacheParamsKernel(const std::string &kernelName, const BinHandle *handle) noexcept
: KernelBase(kernelName, handle)
{
launchBufferSize_ = sizeof(KVCacheTilingData);
}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::KVCache),
"kv_cache: param type invalid", return false);
MKI_CHECK(launchParam.GetInTensorCount() == 3, "in tensor num invalid", return false);
MKI_CHECK(launchParam.GetOutTensorCount() == 1, "out tensor num invalid", return false);
MKI_CHECK(launchParam.GetInTensor(DIM_0).desc.dims.size() > 1,
"in tensor 0 dims num is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(DIM_0).desc.dims[DIM_1] <= std::numeric_limits<uint32_t>::max(),
"in tensor 0 shape is invalid", return false);
return true;
}
uint64_t GetTilingSize(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::KVCache),
"OpParam is invalid", return 0);
auto param = AnyCast<OpParam::KVCache>(launchParam.GetParam());
size_t constTensorSize = Utils::GetConstTensorSize<int32_t>(param.seqLen)
+ Utils::GetConstTensorSize<int32_t>(param.tokenOffset);
size_t maxVal = std::numeric_limits<uint32_t>::max();
MKI_CHECK(maxVal - constTensorSize >= launchBufferSize_, "batch Size is invalid", return 0);
return launchBufferSize_ + constTensorSize;
}
Status InitImpl(const LaunchParam &launchParam) override
{
kernelInfo_.SetConstTensorOffset(launchBufferSize_);
auto ¶m = AnyCast<OpParam::KVCache>(launchParam.GetParam());
auto ret = kernelInfo_.AddConstTensorData<int32_t>(TENSOR_SEQLEN_IDX, param.seqLen);
MKI_CHECK_NO_LOG(ret, return Status::FailStatus(1));
ret = kernelInfo_.AddConstTensorData<int32_t>(TENSOR_TOKEN_OFFSET_IDX, param.tokenOffset);
MKI_CHECK_NO_LOG(ret, return Status::FailStatus(1));
return Status::OkStatus();
}
};
class KVCacheNdParamsKernel : public KVCacheParamsKernel {
public:
explicit KVCacheNdParamsKernel(const std::string &kernelName, const BinHandle *handle) noexcept
: KVCacheParamsKernel(kernelName, handle) {}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(KVCacheParamsKernel::CanSupport(launchParam), "failed to check support", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.format == TENSOR_FORMAT_ND,
"in tensor 0 format is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 2,
"in tensor 0 dims num is invalid", return false);
return true;
}
Status InitImpl(const LaunchParam &launchParam) override
{
auto status = KVCacheNdTiling<true>(launchParam, kernelInfo_);
MKI_CHECK_NO_LOG(status.Ok(), return status);
status = KVCacheParamsKernel::InitImpl(launchParam);
MKI_CHECK_NO_LOG(status.Ok(), return status);
return Status::OkStatus();
}
};
REG_KERNEL_BASE(KVCacheNdParamsKernel);
class KVCacheNzParamsKernel : public KVCacheParamsKernel {
public:
explicit KVCacheNzParamsKernel(const std::string &kernelName, const BinHandle *handle) noexcept
: KVCacheParamsKernel(kernelName, handle) {}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(KVCacheParamsKernel::CanSupport(launchParam), "failed to check support", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.format == TENSOR_FORMAT_FRACTAL_NZ,
"in tensor 0 format is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 4,
"in tensor 0 dims num is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(2).desc.dims.size() == 5,
"in tensor 2 dims num is invalid", return false);
return true;
}
Status InitImpl(const LaunchParam &launchParam) override
{
auto status = KVCacheNzTiling<true>(launchParam, kernelInfo_);
MKI_CHECK_NO_LOG(status.Ok(), return status);
status = KVCacheParamsKernel::InitImpl(launchParam);
MKI_CHECK_NO_LOG(status.Ok(), return status);
return Status::OkStatus();
}
};
REG_KERNEL_BASE(KVCacheNzParamsKernel);
class KVCacheNdDynamicBatchParamsKernel : public KVCacheParamsKernel {
public:
explicit KVCacheNdDynamicBatchParamsKernel(const std::string &kernelName, const BinHandle *handle) noexcept
: KVCacheParamsKernel(kernelName, handle) {}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(KVCacheParamsKernel::CanSupport(launchParam), "failed to check support", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.format == TENSOR_FORMAT_ND,
"in tensor 0 format is invalid", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 2,
"in tensor 0 dims num is invalid", return false);
return true;
}
uint64_t GetTilingSize(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::KVCache),
"OpParam is invalid", return 0);
auto param = AnyCast<OpParam::KVCache>(launchParam.GetParam());
size_t constTensorSize = Utils::GetConstTensorSize<int32_t>(param.batchRunStatus);
size_t maxVal = std::numeric_limits<uint32_t>::max();
MKI_CHECK(maxVal - constTensorSize >= KVCacheParamsKernel::GetTilingSize(launchParam),
"batchRunStatus Size is invalid", return 0);
return KVCacheParamsKernel::GetTilingSize(launchParam) + constTensorSize;
}
Status InitImpl(const LaunchParam &launchParam) override
{
auto status = KVCacheTiling<true>(launchParam, kernelInfo_);
MKI_CHECK_NO_LOG(status.Ok(), return status);
status = KVCacheParamsKernel::InitImpl(launchParam);
MKI_CHECK_NO_LOG(status.Ok(), return status);
auto ¶m = AnyCast<OpParam::KVCache>(launchParam.GetParam());
auto ret = kernelInfo_.AddConstTensorData<int32_t>(TENSOR_BATCH_STATUS_IDX, param.batchRunStatus);
MKI_CHECK_NO_LOG(ret, return Status::FailStatus(1));
return Status::OkStatus();
}
};
REG_KERNEL_BASE(KVCacheNdDynamicBatchParamsKernel);
}