* Copyright (c) 2024 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <mki/kernel_info.h>
#include <mki/base/kernel_base.h>
#include <mki_loader/op_register.h>
#include <mki/utils/log/log.h>
#include <mki/utils/math/math.h>
#include <mki/utils/math/tensor_utils.h>
#include <mki/utils/checktensor/check_tensor.h>
#include <mki/utils/platform/platform_info.h>
#include "atbops/params/params.h"
#include "mixkernels/pagedattention/tiling/paged_attention_tiling.h"
#include "mixkernels/pagedattention/tiling/paged_attention_tiling_dependency.h"
#include "mixkernels/utils/common.h"
namespace AtbOps {
using namespace Mki;
constexpr uint32_t TILINGMIN = 512;
class PagedAttentionKernel : public KernelBase {
public:
explicit PagedAttentionKernel(const std::string &kernelName, const BinHandle *handle)
: KernelBase(kernelName, handle)
{
launchBufferSize_ = Utils::RoundUp((TILING_PARA_SIZE + TILING_HEAD_SIZE) * sizeof(uint32_t), TILINGMIN);
}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetOutTensorCount() == 1, "out tensor num invalid", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 3, "in tensor0 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(1).desc.dims.size() == 4, "in tensor1 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(2).desc.dims.size() == 4, "in tensor2 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(3).desc.dims.size() == 2, "in tensor3 dims invalid", return false);
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::PagedAttention),
"paged attention: param type invalid", return false);
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
if (param.dataShapeType == OpParam::PagedAttention::DataShapeType::BNSD) {
MKI_CHECK(param.quantType == OpParam::PagedAttention::QuantType::TYPE_QUANT_UNDEFINED &&
!param.compressHead && param.scaleType == OpParam::PagedAttention::ScaleType::SCALE_TOR,
"BNSD does not support quant,compressHead and logn", return false);
}
return true;
}
Status InitImpl(const LaunchParam &launchParam) override
{
auto status = PagedAttentionTiling(launchParam, kernelInfo_);
MKI_CHECK_NO_LOG(status.Ok(), return status);
kernelInfo_.SetHwsyncIdx(0);
return Status::OkStatus();
}
uint64_t GetTilingSize(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::PagedAttention),
"paged attention: param type invalid", return false);
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
auto batch = param.kvSeqLen.size();
MKI_CHECK(batch > 0 && batch <= ND_BATCH_LIMIT, "batch is invalid", return 0);
uint64_t bufferSize =
Utils::RoundUp(launchBufferSize_ + TILING_PARA_SIZE * (batch - 1) * sizeof(uint32_t), TILINGMIN);
return bufferSize;
}
};
constexpr int32_t QUANT_EYE_CONST_TENSOR_IDX = 13;
class PagedAttentionMaskNdKernel : public PagedAttentionKernel {
public:
explicit PagedAttentionMaskNdKernel(const std::string &kernelName, const BinHandle *handle)
: PagedAttentionKernel(kernelName, handle)
{
}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetInTensorCount() == 12, "in tensor num invalid", return false);
MKI_CHECK(PagedAttentionKernel::CanSupport(launchParam), "failed to check support", return false);
return true;
}
uint64_t GetTilingSize(const LaunchParam &launchParam) const override
{
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
return PagedAttentionKernel::GetTilingSize(launchParam) + Utils::GetConstTensorSize<int8_t>(param.identityM);
}
Status InitImpl(const LaunchParam &launchParam) override
{
Status st = PagedAttentionKernel::InitImpl(launchParam);
MKI_CHECK_NO_LOG(st.Ok(), return st);
kernelInfo_.SetConstTensorOffset(PagedAttentionKernel::GetTilingSize(launchParam));
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
kernelInfo_.AddConstTensorData<int8_t>(QUANT_EYE_CONST_TENSOR_IDX, param.identityM);
return Status::OkStatus();
}
};
class PagedMultiLatentAttentionSplitCacheMaskNdKernel : public PagedAttentionKernel {
public:
explicit PagedMultiLatentAttentionSplitCacheMaskNdKernel(const std::string &kernelName, const BinHandle *handle)
: PagedAttentionKernel(kernelName, handle)
{
}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetInTensorCount() == 12, "in tensor num invalid", return false);
MKI_CHECK(PagedAttentionKernel::CanSupport(launchParam), "failed to check support", return false);
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
MKI_CHECK(param.dataShapeType == OpParam::PagedAttention::DataShapeType::BSND, "MLA only supports BSND",
return false);
return true;
}
uint64_t GetTilingSize(const LaunchParam &launchParam) const override
{
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
return PagedAttentionKernel::GetTilingSize(launchParam) + Utils::GetConstTensorSize<int8_t>(param.identityM);
}
Status InitImpl(const LaunchParam &launchParam) override
{
Status st = PagedAttentionKernel::InitImpl(launchParam);
MKI_CHECK_NO_LOG(st.Ok(), return st);
kernelInfo_.SetConstTensorOffset(PagedAttentionKernel::GetTilingSize(launchParam));
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
kernelInfo_.AddConstTensorData<int8_t>(QUANT_EYE_CONST_TENSOR_IDX, param.identityM);
return Status::OkStatus();
}
};
constexpr int32_t QUANT_EYE_CONST_TENSOR_IDX_MLA = 10;
class PagedMultiLatentAttentionCombineCacheMaskNdKernel : public PagedAttentionKernel {
public:
explicit PagedMultiLatentAttentionCombineCacheMaskNdKernel(const std::string &kernelName, const BinHandle *handle)
: PagedAttentionKernel(kernelName, handle)
{
}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetInTensorCount() == 9, "in tensor num invalid", return false);
MKI_CHECK(launchParam.GetOutTensorCount() == 1, "out tensor num invalid", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 3, "in tensor0 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(1).desc.dims.size() == 4, "in tensor1 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(2).desc.dims.size() == 2, "in tensor2 dims invalid", return false);
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::PagedAttention),
"paged attention: param type invalid", return false);
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
MKI_CHECK(param.dataShapeType == OpParam::PagedAttention::DataShapeType::BSND, "MLA only supports BSND",
return false);
return true;
}
uint64_t GetTilingSize(const LaunchParam &launchParam) const override
{
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
return PagedAttentionKernel::GetTilingSize(launchParam) + Utils::GetConstTensorSize<int8_t>(param.identityM);
}
Status InitImpl(const LaunchParam &launchParam) override
{
Status st = PagedAttentionKernel::InitImpl(launchParam);
MKI_CHECK_NO_LOG(st.Ok(), return st);
kernelInfo_.SetConstTensorOffset(PagedAttentionKernel::GetTilingSize(launchParam));
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
kernelInfo_.AddConstTensorData<int8_t>(QUANT_EYE_CONST_TENSOR_IDX_MLA, param.identityM);
return Status::OkStatus();
}
};
class PagedMultiLatentAttentionMultiTokenPredictionMaskNdKernel : public PagedAttentionKernel {
public:
explicit PagedMultiLatentAttentionMultiTokenPredictionMaskNdKernel(const std::string &kernelName,
const BinHandle *handle)
: PagedAttentionKernel(kernelName, handle)
{
}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetInTensorCount() == 4, "in tensor num invalid", return false);
MKI_CHECK(launchParam.GetOutTensorCount() == 1, "out tensor num invalid", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 3, "in tensor0 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(1).desc.dims.size() == 4, "in tensor1 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(2).desc.dims.size() == 2, "in tensor2 dims invalid", return false);
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::PagedAttention),
"paged attention: param type invalid", return false);
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
MKI_CHECK(param.dataShapeType == OpParam::PagedAttention::DataShapeType::BSND, "MLA only supports BSND",
return false);
return true;
}
uint64_t GetTilingSize(const LaunchParam &launchParam) const override
{
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
return PagedAttentionKernel::GetTilingSize(launchParam) + Utils::GetConstTensorSize<int8_t>(param.identityM);
}
Status InitImpl(const LaunchParam &launchParam) override
{
Status st = PagedAttentionKernel::InitImpl(launchParam);
MKI_CHECK_NO_LOG(st.Ok(), return st);
kernelInfo_.SetConstTensorOffset(PagedAttentionKernel::GetTilingSize(launchParam));
return Status::OkStatus();
}
};
class PagedAttentionNzBaseKernel : public KernelBase {
public:
explicit PagedAttentionNzBaseKernel(const std::string &kernelName, const BinHandle *handle)
: KernelBase(kernelName, handle)
{
is910A_ = PlatformInfo::Instance().GetPlatformType() == PlatformType::ASCEND_910A;
tilingHeadSize_ = is910A_ ? TILING_HEAD_SIZE_910A : TILING_HEAD_SIZE_NZ;
launchBufferSize_ = Utils::RoundUp(tilingHeadSize_ * sizeof(uint32_t), TILINGMIN);
}
uint64_t GetTilingSize(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::PagedAttention),
"paged attention: param type invalid", return false);
auto param = AnyCast<OpParam::PagedAttention>(launchParam.GetParam());
auto batch = param.qSeqLen.size();
MKI_CHECK(batch <= ND_BATCH_LIMIT, "batch is invalid", return 0);
uint64_t bufferSize =
Utils::RoundUp((TILING_PARA_SIZE_NZ * batch + tilingHeadSize_) * sizeof(uint32_t), TILINGMIN);
return bufferSize;
}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetOutTensorCount() == 1, "out tensor num invalid", return false);
MKI_CHECK(launchParam.GetParam().Type() == typeid(OpParam::PagedAttention),
"paged attention: param type invalid", return false);
MKI_CHECK(launchParam.GetInTensor(0).desc.dims.size() == 4, "in tensor0 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(1).desc.dims.size() == 4, "in tensor1 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(2).desc.dims.size() == 4, "in tensor2 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(3).desc.dims.size() == 2, "in tensor3 dims invalid", return false);
MKI_CHECK(launchParam.GetInTensor(4).desc.dims.size() == 1, "in tensor4 dims invalid", return false);
return true;
}
Status InitImpl(const LaunchParam &launchParam) override
{
auto status = PagedAttentionTiling(launchParam, kernelInfo_);
MKI_CHECK_NO_LOG(status.Ok(), return status);
return Status::OkStatus();
}
private:
bool is910A_ = false;
int32_t tilingHeadSize_ = 0;
};
class PagedAttentionDecoderNzMaskKernel : public PagedAttentionNzBaseKernel {
public:
explicit PagedAttentionDecoderNzMaskKernel(const std::string &kernelName, const BinHandle *handle)
: PagedAttentionNzBaseKernel(kernelName, handle) {}
bool CanSupport(const LaunchParam &launchParam) const override
{
MKI_CHECK(launchParam.GetInTensorCount() == 7, "in tensor num invalid", return false);
MKI_CHECK(PagedAttentionNzBaseKernel::CanSupport(launchParam), "failed to check support", return false);
return true;
}
};
REG_KERNEL_BASE(PagedAttentionDecoderNzMaskKernel);
REG_KERNEL_BASE(PagedAttentionMaskNdKernel);
REG_KERNEL_BASE(PagedMultiLatentAttentionSplitCacheMaskNdKernel);
REG_KERNEL_BASE(PagedMultiLatentAttentionCombineCacheMaskNdKernel);
REG_KERNEL_BASE(PagedMultiLatentAttentionMultiTokenPredictionMaskNdKernel);
}