* Copyright 2026 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "mfusion/Dialect/Mfuse/IR/Mfuse.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Support/LogicalResult.h"
namespace mlir {
namespace mfuse {
mlir::LogicalResult AclnnVarMeanOp::verify() {
auto inputType = mlir::dyn_cast<mlir::RankedTensorType>(getSelf().getType());
if (!inputType || !inputType.hasRank()) {
return emitOpError("input must be a ranked tensor");
}
int64_t rank = inputType.getRank();
auto dimAttr = getDim();
llvm::DenseSet<int64_t> seenDims;
for (auto dimAttrVal : dimAttr.getValue()) {
auto dim = mlir::cast<mlir::IntegerAttr>(dimAttrVal).getValue().getSExtValue();
int64_t actualDim = dim < 0 ? rank + dim : dim;
if (actualDim < 0) {
return emitOpError("dimension out of range, got ") << dim;
}
if (actualDim >= rank) {
return emitOpError("dimension out of range, got ") << dim << " for input rank " << rank;
}
if (!seenDims.insert(actualDim).second) {
return emitOpError("duplicate reduction dimensions are not supported, got ") << dim;
}
}
auto varianceOutType = mlir::dyn_cast<mlir::RankedTensorType>(getVarianceOut().getType());
auto meanOutType = mlir::dyn_cast<mlir::RankedTensorType>(getMeanOut().getType());
if (!varianceOutType || !meanOutType) {
return emitOpError("results must be ranked tensors");
}
if (!mlir::isa<mlir::FloatType>(varianceOutType.getElementType()) ||
!mlir::isa<mlir::FloatType>(meanOutType.getElementType())) {
return emitOpError("result element types must be floating point");
}
auto correction = getCorrection();
if (correction < 0) {
return emitOpError("correction must be non-negative, got ") << correction;
}
return mlir::success();
}
}
}