* This file is part of the OpenBOAT project at Harbin Institute of Technology (HIT)
* and is contributed to the CANN Open Software.
*
* Copyright (c) 2026 AISS Group, Harbin Institute of Technology (HIT).
* All Rights Reserved.
*
* Authors (accounts):
* - Shi Xiangyang <@shi-xiangyang225>
* - Su Tonghua <@sutonghua>
*
* This program is free software: you can redistribute it and/or modify it.
* Licensed under the CANN Open Software License Agreement Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* See the LICENSE file at the root of the repository for the full text of the License.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTIES OF ANY KIND, EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
*/
* \file split_infer.cpp
* \brief
*/
#include "register/op_impl_registry.h"
#include "log/log.h"
using namespace ge;
namespace ops {
constexpr size_t GMM_INDEX_IN_X = 0;
constexpr size_t GMM_INDEX_IN_INDICES = 1;
constexpr size_t GMM_INDEX_OUT_Y = 0;
constexpr size_t GMM_MIN_FM_DIM = 1;
constexpr size_t GMM_NORM_AXIS = -1;
constexpr size_t ATTRPOS0 = 0;
constexpr size_t ATTRPOS1 = 1;
constexpr int32_t INDICES_LIMIT = 10;
struct SplitSetOutputParams {
int64_t indicesOrSections[INDICES_LIMIT];
int64_t axis;
};
static ge::graphStatus UpdateShapeY(gert::InferShapeContext* context, size_t idxY, const std::vector<int64_t>& yDims) {
gert::Shape* yShape = context->GetOutputShape(idxY);
OP_CHECK_NULL_WITH_CONTEXT(context, yShape);
yShape->SetDimNum(yDims.size());
for (size_t dim = 0; dim < yDims.size(); ++dim) {
yShape->SetDim(dim, yDims[dim]);
}
return GRAPH_SUCCESS;
}
static ge::graphStatus ComputeSplitGroups(gert::InferShapeContext* context,
int64_t xAxisLen,
int64_t* indicesOrSections,
bool& isUniformSplit,
int64_t& groupNum,
int64_t& splitSize) {
isUniformSplit = true;
groupNum = 0;
splitSize = 0;
for (int i = 1; (indicesOrSections[i] != 0) && (i < INDICES_LIMIT); ++i) {
if (indicesOrSections[i] != 0) {
isUniformSplit = false;
break;
}
}
if (isUniformSplit) {
groupNum = indicesOrSections[0];
if (groupNum == 0) {
OP_LOGE(context->GetNodeName(), "groupNum must be greater than 0.");
return GRAPH_FAILED;
}
if (xAxisLen % groupNum != 0) {
OP_LOGE(context->GetNodeName(),
"When splitting uniformly, xAxisLen %ld must be divisible by groupNum %ld.",
xAxisLen, groupNum);
return GRAPH_FAILED;
}
splitSize = xAxisLen / groupNum;
} else {
int numIndices = 0;
for (int i = 0; indicesOrSections[i] != 0; ++i) {
if (indicesOrSections[i] > xAxisLen) {
OP_LOGE(context->GetNodeName(), "Index %ld exceeds xAxisLen %ld.", indicesOrSections[i], xAxisLen);
return GRAPH_FAILED;
}
numIndices++;
}
groupNum = numIndices + 1;
}
return GRAPH_SUCCESS;
}
static ge::graphStatus UpdateMultipleShapeYNorm(gert::InferShapeContext* context,
int64_t* indicesOrSections) {
const gert::Shape* xShape = context->GetDynamicInputShape(GMM_INDEX_IN_X, 0);
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
int64_t xAxisLen = 1;
for (size_t dim = 0; dim < xShape->GetDimNum(); ++dim) {
xAxisLen *= xShape->GetDim(dim);
}
int64_t groupNum = 0;
bool isUniformSplit = true;
int64_t splitSize = 0;
OP_CHECK_IF(ComputeSplitGroups(context, xAxisLen, indicesOrSections, isUniformSplit, groupNum, splitSize) != GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "ComputeSplitGroups failed."),
return GRAPH_FAILED);
int64_t preOffset = 0;
for (int idx = 0; idx < groupNum; ++idx) {
int64_t currentSize = 0;
if (isUniformSplit) {
currentSize = splitSize;
} else {
if (idx == groupNum - 1) {
currentSize = xAxisLen - indicesOrSections[idx - 1];
} else {
currentSize = indicesOrSections[idx] - preOffset;
preOffset = indicesOrSections[idx];
}
}
std::vector<int64_t> yDims;
yDims.push_back(currentSize);
OP_CHECK_IF(UpdateShapeY(context, GMM_INDEX_OUT_Y + idx, yDims) != GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "Failed to update shape of y[%d].", idx),
return GRAPH_FAILED);
}
return GRAPH_SUCCESS;
}
static ge::graphStatus UpdateMultipleShapeY(gert::InferShapeContext* context,
int64_t* indicesOrSections, int64_t axis) {
const gert::Shape* xShape = context->GetDynamicInputShape(GMM_INDEX_IN_X, 0);
OP_CHECK_NULL_WITH_CONTEXT(context, xShape);
int64_t xAxisLen = xShape->GetDim(axis);
int64_t groupNum = 0;
bool isUniformSplit = true;
int64_t splitSize = 0;
OP_CHECK_IF(ComputeSplitGroups(context, xAxisLen, indicesOrSections, isUniformSplit, groupNum, splitSize) != GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "ComputeSplitGroups failed."),
return GRAPH_FAILED);
int64_t preOffset = 0;
for (int idx = 0; idx < groupNum; ++idx) {
int64_t currentSize = 0;
if (isUniformSplit) {
currentSize = splitSize;
} else {
if (idx == groupNum - 1) {
currentSize = xAxisLen - indicesOrSections[idx - 1];
} else {
currentSize = indicesOrSections[idx] - preOffset;
preOffset = indicesOrSections[idx];
}
}
std::vector<int64_t> yDims;
for (size_t i = 0; i < xShape->GetDimNum(); ++i) {
if (i == static_cast<size_t>(axis)) {
yDims.push_back(currentSize);
} else {
yDims.push_back(xShape->GetDim(i));
}
}
OP_CHECK_IF(UpdateShapeY(context, GMM_INDEX_OUT_Y + idx, yDims) != GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "Failed to update shape of y[%d].", idx),
return GRAPH_FAILED);
}
return GRAPH_SUCCESS;
}
static ge::graphStatus XSingleYSeparatedNorm(gert::InferShapeContext* context,
int64_t* indicesOrSections) {
OP_CHECK_IF(UpdateMultipleShapeYNorm(context, indicesOrSections) != GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "Failed to update shape of y."),
return GRAPH_FAILED);
return GRAPH_SUCCESS;
}
static ge::graphStatus XSingleYSeparated(gert::InferShapeContext* context,
int64_t* indicesOrSections, int64_t axis) {
OP_CHECK_IF(UpdateMultipleShapeY(context, indicesOrSections, axis) != GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "Failed to update shape of y."),
return GRAPH_FAILED);
return GRAPH_SUCCESS;
}
static ge::graphStatus SplitSetOutputShape(gert::InferShapeContext* context,
const SplitSetOutputParams& outputParams) {
int64_t indicesOrSections[INDICES_LIMIT];
for(int i = 0; i < INDICES_LIMIT; i ++){
indicesOrSections[i] = outputParams.indicesOrSections[i];
}
int64_t axis = outputParams.axis;
if(axis == (int64_t)GMM_NORM_AXIS) {
OP_CHECK_IF(XSingleYSeparatedNorm(context, indicesOrSections) != GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "Failed to update shape of y."), return GRAPH_FAILED);
} else {
OP_CHECK_IF(XSingleYSeparated(context, indicesOrSections, axis ) != GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "Failed to update shape of y."), return GRAPH_FAILED);
}
return GRAPH_SUCCESS;
}
static ge::graphStatus InferShapeSplit(gert::InferShapeContext* context) {
auto attrs = context->GetAttrs();
int64_t axis = -1;
int64_t* src_indices = nullptr;
if(attrs) {
const int64_t* src_indices_const = nullptr;
if (attrs->GetInt(ATTRPOS0) != nullptr) {
src_indices_const = attrs->GetInt(ATTRPOS0);
}
src_indices = const_cast<int64_t*>(src_indices_const);
if (attrs->GetInt(ATTRPOS1) != nullptr) {
axis = *(attrs->GetInt(ATTRPOS1));
}
}
SplitSetOutputParams outputParams;
uint32_t len = sizeof(outputParams.indicesOrSections) / sizeof(outputParams.indicesOrSections[0]);
for(uint32_t i = 0; i < len ; i ++){
outputParams.indicesOrSections[i] = src_indices[i];
}
outputParams.axis = axis;
OP_CHECK_IF(SplitSetOutputShape(context, outputParams) != GRAPH_SUCCESS,
OP_LOGE(context->GetNodeName(), "SplitSetOutputShape failed"), return GRAPH_FAILED);
return GRAPH_SUCCESS;
}
IMPL_OP_INFERSHAPE(Split).InferShape(InferShapeSplit);
}