* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "register/op_def_registry.h"
#include "graph/utils/type_utils.h"
namespace {
constexpr int OUTPUTINDEX0 = 0;
constexpr int OUTPUTINDEX1 = 1;
constexpr int OUTPUTINDEX2 = 2;
constexpr int SEQ_LEN_DIM = 2;
constexpr int HEAD_DIM_DIM = 3;
constexpr int INPUT_HEAD_NUM_DIM = 2;
constexpr int INPUT_HEAD_DIM_DIM = 3;
}
namespace ops {
static ge::graphStatus LaPreprocessInferShape(gert::InferShapeContext *context)
{
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
const gert::Shape *qShape = context->GetInputShape(0);
const gert::Shape *kShape = context->GetInputShape(1);
const gert::Shape *vShape = context->GetInputShape(2);
gert::Shape *outQShape = context->GetOutputShape(0);
gert::Shape *outKShape = context->GetOutputShape(1);
gert::Shape *outVShape = context->GetOutputShape(2);
if (qShape == nullptr || kShape == nullptr || vShape == nullptr ||
outQShape == nullptr || outKShape == nullptr || outVShape == nullptr) {
return ge::GRAPH_FAILED;
}
const auto attrs = context->GetAttrs();
if (attrs == nullptr) {
return ge::GRAPH_FAILED;
}
auto alignLen = *context->GetAttrs()->GetAttrPointer<int32_t>(0);
outQShape->SetDimNum(qShape->GetDimNum());
outQShape->SetDim(0, qShape->GetDim(0));
outQShape->SetDim(1, qShape->GetDim(2));
int32_t qPadDim = (qShape->GetDim(1) + alignLen - 1) / alignLen * alignLen;
outQShape->SetDim(SEQ_LEN_DIM, qPadDim);
outQShape->SetDim(HEAD_DIM_DIM, qShape->GetDim(INPUT_HEAD_DIM_DIM));
outKShape->SetDimNum(kShape->GetDimNum());
outKShape->SetDim(0, kShape->GetDim(0));
outKShape->SetDim(1, kShape->GetDim(INPUT_HEAD_NUM_DIM));
int32_t kPadDim = (kShape->GetDim(1) + alignLen - 1) / alignLen * alignLen;
outKShape->SetDim(SEQ_LEN_DIM, kPadDim);
outKShape->SetDim(HEAD_DIM_DIM, kShape->GetDim(INPUT_HEAD_DIM_DIM));
outVShape->SetDimNum(vShape->GetDimNum());
outVShape->SetDim(0, vShape->GetDim(0));
outVShape->SetDim(1, vShape->GetDim(INPUT_HEAD_NUM_DIM));
int32_t vPadDim = (vShape->GetDim(1) + alignLen - 1) / alignLen * alignLen;
outVShape->SetDim(SEQ_LEN_DIM, vPadDim);
outVShape->SetDim(HEAD_DIM_DIM, vShape->GetDim(INPUT_HEAD_DIM_DIM));
return ge::GRAPH_SUCCESS;
}
static ge::graphStatus LaPreprocessInferDtype(gert::InferDataTypeContext *context)
{
if (context == nullptr) {
return ge::GRAPH_FAILED;
}
context->SetOutputDataType(OUTPUTINDEX0, ge::DT_FLOAT16);
context->SetOutputDataType(OUTPUTINDEX1, ge::DT_FLOAT16);
context->SetOutputDataType(OUTPUTINDEX2, ge::DT_FLOAT16);
return ge::GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(LaPreprocess)
.InferShape(LaPreprocessInferShape)
.InferDataType(LaPreprocessInferDtype);
}