* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#include <any>
#include <cctype>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "core/any_cast.h"
#include "core/dtype.h"
#include "core/logging.h"
#include "ir/expr.h"
#include "ir/kind_traits.h"
#include "ir/op_registry.h"
#include "ir/scalar_expr.h"
#include "ir/type.h"
namespace pypto {
namespace ir {
namespace {
bool IsSupportedPrintfConversion(char conversion)
{
return conversion == 'd' || conversion == 'i' || conversion == 'u' || conversion == 'x' || conversion == 'f';
}
std::vector<char> ParsePrintfConversions(const std::string& format)
{
std::vector<char> conversions;
size_t i = 0;
while (i < format.size()) {
if (format[i] != '%') {
++i;
continue;
}
if (i + 1 < format.size() && format[i + 1] == '%') {
CHECK(false) << "debug.printf does not support literal '%%'";
}
size_t j = i + 1;
while (j < format.size()) {
char c = format[j];
if (c == '-' || c == '+' || c == ' ' || c == '#' || c == '0') {
++j;
} else {
break;
}
}
while (j < format.size() && std::isdigit(static_cast<unsigned char>(format[j]))) {
++j;
}
if (j < format.size() && format[j] == '.') {
++j;
CHECK(j < format.size() && std::isdigit(static_cast<unsigned char>(format[j])))
<< "debug.printf precision must be followed by digits";
while (j < format.size() && std::isdigit(static_cast<unsigned char>(format[j]))) {
++j;
}
}
CHECK(j < format.size()) << "debug.printf format ends with an incomplete conversion";
char conversion = format[j];
CHECK(IsSupportedPrintfConversion(conversion))
<< "debug.printf does not support conversion '%" << conversion << "'";
conversions.push_back(conversion);
i = j + 1;
}
return conversions;
}
[[maybe_unused]] bool IsPrintfIntegerType(const DataType& dtype)
{
return dtype == DataType::INT8 || dtype == DataType::INT16 || dtype == DataType::INT32 ||
dtype == DataType::INT64 || dtype == DataType::UINT8 || dtype == DataType::UINT16 ||
dtype == DataType::UINT32 || dtype == DataType::UINT64;
}
bool IsPrintfSignedIntegerType(const DataType& dtype)
{
return dtype == DataType::INT8 || dtype == DataType::INT16 || dtype == DataType::INT32 || dtype == DataType::INT64;
}
bool IsPrintfUnsignedIntegerType(const DataType& dtype)
{
return dtype == DataType::UINT8 || dtype == DataType::UINT16 || dtype == DataType::UINT32 ||
dtype == DataType::UINT64;
}
bool IsPrintfIndexType(const DataType& dtype) { return dtype == DataType::INDEX; }
bool IsPrintfBoolType(const DataType& dtype) { return dtype == DataType::BOOL; }
TypePtr DeduceDebugDumpTensorType(
[[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs)
{
CHECK(args.size() == 0x3) << "debug.dump_tensor requires exactly 3 arguments (tensor, offsets, shapes), but got "
<< args.size();
auto tensor_type = As<TensorType>(args[0]->GetType());
CHECK(tensor_type) << "debug.dump_tensor requires first argument to be a TensorType, but got "
<< args[0]->GetType()->TypeName();
auto offsets = As<MakeTuple>(args[1]);
CHECK(offsets) << "debug.dump_tensor requires offsets to be a MakeTuple";
auto shapes = As<MakeTuple>(args[2]);
CHECK(shapes) << "debug.dump_tensor requires shapes to be a MakeTuple";
const size_t rank = tensor_type->shape_.size();
CHECK(offsets->elements_.size() == rank) << "debug.dump_tensor offsets count (" << offsets->elements_.size()
<< ") must match tensor rank (" << rank << ")";
CHECK(shapes->elements_.size() == rank) << "debug.dump_tensor shapes count (" << shapes->elements_.size()
<< ") must match tensor rank (" << rank << ")";
bool is_full_tensor_window = true;
for (size_t i = 0; i < rank; ++i) {
auto offset_const = As<ConstInt>(offsets->elements_[i]);
if (!offset_const || offset_const->value_ != 0) {
is_full_tensor_window = false;
break;
}
auto shape_const = As<ConstInt>(shapes->elements_[i]);
auto tensor_dim_const = As<ConstInt>(tensor_type->shape_[i]);
if (!shape_const || !tensor_dim_const || shape_const->value_ != tensor_dim_const->value_) {
is_full_tensor_window = false;
break;
}
}
if (!is_full_tensor_window && tensor_type->tensor_view_.has_value() && !tensor_type->tensor_view_->stride.empty()) {
const auto& last_stride = tensor_type->tensor_view_->stride.back();
auto last_stride_const = As<ConstInt>(last_stride);
CHECK(last_stride_const) << "debug.dump_tensor windowed mode requires the innermost stride to be statically 1";
CHECK(last_stride_const->value_ == 1)
<< "debug.dump_tensor windowed mode requires innermost stride == 1, got " << last_stride_const->value_;
}
for (size_t i = 0; i < rank; ++i) {
auto offset_scalar = As<ScalarType>(offsets->elements_[i]->GetType());
CHECK(offset_scalar) << "debug.dump_tensor offset element " << i << " must be ScalarType, but got "
<< offsets->elements_[i]->GetType()->TypeName();
CHECK(offset_scalar->dtype_.IsInt())
<< "debug.dump_tensor offset element " << i << " must have integer dtype, but got "
<< offset_scalar->dtype_.ToString();
auto shape_scalar = As<ScalarType>(shapes->elements_[i]->GetType());
CHECK(shape_scalar) << "debug.dump_tensor shape element " << i << " must be ScalarType, but got "
<< shapes->elements_[i]->GetType()->TypeName();
CHECK(shape_scalar->dtype_.IsInt()) << "debug.dump_tensor shape element " << i
<< " must have integer dtype, but got " << shape_scalar->dtype_.ToString();
auto shape_const = As<ConstInt>(shapes->elements_[i]);
if (shape_const) {
CHECK(shape_const->value_ > 0)
<< "debug.dump_tensor shape element " << i << " must be positive, got " << shape_const->value_;
}
}
return GetUnknownType();
}
TypePtr DeduceDebugDumpTileType(
[[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs)
{
CHECK(args.size() == 1 || args.size() == 0x3)
<< "debug.dump_tile requires 1 argument (tile) or 3 arguments (tile, offsets, shapes), but got " << args.size();
auto tile_type = As<TileType>(args[0]->GetType());
CHECK(tile_type) << "debug.dump_tile requires first argument to be a TileType, but got "
<< args[0]->GetType()->TypeName();
if (args.size() == 0x3) {
auto offsets = As<MakeTuple>(args[1]);
CHECK(offsets) << "debug.dump_tile requires second argument to be a MakeTuple (offsets)";
auto shapes = As<MakeTuple>(args[2]);
CHECK(shapes) << "debug.dump_tile requires third argument to be a MakeTuple (shapes)";
const size_t rank = tile_type->shape_.size();
CHECK(rank == 0x2) << "debug.dump_tile currently only supports 2D tile windows, but got rank " << rank;
CHECK(offsets->elements_.size() == rank) << "debug.dump_tile offsets count (" << offsets->elements_.size()
<< ") must match tile rank (" << rank << ")";
CHECK(shapes->elements_.size() == rank) << "debug.dump_tile shapes count (" << shapes->elements_.size()
<< ") must match tile rank (" << rank << ")";
for (size_t i = 0; i < rank; ++i) {
auto offset_scalar = As<ScalarType>(offsets->elements_[i]->GetType());
CHECK(offset_scalar) << "debug.dump_tile offset element " << i << " must be ScalarType, but got "
<< offsets->elements_[i]->GetType()->TypeName();
CHECK(offset_scalar->dtype_.IsInt())
<< "debug.dump_tile offset element " << i << " must have integer dtype, but got "
<< offset_scalar->dtype_.ToString();
auto shape_scalar = As<ScalarType>(shapes->elements_[i]->GetType());
CHECK(shape_scalar) << "debug.dump_tile shape element " << i << " must be ScalarType, but got "
<< shapes->elements_[i]->GetType()->TypeName();
CHECK(shape_scalar->dtype_.IsInt())
<< "debug.dump_tile shape element " << i << " must have integer dtype, but got "
<< shape_scalar->dtype_.ToString();
if (auto shape_const = As<ConstInt>(shapes->elements_[i])) {
CHECK(shape_const->value_ > 0)
<< "debug.dump_tile shape element " << i << " must be positive, got " << shape_const->value_;
}
}
}
return GetUnknownType();
}
TypePtr DeduceDebugPrintfType(
[[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs)
{
bool found_format = false;
std::string format;
for (const auto& [key, value] : kwargs) {
if (key == "format") {
format = AnyCast<std::string>(value, "kwarg key: format");
found_format = true;
break;
}
}
CHECK(found_format) << "debug.printf requires 'format' kwarg";
auto conversions = ParsePrintfConversions(format);
CHECK(conversions.size() == args.size())
<< "debug.printf format expects " << conversions.size() << " scalar arguments, but got " << args.size();
for (size_t i = 0; i < args.size(); ++i) {
auto scalar_type = As<ScalarType>(args[i]->GetType());
CHECK(scalar_type) << "debug.printf argument " << i << " must be ScalarType, but got "
<< args[i]->GetType()->TypeName();
const DataType& dtype = scalar_type->dtype_;
char conversion = conversions[i];
if (conversion == 'f') {
CHECK(dtype == DataType::FP32)
<< "debug.printf conversion '%f' requires FP32 scalar, but got " << dtype.ToString();
} else if (conversion == 'x') {
CHECK(IsPrintfUnsignedIntegerType(dtype) || IsPrintfIndexType(dtype))
<< "debug.printf conversion '%" << conversion << "' requires unsigned integer or index scalar, but got "
<< dtype.ToString();
} else if (conversion == 'u') {
CHECK(IsPrintfUnsignedIntegerType(dtype) || IsPrintfBoolType(dtype) || IsPrintfIndexType(dtype))
<< "debug.printf conversion '%" << conversion
<< "' requires unsigned integer, bool, or index scalar, but got " << dtype.ToString();
} else {
CHECK(IsPrintfSignedIntegerType(dtype) || IsPrintfBoolType(dtype) || IsPrintfIndexType(dtype))
<< "debug.printf conversion '%" << conversion
<< "' requires signed integer, bool, or index scalar, but got " << dtype.ToString();
}
}
return GetUnknownType();
}
TypePtr DeduceDebugAssertType(
[[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs)
{
CHECK(args.size() >= 1) << "debug.assert requires at least 1 argument (condition), but got " << args.size();
bool found_condition_text = false;
bool found_format = false;
std::string condition_text;
std::string format;
for (const auto& [key, value] : kwargs) {
if (key == "condition_text") {
condition_text = AnyCast<std::string>(value, "kwarg key: condition_text");
found_condition_text = true;
} else if (key == "format") {
format = AnyCast<std::string>(value, "kwarg key: format");
found_format = true;
}
}
CHECK(found_condition_text) << "debug.assert requires 'condition_text' kwarg";
CHECK(found_format) << "debug.assert requires 'format' kwarg";
CHECK(!condition_text.empty()) << "debug.assert requires non-empty condition_text";
auto scalar_type = As<ScalarType>(args[0]->GetType());
CHECK(scalar_type) << "debug.assert requires condition to be ScalarType, but got "
<< args[0]->GetType()->TypeName();
CHECK(scalar_type->dtype_ == DataType::BOOL)
<< "debug.assert requires bool scalar condition, but got " << scalar_type->dtype_.ToString();
auto conversions = ParsePrintfConversions(format);
CHECK(conversions.size() == args.size() - 1)
<< "debug.assert format expects " << conversions.size() << " scalar arguments, but got " << (args.size() - 1);
for (size_t i = 1; i < args.size(); ++i) {
auto printf_arg = As<ScalarType>(args[i]->GetType());
CHECK(printf_arg) << "debug.assert argument " << (i - 1) << " must be ScalarType, but got "
<< args[i]->GetType()->TypeName();
const DataType& dtype = printf_arg->dtype_;
char conversion = conversions[i - 1];
if (conversion == 'f') {
CHECK(dtype == DataType::FP32)
<< "debug.assert conversion '%f' requires FP32 scalar, but got " << dtype.ToString();
} else if (conversion == 'x') {
CHECK(IsPrintfUnsignedIntegerType(dtype) || IsPrintfIndexType(dtype))
<< "debug.assert conversion '%" << conversion << "' requires unsigned integer or index scalar, but got "
<< dtype.ToString();
} else if (conversion == 'u') {
CHECK(IsPrintfUnsignedIntegerType(dtype) || IsPrintfBoolType(dtype) || IsPrintfIndexType(dtype))
<< "debug.assert conversion '%" << conversion
<< "' requires unsigned integer, bool, or index scalar, but got " << dtype.ToString();
} else {
CHECK(IsPrintfSignedIntegerType(dtype) || IsPrintfBoolType(dtype) || IsPrintfIndexType(dtype))
<< "debug.assert conversion '%" << conversion
<< "' requires signed integer, bool, or index scalar, but got " << dtype.ToString();
}
}
return GetUnknownType();
}
}
REGISTER_OP("debug.dump_tensor")
.set_op_category("DebugOp")
.set_description(
"Print a tensor or tensor window for debugging. Supports full dumps and window dumps "
"with dynamic offsets/shapes when the innermost stride is statically 1.")
.add_argument("tensor", "Input tensor (TensorType)")
.add_argument("offsets", "Offsets per dimension (MakeTuple of integer scalars)")
.add_argument("shapes", "Shape per dimension (MakeTuple of integer scalars)")
.set_attr<bool>("show_location")
.f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceDebugDumpTensorType(args, kwargs);
});
REGISTER_OP("debug.dump_tile")
.set_op_category("DebugOp")
.set_description(
"Print a tile or tile window for debugging. Full dumps support tiles with dynamic valid-shape; "
"window dumps support dynamic offsets on PTO and CCE, and dynamic shapes on CCE only; "
"PTO window shapes remain static-only. Tile windows are currently 2D-only.")
.add_argument("tile", "Input tile (TileType)")
.add_argument("offsets", "Optional offsets per dimension (MakeTuple of integer scalars)")
.add_argument("shapes", "Optional shape per dimension (MakeTuple of integer scalars; dynamic only on CCE)")
.set_attr<bool>("show_location")
.f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceDebugDumpTileType(args, kwargs);
});
REGISTER_OP("debug.printf")
.set_op_category("DebugOp")
.set_description("Print scalar values using a compile-time format string")
.add_argument("scalars", "Scalar arguments consumed by format conversions (variadic)")
.set_attr<std::string>("format")
.set_attr<bool>("show_location")
.f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceDebugPrintfType(args, kwargs);
});
REGISTER_OP("debug.assert")
.set_op_category("DebugOp")
.set_description("Print an assertion failure message and abort execution when condition is false")
.add_argument("condition", "Scalar boolean condition")
.add_argument("scalars", "Optional scalar arguments consumed by assertion format conversions (variadic)")
.set_attr<std::string>("condition_text")
.set_attr<std::string>("format")
.set_attr<bool>("show_location")
.f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceDebugAssertType(args, kwargs);
});
REGISTER_OP("debug.trap")
.set_op_category("DebugOp")
.set_description("Abort execution by inserting a trap")
.no_argument()
.f_deduce_type([]([[maybe_unused]] const std::vector<ExprPtr>& args,
[[maybe_unused]] const std::vector<std::pair<std::string, std::any>>& kwargs) {
(void)args;
(void)kwargs;
return GetUnknownType();
});
}
}