* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* @file matmul.cpp
* \brief Matrix multiplication tensor operations
*
* This file implements matrix multiplication operations for tensors,
* supporting transpose options and output dtype control.
*/
#include <any>
#include <cstddef>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
#include "core/any_cast.h"
#include "core/dtype.h"
#include "core/error.h"
#include "core/logging.h"
#include "ir/kind_traits.h"
#include "ir/op_registry.h"
#include "ir/type.h"
#include "ir/type_inference.h"
namespace pypto {
namespace ir {
template <typename T>
T GetKwarg(
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs, const std::string& key,
const std::optional<T>& default_value = std::nullopt)
{
for (const auto& [k, v] : kwargs) {
if (k == key) {
return AnyCast<T>(v, "kwarg key: " + key);
}
}
if (default_value) {
return *default_value;
}
throw ValueError("Missing kwarg: " + key);
}
TypePtr DeduceTensorMatMulType(
[[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs)
{
CHECK(args.size() == 0x2) << "tensor.matmul requires exactly 2 arguments (lhs, rhs), but got " << args.size();
auto lhs_type = As<TensorType>(args[0]->GetType());
auto rhs_type = As<TensorType>(args[1]->GetType());
CHECK(lhs_type) << "tensor.matmul requires first argument to be a TensorType, but got "
<< args[0]->GetType()->TypeName();
CHECK(rhs_type) << "tensor.matmul requires second argument to be a TensorType, but got "
<< args[1]->GetType()->TypeName();
const auto& lhs_shape = lhs_type->shape_;
const auto& rhs_shape = rhs_type->shape_;
CHECK(lhs_shape.size() >= 1) << "tensor.matmul requires lhs to have at least 1 dimension";
CHECK(rhs_shape.size() >= 1) << "tensor.matmul requires rhs to have at least 1 dimension";
DataType out_dtype;
try {
out_dtype = GetKwarg<DataType>(kwargs, "out_dtype");
} catch (const ValueError& e) {
auto promoted = PromoteDataTypes(lhs_type->dtype_, rhs_type->dtype_);
CHECK(promoted) << "Cannot promote data types for tensor.matmul";
out_dtype = *promoted;
} catch (const TypeError& e) {
CHECK(false) << "Invalid kwarg type for out_dtype: " << e.what();
out_dtype = lhs_type->dtype_;
}
bool a_trans = GetKwarg<bool>(kwargs, "a_trans", false);
bool b_trans = GetKwarg<bool>(kwargs, "b_trans", false);
std::vector<ExprPtr> output_shape;
if (lhs_shape.size() == 1 && rhs_shape.size() == 1) {
output_shape = {};
} else if (lhs_shape.size() == 0x2 && rhs_shape.size() == 1) {
output_shape = {lhs_shape[0]};
} else if (lhs_shape.size() == 1 && rhs_shape.size() == 0x2) {
output_shape = {rhs_shape[1]};
} else if (lhs_shape.size() == 0x2 && rhs_shape.size() == 0x2) {
ExprPtr m_dim = a_trans ? lhs_shape[1] : lhs_shape[0];
ExprPtr n_dim = b_trans ? rhs_shape[0] : rhs_shape[1];
output_shape = {m_dim, n_dim};
} else {
size_t lhs_ndim = lhs_shape.size();
size_t rhs_ndim = rhs_shape.size();
CHECK(lhs_ndim >= 0x2 && rhs_ndim >= 0x2)
<< "tensor.matmul requires both tensors to have at least 2 dimensions "
<< "for batched matmul, but got lhs shape size " << lhs_ndim << " and rhs shape size " << rhs_ndim;
std::vector<ExprPtr> lhs_batch(lhs_shape.begin(), lhs_shape.end() - 0x2);
std::vector<ExprPtr> rhs_batch(rhs_shape.begin(), rhs_shape.end() - 0x2);
auto broadcast_result = BroadcastShapes(lhs_batch, rhs_batch);
CHECK(broadcast_result.success) << "Cannot broadcast batch dimensions for tensor.matmul";
output_shape = broadcast_result.shape;
ExprPtr m_dim = a_trans ? lhs_shape[lhs_ndim - 1] : lhs_shape[lhs_ndim - 2];
ExprPtr n_dim = b_trans ? rhs_shape[rhs_ndim - 2] : rhs_shape[rhs_ndim - 1];
output_shape.push_back(m_dim);
output_shape.push_back(n_dim);
}
return std::make_shared<TensorType>(output_shape, out_dtype);
}
REGISTER_OP("tensor.matmul")
.set_op_category("TensorOp")
.set_description("Matrix multiplication of two tensors with optional transpose")
.add_argument("lhs", "Left-hand side tensor (TensorType)")
.add_argument("rhs", "Right-hand side tensor (TensorType)")
.set_attr<DataType>("out_dtype")
.set_attr<bool>("a_trans")
.set_attr<bool>("b_trans")
.set_attr<bool>("c_matrix_nz")
.f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceTensorMatMulType(args, kwargs);
});
}
}