* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file mhc_pre_infershape.cpp
* \brief
*/
#include <map>
#include <string>
#include <sstream>
#include <initializer_list>
#include "exe_graph/runtime/infer_shape_context.h"
#include "exe_graph/runtime/shape.h"
#include "exe_graph/runtime/storage_shape.h"
#include "register/op_impl_registry.h"
#include "log/log.h"
#include "err/ops_err.h"
using namespace gert;
using namespace ge;
namespace ops {
const constexpr int64_t X_INDEX = 0;
const constexpr int64_t PHI_INDEX = 1;
const constexpr int64_t OUT_H_IN_INDEX = 0;
const constexpr int64_t OUT_H_POST_INDEX = 1;
const constexpr int64_t OUT_H_RES_INDEX = 2;
const constexpr int64_t OUT_INV_RMS_INDEX = 3;
const constexpr int64_t OUT_MM_RES_INDEX = 4;
const constexpr int64_t OUT_H_PRE_INDEX = 5;
const constexpr int64_t BSND_DIM_NUM = 4;
const constexpr int64_t TND_DIM_NUM = 3;
const constexpr int64_t UNKNOWN_DIM_VALUE = -1LL;
static void SetShape3D(gert::Shape *shape, uint64_t d0, uint64_t d1, uint64_t d2)
{
shape->SetDimNum(3);
shape->SetDim(0, d0);
shape->SetDim(1, d1);
shape->SetDim(2, d2);
}
static void SetShape2D(gert::Shape *shape, uint64_t d0, uint64_t d1)
{
shape->SetDimNum(2);
shape->SetDim(0, d0);
shape->SetDim(1, d1);
}
static void SetShape4D(gert::Shape *shape, uint64_t d0, uint64_t d1, uint64_t d2, uint64_t d3)
{
shape->SetDimNum(4);
shape->SetDim(0, d0);
shape->SetDim(1, d1);
shape->SetDim(2, d2);
shape->SetDim(3, d3);
}
static void SetShape1D(gert::Shape *shape, uint64_t d0)
{
shape->SetDimNum(1);
shape->SetDim(0, d0);
}
static void SetShapeFromX(gert::Shape *dst, const gert::Shape *src)
{
dst->SetDimNum(src->GetDimNum());
for (int64_t i = 0; i < src->GetDimNum(); ++i) {
dst->SetDim(i, src->GetDim(i));
}
}
static ge::graphStatus InferShape4MhcPre(InferShapeContext *context)
{
OP_LOGD(context->GetNodeName(), "Begin to do InferShape MhcPre");
const gert::Shape *xShape = context->GetDynamicInputShape(X_INDEX, 0);
const gert::Shape *phiShape = context->GetDynamicInputShape(PHI_INDEX, 0);
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
OP_CHECK_NULL_WITH_CONTEXT(context, phiShape);
int64_t phiDim = phiShape->GetDimNum();
int64_t xDim = xShape->GetDimNum();
OP_CHECK_IF(phiDim != 2, OP_LOGE(context->GetNodeName(), "phiShapeDim should be 2, but got %ld", phiDim),
return GRAPH_FAILED);
OP_CHECK_IF(
xDim != BSND_DIM_NUM && xDim != TND_DIM_NUM,
OP_LOGE(context->GetNodeName(), "xShapeDim should be %ld or %ld, but got %ld", BSND_DIM_NUM, TND_DIM_NUM, xDim),
return GRAPH_FAILED);
uint64_t matK = phiShape->GetDim(0);
gert::Shape *outShapes[6] = {context->GetOutputShape(OUT_H_IN_INDEX), context->GetOutputShape(OUT_H_POST_INDEX),
context->GetOutputShape(OUT_H_RES_INDEX), context->GetOutputShape(OUT_INV_RMS_INDEX),
context->GetOutputShape(OUT_MM_RES_INDEX), context->GetOutputShape(OUT_H_PRE_INDEX)};
if (xDim == BSND_DIM_NUM) {
uint64_t b = xShape->GetDim(0), s = xShape->GetDim(1), n = xShape->GetDim(2), d = xShape->GetDim(3);
SetShape3D(outShapes[0], b, s, d);
SetShape3D(outShapes[1], b, s, n);
SetShape4D(outShapes[2], b, s, n, n);
SetShape2D(outShapes[3], b, s);
SetShape3D(outShapes[4], b, s, matK);
SetShape3D(outShapes[5], b, s, n);
} else {
uint64_t t = xShape->GetDim(0), n = xShape->GetDim(1), d = xShape->GetDim(2);
SetShape2D(outShapes[0], t, d);
SetShape2D(outShapes[1], t, n);
SetShape3D(outShapes[2], t, n, n);
SetShape1D(outShapes[3], t);
SetShape2D(outShapes[4], t, matK);
SetShape2D(outShapes[5], t, n);
}
OP_LOGD(context->GetNodeName(), "End to do InferShape MhcPre");
return GRAPH_SUCCESS;
}
static graphStatus InferDataType4MhcPre(gert::InferDataTypeContext *context)
{
const auto xDtype = context->GetInputDataType(X_INDEX);
context->SetOutputDataType(OUT_H_IN_INDEX, xDtype);
context->SetOutputDataType(OUT_H_POST_INDEX, DataType::DT_FLOAT);
context->SetOutputDataType(OUT_H_RES_INDEX, DataType::DT_FLOAT);
context->SetOutputDataType(OUT_INV_RMS_INDEX, DataType::DT_FLOAT);
context->SetOutputDataType(OUT_MM_RES_INDEX, DataType::DT_FLOAT);
context->SetOutputDataType(OUT_H_PRE_INDEX, DataType::DT_FLOAT);
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(MhcPre).InferShape(InferShape4MhcPre).InferDataType(InferDataType4MhcPre);
}