#include <string>

inline const std::string kCubeKernelTilingWrapperHppValue = R"(
/**
 * Copyright (c) 2026 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

#ifndef CUBE_KERNEL_TILING_WRAPPER_H
#define CUBE_KERNEL_TILING_WRAPPER_H

#include <string>
#include <vector>
#include <map>
#include <utility>
#include <cstdint>
#include <memory>
#include <sstream>
#include <stdexcept>
#include <cstddef>
#include <cstring>
#include <iomanip>
#include <algorithm>

namespace ge {
namespace autofuse {

namespace json_internal {

enum class Type {
    null,
    boolean,
    number_integer,
    number_float,
    string,
    array,
    object
};

class Json {
public:
    Json() : type_(Type::null) {}
    Json(bool value) : type_(Type::boolean), bool_value_(value) {}
    Json(int value) : type_(Type::number_integer), int_value_(value) {}
    Json(int64_t value) : type_(Type::number_integer), int_value_(value) {}
    Json(double value) : type_(Type::number_float), float_value_(value) {}
    Json(const char* value) : type_(Type::string), string_value_(new std::string(value)) {}
    Json(const std::string& value) : type_(Type::string), string_value_(new std::string(value)) {}
    Json(const std::vector<int64_t>& value) : type_(Type::array), array_value_(new std::vector<Json>()) {
        for (const auto& v : value) {
            array_value_->push_back(Json(v));
        }
    }
    Json(const std::vector<double>& value) : type_(Type::array), array_value_(new std::vector<Json>()) {
        for (const auto& v : value) {
            array_value_->push_back(Json(v));
        }
    }
    Json(const std::vector<std::string>& value) : type_(Type::array), array_value_(new std::vector<Json>()) {
        for (const auto& v : value) {
            array_value_->push_back(Json(v));
        }
    }
    
    Json(const Json& other) : type_(other.type_) {
        CopyValue(other);
    }
    
    Json(Json&& other) noexcept : type_(other.type_) {
        MoveValue(std::move(other));
        other.type_ = Type::null;
    }
    
    Json& operator=(const Json& other) {
        if (this != &other) {
            Clear();
            type_ = other.type_;
            CopyValue(other);
        }
        return *this;
    }
    
    Json& operator=(Json&& other) noexcept {
        if (this != &other) {
            Clear();
            type_ = other.type_;
            MoveValue(std::move(other));
            other.type_ = Type::null;
        }
        return *this;
    }
    
    ~Json() {
        Clear();
    }
    
    Type type() const { return type_; }
    bool is_null() const { return type_ == Type::null; }
    bool is_boolean() const { return type_ == Type::boolean; }
    bool is_number() const { return type_ == Type::number_integer || type_ == Type::number_float; }
    bool is_string() const { return type_ == Type::string; }
    bool is_array() const { return type_ == Type::array; }
    bool is_object() const { return type_ == Type::object; }
    
    bool get_bool() const {
        if (type_ != Type::boolean) throw std::runtime_error("Json is not a boolean");
        return bool_value_;
    }
    
    int64_t get_int64() const {
        if (type_ == Type::number_integer) return int_value_;
        if (type_ == Type::number_float) return static_cast<int64_t>(float_value_);
        throw std::runtime_error("Json is not a number");
    }
    
    int get_int() const {
        return static_cast<int>(get_int64());
    }
    
    double get_double() const {
        if (type_ == Type::number_float) return float_value_;
        if (type_ == Type::number_integer) return static_cast<double>(int_value_);
        throw std::runtime_error("Json is not a number");
    }
    
    std::string get_string() const {
        if (type_ != Type::string) throw std::runtime_error("Json is not a string");
        return *string_value_;
    }
    
    std::vector<int64_t> get_int64_array() const {
        if (type_ != Type::array) throw std::runtime_error("Json is not an array");
        std::vector<int64_t> result;
        for (const auto& item : *array_value_) {
            result.push_back(item.get_int64());
        }
        return result;
    }
    
    std::vector<double> get_double_array() const {
        if (type_ != Type::array) throw std::runtime_error("Json is not an array");
        std::vector<double> result;
        for (const auto& item : *array_value_) {
            result.push_back(item.get_double());
        }
        return result;
    }
    
    std::vector<std::string> get_string_array() const {
        if (type_ != Type::array) throw std::runtime_error("Json is not an array");
        std::vector<std::string> result;
        for (const auto& item : *array_value_) {
            result.push_back(item.get_string());
        }
        return result;
    }
    
    Json& operator[](const std::string& key) {
        if (type_ == Type::null) {
            type_ = Type::object;
            object_value_ = new std::map<std::string, Json>();
        }
        if (type_ != Type::object) throw std::runtime_error("Json is not an object");
        return (*object_value_)[key];
    }
    
    const Json& operator[](const std::string& key) const {
        if (type_ != Type::object) throw std::runtime_error("Json is not an object");
        static const Json null_json;
        auto it = object_value_->find(key);
        if (it == object_value_->end()) return null_json;
        return it->second;
    }
    
    Json& operator[](size_t index) {
        if (type_ != Type::array) throw std::runtime_error("Json is not an array");
        if (index >= array_value_->size()) throw std::runtime_error("Array index out of bounds");
        return (*array_value_)[index];
    }
    
    const Json& operator[](size_t index) const {
        if (type_ != Type::array) throw std::runtime_error("Json is not an array");
        if (index >= array_value_->size()) throw std::runtime_error("Array index out of bounds");
        return (*array_value_)[index];
    }
    
    bool contains(const std::string& key) const {
        if (type_ != Type::object) return false;
        return object_value_->find(key) != object_value_->end();
    }
    
    void push_back(const Json& value) {
        if (type_ == Type::null) {
            type_ = Type::array;
            array_value_ = new std::vector<Json>();
        }
        if (type_ != Type::array) throw std::runtime_error("Json is not an array");
        array_value_->push_back(value);
    }
    
    void push_back(Json&& value) {
        if (type_ == Type::null) {
            type_ = Type::array;
            array_value_ = new std::vector<Json>();
        }
        if (type_ != Type::array) throw std::runtime_error("Json is not an array");
        array_value_->push_back(std::move(value));
    }
    
    size_t size() const {
        if (type_ == Type::array) return array_value_->size();
        if (type_ == Type::object) return object_value_->size();
        return 0;
    }
    
    std::string dump(int indent = -1) const {
        std::ostringstream oss;
        Dump(oss, indent, 0);
        return oss.str();
    }
    
    static Json parse(const std::string& str) {
        Parser parser(str);
        return parser.Parse();
    }
    
    static Json array() {
        Json j;
        j.type_ = Type::array;
        j.array_value_ = new std::vector<Json>();
        return j;
    }
    
    static Json object() {
        Json j;
        j.type_ = Type::object;
        j.object_value_ = new std::map<std::string, Json>();
        return j;
    }
    
private:
    Type type_;
    
    union {
        bool bool_value_;
        int64_t int_value_;
        double float_value_;
        std::string* string_value_;
        std::vector<Json>* array_value_;
        std::map<std::string, Json>* object_value_;
    };
    
    void Clear() {
        switch (type_) {
            case Type::string:
                delete string_value_;
                break;
            case Type::array:
                delete array_value_;
                break;
            case Type::object:
                delete object_value_;
                break;
            default:
                break;
        }
    }
    
    void CopyValue(const Json& other) {
        switch (other.type_) {
            case Type::null:
                break;
            case Type::boolean:
                bool_value_ = other.bool_value_;
                break;
            case Type::number_integer:
                int_value_ = other.int_value_;
                break;
            case Type::number_float:
                float_value_ = other.float_value_;
                break;
            case Type::string:
                string_value_ = new std::string(*other.string_value_);
                break;
            case Type::array:
                array_value_ = new std::vector<Json>(*other.array_value_);
                break;
            case Type::object:
                object_value_ = new std::map<std::string, Json>(*other.object_value_);
                break;
        }
    }
    
    void MoveValue(Json&& other) {
        switch (other.type_) {
            case Type::null:
                break;
            case Type::boolean:
                bool_value_ = other.bool_value_;
                break;
            case Type::number_integer:
                int_value_ = other.int_value_;
                break;
            case Type::number_float:
                float_value_ = other.float_value_;
                break;
            case Type::string:
                string_value_ = other.string_value_;
                other.string_value_ = nullptr;
                break;
            case Type::array:
                array_value_ = other.array_value_;
                other.array_value_ = nullptr;
                break;
            case Type::object:
                object_value_ = other.object_value_;
                other.object_value_ = nullptr;
                break;
        }
    }
    
    void Dump(std::ostringstream& oss, int indent, int level) const {
        std::string indent_str;
        if (indent > 0) {
            indent_str = std::string(level * indent, ' ');
        }
        
        switch (type_) {
            case Type::null:
                oss << "null";
                break;
            case Type::boolean:
                oss << (bool_value_ ? "true" : "false");
                break;
            case Type::number_integer:
                oss << int_value_;
                break;
            case Type::number_float:
                oss << float_value_;
                break;
            case Type::string:
                oss << "\"" << EscapeString(*string_value_) << "\"";
                break;
            case Type::array:
                oss << "[";
                if (indent > 0 && !array_value_->empty()) {
                    oss << "\n";
                }
                for (size_t i = 0; i < array_value_->size(); ++i) {
                    if (indent > 0) {
                        oss << indent_str << std::string(indent, ' ');
                    }
                    (*array_value_)[i].Dump(oss, indent, level + 1);
                    if (i < array_value_->size() - 1) {
                        oss << ",";
                    }
                    if (indent > 0) {
                        oss << "\n";
                    }
                }
                if (indent > 0 && !array_value_->empty()) {
                    oss << indent_str;
                }
                oss << "]";
                break;
            case Type::object:
                oss << "{";
                if (indent > 0 && !object_value_->empty()) {
                    oss << "\n";
                }
                auto it = object_value_->begin();
                for (size_t i = 0; i < object_value_->size(); ++i, ++it) {
                    if (indent > 0) {
                        oss << indent_str << std::string(indent, ' ');
                    }
                    oss << "\"" << it->first << "\":";
                    if (indent > 0) {
                        oss << " ";
                    }
                    it->second.Dump(oss, indent, level + 1);
                    if (i < object_value_->size() - 1) {
                        oss << ",";
                    }
                    if (indent > 0) {
                        oss << "\n";
                    }
                }
                if (indent > 0 && !object_value_->empty()) {
                    oss << indent_str;
                }
                oss << "}";
                break;
        }
    }
    
    static std::string EscapeString(const std::string& str) {
        std::string result;
        for (char c : str) {
            switch (c) {
                case '"': result += "\\\""; break;
                case '\\': result += "\\\\"; break;
                case '\b': result += "\\b"; break;
                case '\f': result += "\\f"; break;
                case '\n': result += "\\n"; break;
                case '\r': result += "\\r"; break;
                case '\t': result += "\\t"; break;
                default:
                    if (static_cast<unsigned char>(c) < 0x20) {
                        char buf[7];
                        snprintf(buf, sizeof(buf), "\\u%04x", static_cast<unsigned char>(c));
                        result += buf;
                    } else {
                        result += c;
                    }
                    break;
            }
        }
        return result;
    }
    
    class Parser {
    public:
        Parser(const std::string& str) : str_(str), pos_(0) {
            SkipWhitespace();
        }
        
        Json Parse() {
            if (pos_ >= str_.size()) {
                throw std::runtime_error("Empty JSON string");
            }
            return ParseValue();
        }
        
    private:
        const std::string& str_;
        size_t pos_;
        
        void SkipWhitespace() {
            while (pos_ < str_.size() && (str_[pos_] == ' ' || str_[pos_] == '\t' || 
                                          str_[pos_] == '\n' || str_[pos_] == '\r')) {
                ++pos_;
            }
        }
        
        char Peek() const {
            if (pos_ >= str_.size()) return '\0';
            return str_[pos_];
        }
        
        char Consume() {
            if (pos_ >= str_.size()) return '\0';
            return str_[pos_++];
        }
        
        Json ParseValue() {
            SkipWhitespace();
            char c = Peek();
            
            if (c == 'n') return ParseNull();
            if (c == 't' || c == 'f') return ParseBoolean();
            if (c == '"') return ParseString();
            if (c == '[') return ParseArray();
            if (c == '{') return ParseObject();
            if (c == '-' || (c >= '0' && c <= '9')) return ParseNumber();
            
            throw std::runtime_error(std::string("Unexpected character: ") + c);
        }
        
        Json ParseNull() {
            if (str_.substr(pos_, 4) == "null") {
                pos_ += 4;
                return Json();
            }
            throw std::runtime_error("Expected 'null'");
        }
        
        Json ParseBoolean() {
            if (str_.substr(pos_, 4) == "true") {
                pos_ += 4;
                return Json(true);
            }
            if (str_.substr(pos_, 5) == "false") {
                pos_ += 5;
                return Json(false);
            }
            throw std::runtime_error("Expected 'true' or 'false'");
        }
        
        Json ParseNumber() {
            size_t start = pos_;
            if (Peek() == '-') Consume();
            
            while (pos_ < str_.size() && (str_[pos_] >= '0' && str_[pos_] <= '9')) {
                ++pos_;
            }
            
            bool is_float = false;
            if (pos_ < str_.size() && str_[pos_] == '.') {
                is_float = true;
                ++pos_;
                while (pos_ < str_.size() && (str_[pos_] >= '0' && str_[pos_] <= '9')) {
                    ++pos_;
                }
            }
            
            if (pos_ < str_.size() && (str_[pos_] == 'e' || str_[pos_] == 'E')) {
                is_float = true;
                ++pos_;
                if (pos_ < str_.size() && (str_[pos_] == '+' || str_[pos_] == '-')) {
                    ++pos_;
                }
                while (pos_ < str_.size() && (str_[pos_] >= '0' && str_[pos_] <= '9')) {
                    ++pos_;
                }
            }
            
            std::string num_str = str_.substr(start, pos_ - start);
            if (is_float) {
                return Json(std::stod(num_str));
            } else {
                return Json(static_cast<int64_t>(std::stoll(num_str)));
            }
        }
        
        Json ParseString() {
            if (Consume() != '"') {
                throw std::runtime_error("Expected '\"'");
            }
            
            std::string result;
            while (pos_ < str_.size() && str_[pos_] != '"') {
                if (str_[pos_] == '\\') {
                    ++pos_;
                    if (pos_ >= str_.size()) {
                        throw std::runtime_error("Unexpected end of string");
                    }
                    switch (str_[pos_]) {
                        case '"': result += '"'; break;
                        case '\\': result += '\\'; break;
                        case '/': result += '/'; break;
                        case 'b': result += '\b'; break;
                        case 'f': result += '\f'; break;
                        case 'n': result += '\n'; break;
                        case 'r': result += '\r'; break;
                        case 't': result += '\t'; break;
                        case 'u': {
                            if (pos_ + 4 >= str_.size()) {
                                throw std::runtime_error("Invalid unicode escape");
                            }
                            std::string hex_str = str_.substr(pos_ + 1, 4);
                            unsigned int codepoint = std::stoul(hex_str, nullptr, 16);
                            if (codepoint < 0x80) {
                                result += static_cast<char>(codepoint);
                            } else if (codepoint < 0x800) {
                                result += static_cast<char>(0xC0 | (codepoint >> 6));
                                result += static_cast<char>(0x80 | (codepoint & 0x3F));
                            } else {
                                result += static_cast<char>(0xE0 | (codepoint >> 12));
                                result += static_cast<char>(0x80 | ((codepoint >> 6) & 0x3F));
                                result += static_cast<char>(0x80 | (codepoint & 0x3F));
                            }
                            pos_ += 4;
                            break;
                        }
                        default:
                            throw std::runtime_error("Invalid escape sequence");
                    }
                } else {
                    result += str_[pos_];
                }
                ++pos_;
            }
            
            if (pos_ >= str_.size() || Consume() != '"') {
                throw std::runtime_error("Unterminated string");
            }
            
            return Json(result);
        }
        
        Json ParseArray() {
            if (Consume() != '[') {
                throw std::runtime_error("Expected '['");
            }
            
            Json result = Json::array();
            SkipWhitespace();
            
            if (Peek() == ']') {
                Consume();
                return result;
            }
            
            while (true) {
                result.push_back(ParseValue());
                SkipWhitespace();
                
                if (Peek() == ']') {
                    Consume();
                    return result;
                }
                
                if (Peek() == ',') {
                    Consume();
                } else {
                    throw std::runtime_error("Expected ',' or ']' in array");
                }
            }
        }
        
        Json ParseObject() {
            if (Consume() != '{') {
                throw std::runtime_error("Expected '{'");
            }
            
            Json result = Json::object();
            SkipWhitespace();
            
            if (Peek() == '}') {
                Consume();
                return result;
            }
            
            while (true) {
                SkipWhitespace();
                Json key = ParseString();
                SkipWhitespace();
                
                if (Consume() != ':') {
                    throw std::runtime_error("Expected ':' after key");
                }
                
                Json value = ParseValue();
                result[key.get_string()] = std::move(value);
                SkipWhitespace();
                
                if (Peek() == '}') {
                    Consume();
                    return result;
                }
                
                if (Peek() == ',') {
                    Consume();
                } else {
                    throw std::runtime_error("Expected ',' or '}' in object");
                }
            }
        }
    };
};

} // namespace json_internal

namespace crypto {

class SHA1 {
public:
    static constexpr size_t DIGEST_LENGTH = 20;
    
    static std::string Hash(const std::string& input) {
        SHA1 sha1;
        sha1.Update(reinterpret_cast<const uint8_t*>(input.c_str()), input.length());
        uint8_t digest[DIGEST_LENGTH];
        sha1.Final(digest);
        return DigestToHex(digest);
    }

private:
    SHA1() {
        Reset();
    }
    
    void Reset() {
        m_digest[0] = 0x67452301;
        m_digest[1] = 0xEFCDAB89;
        m_digest[2] = 0x98BADCFE;
        m_digest[3] = 0x10325476;
        m_digest[4] = 0xC3D2E1F0;
        m_block_len = 0;
        m_total_len = 0;
    }
    
    void Update(const uint8_t* data, size_t len) {
        while (len) {
            size_t copy_len = std::min(len, 64 - m_block_len);
            std::memcpy(m_block + m_block_len, data, copy_len);
            
            m_block_len += copy_len;
            m_total_len += copy_len;
            data += copy_len;
            len -= copy_len;
            
            if (m_block_len == 64) {
                ProcessBlock(m_block);
                m_block_len = 0;
            }
        }
    }
    
    void Final(uint8_t* digest) {
        uint64_t total_bits = m_total_len * 8;
        
        m_block[m_block_len++] = 0x80;
        if (m_block_len > 56) {
            while (m_block_len < 64) {
                m_block[m_block_len++] = 0;
            }
            ProcessBlock(m_block);
            m_block_len = 0;
        }
        
        while (m_block_len < 56) {
            m_block[m_block_len++] = 0;
        }
        
        for (int i = 7; i >= 0; --i) {
            m_block[m_block_len++] = static_cast<uint8_t>((total_bits >> (i * 8)) & 0xFF);
        }
        
        ProcessBlock(m_block);
        
        for (int i = 0; i < 5; ++i) {
            digest[i * 4 + 0] = static_cast<uint8_t>((m_digest[i] >> 24) & 0xFF);
            digest[i * 4 + 1] = static_cast<uint8_t>((m_digest[i] >> 16) & 0xFF);
            digest[i * 4 + 2] = static_cast<uint8_t>((m_digest[i] >> 8) & 0xFF);
            digest[i * 4 + 3] = static_cast<uint8_t>(m_digest[i] & 0xFF);
        }
    }
    
    void ProcessBlock(const uint8_t* block) {
        uint32_t w[80];
        
        for (int i = 0; i < 16; ++i) {
            w[i] = (block[i * 4 + 0] << 24) | (block[i * 4 + 1] << 16) | 
                   (block[i * 4 + 2] << 8) | block[i * 4 + 3];
        }
        
        for (int i = 16; i < 80; ++i) {
            uint32_t temp = w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16];
            w[i] = ROTL(temp, 1);
        }
        
        uint32_t a = m_digest[0];
        uint32_t b = m_digest[1];
        uint32_t c = m_digest[2];
        uint32_t d = m_digest[3];
        uint32_t e = m_digest[4];
        
        for (int i = 0; i < 80; ++i) {
            uint32_t f, k;
            
            if (i < 20) {
                f = (b & c) | ((~b) & d);
                k = 0x5A827999;
            } else if (i < 40) {
                f = b ^ c ^ d;
                k = 0x6ED9EBA1;
            } else if (i < 60) {
                f = (b & c) | (b & d) | (c & d);
                k = 0x8F1BBCDC;
            } else {
                f = b ^ c ^ d;
                k = 0xCA62C1D6;
            }
            
            uint32_t temp = ROTL(a, 5) + f + e + k + w[i];
            e = d;
            d = c;
            c = ROTL(b, 30);
            b = a;
            a = temp;
        }
        
        m_digest[0] += a;
        m_digest[1] += b;
        m_digest[2] += c;
        m_digest[3] += d;
        m_digest[4] += e;
    }
    
    static uint32_t ROTL(uint32_t x, uint32_t n) {
        return (x << n) | (x >> (32 - n));
    }
    
    static std::string DigestToHex(const uint8_t* digest) {
        std::ostringstream oss;
        oss << std::hex << std::setfill('0');
        for (size_t i = 0; i < DIGEST_LENGTH; ++i) {
            oss << std::setw(2) << static_cast<int>(digest[i]);
        }
        return oss.str();
    }
    
    uint32_t m_digest[5];
    uint8_t m_block[64];
    size_t m_block_len;
    uint64_t m_total_len;
};

} // namespace crypto

using Json = json_internal::Json;

#include "arch35/mat_mul_tiling_data.h"

struct TensorInfo {
    std::string param_name;
    std::vector<int64_t> shape;
    std::vector<int64_t> ori_shape;
    std::string dtype;
    std::string format;
    std::string name;
    int64_t range_start = 0;
    int64_t range_end = 0;
};

struct AttrInfo {
    std::string name;
    std::string dtype;
    std::string value_str;
    bool value_bool = false;
    int64_t value_int = 0;
    double value_float = 0.0;
    std::vector<int64_t> value_list_int;
    std::vector<double> value_list_float;
    std::vector<std::string> value_list_str;
    bool is_list = false;
};

struct CompileInfo {
    std::string soc_version;
    std::string core_type;
    std::string op_kernel_lib;
    std::string op_impl_mode;
    int64_t aicore_num = 0;
    int64_t aiv_num = 0;
    std::map<std::string, std::string> extra_info;
};

struct TilingResult {
    std::vector<uint8_t> tiling_data;
    int64_t tiling_key = 0;
    int64_t block_dim = 0;
    int64_t workspace_size = 0;
    bool atomic_flag = false;
    std::string error_msg;
    bool success = false;
    
    BatchMatMulV3BasicTilingData batch_matmul_tiling_data;
    MatMulV3BasicTilingData matmul_basic_tiling_data;
};

class CubeKernelTilingWrapper {
public:
    CubeKernelTilingWrapper();
    ~CubeKernelTilingWrapper();

    TilingResult DoMatMulTiling(const CompileInfo& compile_info,
                                const std::vector<TensorInfo>& inputs,
                                const std::vector<TensorInfo>& outputs,
                                const std::vector<AttrInfo>& attrs,
                                bool is_batch = false);

    static void BuildMatMulArgs(const std::vector<TensorInfo>& args_list,
                               int input_num,
                               bool transpose_a,
                               bool transpose_b,
                               std::vector<TensorInfo>& origin_inputs,
                               std::vector<TensorInfo>& origin_outputs,
                               std::vector<TensorInfo>& inputs);

    static std::string GenerateCompileInfoHash(const std::string& compile_info_info);

    static void ChangeParamNameToName(std::vector<TensorInfo>& inputs);
    static void InputsPreProcess(std::vector<TensorInfo>& inputs);
    static void AttrsPreProcess(std::vector<AttrInfo>& attrs);
    static std::vector<uint8_t> AlignTilingDataTo8Bytes(const std::vector<uint8_t>& tiling_data, const std::string& soc_version);

private:
    static std::string SerializeToJson(const CompileInfo& compile_info);
    static std::string SerializeToJson(const std::vector<TensorInfo>& tensors);
    static std::string SerializeToJson(const std::vector<AttrInfo>& attrs);
    static std::string SerializeToJson(const std::map<std::string, std::string>& extra_params);

    static bool ParseTilingResult(const std::string& json_str, TilingResult& result);

    char* CallDoOpTilingForCompile(const char* op_type,
                                           const char* compile_info,
                                           const char* compile_info_hash,
                                           const char* inputs,
                                           const char* outputs,
                                           const char* attrs,
                                           char* buf,
                                           size_t buf_size,
                                           uint64_t* timer,
                                           const char* extra_params);
};

} // namespace autofuse
} // namespace ge

#endif // CUBE_KERNEL_TILING_WRAPPER_H
)";

inline const std::string kCubeKernelTilingWrapperCppValue = R"(
#include "cube_kernel_tiling_wrapper.h"
#include <sstream>
#include <iomanip>
#include <dlfcn.h>
#include <iostream>
#include <cstdlib>
#include <unistd.h>
#include <cstring>
#include <cmath>
#include <limits>

using json = ge::autofuse::Json;
using SHA1 = ge::autofuse::crypto::SHA1;

#ifndef DEFAULT_ASCEND_OPP_PATH
#define DEFAULT_ASCEND_OPP_PATH "/usr/local/Ascend/cann/opp"
#endif

namespace ge {
namespace autofuse {

CubeKernelTilingWrapper::CubeKernelTilingWrapper() {}

CubeKernelTilingWrapper::~CubeKernelTilingWrapper() {}

void CubeKernelTilingWrapper::BuildMatMulArgs(const std::vector<TensorInfo>& args_list,
                                               int input_num,
                                               bool transpose_a,
                                               bool transpose_b,
                                               std::vector<TensorInfo>& origin_inputs,
                                               std::vector<TensorInfo>& origin_outputs,
                                               std::vector<TensorInfo>& inputs) {
    origin_inputs.clear();
    origin_outputs.clear();
    inputs.clear();

    int64_t m = 0;
    int64_t n = 0;
    std::vector<int64_t> write_shape;

    for (int i = 0; i < input_num && i < static_cast<int>(args_list.size()); ++i) {
        TensorInfo input = args_list[i];
        input.param_name = "input" + std::to_string(i);
        input.ori_shape = input.shape;

        origin_inputs.push_back(input);
        inputs.push_back(input);

        if (i == 0) {
            write_shape = input.shape;
            if (transpose_a) {
                m = input.shape[input.shape.size() - 1];
            } else {
                m = input.shape[input.shape.size() - 2];
            }
        } else if (i == 1) {
            if (transpose_b) {
                n = input.shape[input.shape.size() - 2];
            } else {
                n = input.shape[input.shape.size() - 1];
            }
        }
    }

    if (args_list.size() >= 2) {
        TensorInfo output = args_list[args_list.size() - 2];
        output.param_name = "output0";
        if (!write_shape.empty()) {
            write_shape[write_shape.size() - 1] = n;
            write_shape[write_shape.size() - 2] = m;
            output.shape = write_shape;
            output.ori_shape = write_shape;
        }
        if (!inputs.empty()) {
            output.dtype = inputs.back().dtype;
        }
        origin_outputs.push_back(output);
    }
}

std::string CubeKernelTilingWrapper::SerializeToJson(const CompileInfo& compile_info) {
    json j;
    j["soc_version"] = compile_info.soc_version;
    j["core_type"] = compile_info.core_type;
    j["op_kernel_lib"] = compile_info.op_kernel_lib;
    j["op_impl_mode"] = compile_info.op_impl_mode;
    j["aicore_num"] = compile_info.aicore_num;
    j["aiv_num"] = compile_info.aiv_num;

    if (!compile_info.extra_info.empty()) {
        json extra;
        for (const auto& pair : compile_info.extra_info) {
            extra[pair.first] = pair.second;
        }
        j["extra_info"] = extra;
    }

    return j.dump();
}

std::string CubeKernelTilingWrapper::SerializeToJson(const std::vector<TensorInfo>& tensors) {
    json j = json::array();
    for (const auto& tensor : tensors) {
        json t;
        t["param_name"] = tensor.param_name;
        t["shape"] = tensor.shape;
        t["ori_shape"] = tensor.ori_shape;
        t["dtype"] = tensor.dtype;
        t["format"] = tensor.format;
        t["name"] = tensor.name;
        t["range_start"] = tensor.range_start;
        t["range_end"] = tensor.range_end;
        j.push_back(t);
    }
    return j.dump();
}

std::string CubeKernelTilingWrapper::SerializeToJson(const std::vector<AttrInfo>& attrs) {
    json j = json::array();
    for (const auto& attr : attrs) {
        json a;
        a["name"] = attr.name;
        a["dtype"] = attr.dtype;

        if (attr.is_list) {
            if (attr.dtype == "list_int" || attr.dtype == "list_int32") {
                a["value"] = attr.value_list_int;
            } else if (attr.dtype == "list_float" || attr.dtype == "list_float32") {
                a["value"] = attr.value_list_float;
            } else if (attr.dtype == "list_str") {
                a["value"] = attr.value_list_str;
            }
        } else {
            if (attr.dtype == "bool") {
                a["value"] = attr.value_bool;
            } else if (attr.dtype == "int" || attr.dtype == "int32" || attr.dtype == "int64") {
                a["value"] = attr.value_int;
            } else if (attr.dtype == "float" || attr.dtype == "float32" || attr.dtype == "float64") {
                a["value"] = attr.value_float;
            } else {
                a["value"] = attr.value_str;
            }
        }
        j.push_back(a);
    }
    return j.dump();
}

std::string CubeKernelTilingWrapper::SerializeToJson(const std::map<std::string, std::string>& extra_params) {
    json j;
    for (const auto& pair : extra_params) {
        j[pair.first] = pair.second;
    }
    return j.dump();
}

std::string CubeKernelTilingWrapper::GenerateCompileInfoHash(const std::string& compile_info_json) {
    return SHA1::Hash(compile_info_json);
}

void CubeKernelTilingWrapper::ChangeParamNameToName(std::vector<TensorInfo>& inputs) {
    for (auto& input : inputs) {
        if (input.name.empty() && !input.param_name.empty()) {
            input.name = input.param_name;
        }
    }
}

void CubeKernelTilingWrapper::InputsPreProcess(std::vector<TensorInfo>& inputs) {
    for (auto& input : inputs) {
        if (input.range_start != 0 || input.range_end != 0) {
            if (input.range_start == std::numeric_limits<int64_t>::min() ||
                input.range_start == std::numeric_limits<int64_t>::max()) {
                input.range_start = 0;
            }
            if (input.range_end == std::numeric_limits<int64_t>::min() ||
                input.range_end == std::numeric_limits<int64_t>::max()) {
                input.range_end = 0;
            }
        }
    }
}

void CubeKernelTilingWrapper::AttrsPreProcess(std::vector<AttrInfo>& attrs) {
    for (auto& attr : attrs) {
        if (attr.dtype == "float" || attr.dtype == "float32" || attr.dtype == "float64") {
            if (!attr.is_list) {
                if (std::isinf(attr.value_float)) {
                    if (attr.value_float > 0) {
                        attr.value_str = "float(1.0 / 0.0) ";
                    } else {
                        attr.value_str = "float(-1.0 / 0.0) ";
                    }
                } else if (std::isnan(attr.value_float)) {
                    attr.value_str = "float(0.0 / 0.0) ";
                }
            } else {
                for (auto& val : attr.value_list_float) {
                    if (std::isinf(val)) {
                        if (val > 0) {
                            val = std::numeric_limits<double>::max();
                        } else {
                            val = std::numeric_limits<double>::min();
                        }
                    } else if (std::isnan(val)) {
                        val = 0.0;
                    }
                }
            }
        } else if (attr.dtype == "list_float" || attr.dtype == "list_float32") {
            for (auto& val : attr.value_list_float) {
                if (std::isinf(val)) {
                    if (val > 0) {
                        val = std::numeric_limits<double>::max();
                    } else {
                        val = std::numeric_limits<double>::min();
                    }
                } else if (std::isnan(val)) {
                    val = 0.0;
                }
            }
        }
    }
}

std::vector<uint8_t> CubeKernelTilingWrapper::AlignTilingDataTo8Bytes(const std::vector<uint8_t>& tiling_data, const std::string& soc_version) {
    size_t original_size = tiling_data.size();
    std::vector<uint8_t> aligned_data = tiling_data;

    if (soc_version == "Ascend310P") {
        return aligned_data;
    }

    size_t aligned_size = ((original_size + 7) / 8) * 8;
    size_t padding_size = aligned_size - original_size;

    aligned_data.resize(aligned_size, 0);

    return aligned_data;
}

bool CubeKernelTilingWrapper::ParseTilingResult(const std::string& json_str, TilingResult& result) {
    try {
        json j = json::parse(json_str);

        if (j.contains("ret_code") && j["ret_code"].get_int64() != 0) {
            result.success = false;
            if (j.contains("error_messages") && j["error_messages"].is_array()) {
                for (size_t i = 0; i < j["error_messages"].size(); ++i) {
                    const auto& err = j["error_messages"][i];
                    if (err.contains("errormsg")) {
                        result.error_msg += err["errormsg"].get_string() + "; ";
                    }
                }
            }
            return false;
        }
        result.success = true;
        return true;
    } catch (const std::exception& e) {
        result.success = false;
        result.error_msg = std::string("Parse JSON failed: ") + e.what();
        return false;
    }
}

char* CubeKernelTilingWrapper::CallDoOpTilingForCompile(const char* op_type,
                                                        const char* compile_info,
                                                        const char* compile_info_hash,
                                                        const char* inputs,
                                                        const char* outputs,
                                                        const char* attrs,
                                                        char* buf,
                                                        size_t buf_size,
                                                        uint64_t* timer,
                                                        const char* extra_params) {
    const char* ascend_opp_path_env = std::getenv("ASCEND_OPP_PATH");
    std::string opp_base_path;
    
    if (ascend_opp_path_env != nullptr && std::strlen(ascend_opp_path_env) > 0) {
        opp_base_path = ascend_opp_path_env;
        OP_LOGI(OP_NAME, "Using ASCEND_OPP_PATH from environment: %s", opp_base_path.c_str());
    } else {
        opp_base_path = DEFAULT_ASCEND_OPP_PATH;
        OP_LOGI(OP_NAME, "Using default ASCEND_OPP_PATH: %s", opp_base_path.c_str());
    }

    // 尝试从 ASCEND_OPP_PATH 推断 libregister.so 的位置
    std::vector<std::string> libregister_paths = {
        opp_base_path + "/../x86_64-linux/lib64/libregister.so",
        opp_base_path + "/../../../x86_64-linux/lib64/libregister.so",
        opp_base_path + "/../../../runtime/lib64/libregister.so",
        opp_base_path + "/../../../compiler/lib64/libregister.so",
        "libregister.so"  // 最后尝试让系统自动查找
    };

    void* libregister_handle = nullptr;
    for (const auto& lib_path : libregister_paths) {
        OP_LOGI(OP_NAME, "Trying to load: %s", lib_path.c_str());
        libregister_handle = dlopen(lib_path.c_str(), RTLD_LAZY | RTLD_GLOBAL);
        if (libregister_handle != nullptr) {
            OP_LOGI(OP_NAME, "Successfully loaded libregister.so from: %s", lib_path.c_str());
            break;
        } else {
            OP_LOGE(OP_NAME, "Failed: %s", dlerror());
        }
    }

    if (libregister_handle == nullptr) {
        OP_LOGE(OP_NAME, "Failed to load libregister.so from all locations");
        return nullptr;
    }

    std::vector<void*> loaded_handles;
    loaded_handles.push_back(libregister_handle);

    typedef void (*TbeLoadSoAndSaveToRegistryFunc)(const char*);
    TbeLoadSoAndSaveToRegistryFunc tbe_load_func = 
        reinterpret_cast<TbeLoadSoAndSaveToRegistryFunc>(dlsym(libregister_handle, "TbeLoadSoAndSaveToRegistry"));

    std::vector<std::string> tiling_lib_paths = {
        opp_base_path + "/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libopmaster_rt.so",
        opp_base_path + "/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libopmaster_rt2.0.so",
        opp_base_path + "/op_impl/built-in/ai_core/tbe/op_tiling/lib/linux/x86_64/libopmaster_rt.so",
        opp_base_path + "/op_impl/built-in/ai_core/tbe/op_tiling/lib/linux/x86_64/libopmaster_rt2.0.so"
    };

    int loaded_count = 0;
    for (const auto& tiling_lib_path : tiling_lib_paths) {
        OP_LOGI(OP_NAME, "Checking tiling library: %s", tiling_lib_path.c_str());
        if (access(tiling_lib_path.c_str(), F_OK) == 0) {
            OP_LOGI(OP_NAME, "File exists, loading...");
            void* tiling_handle = dlopen(tiling_lib_path.c_str(), RTLD_LAZY | RTLD_GLOBAL);
            if (tiling_handle != nullptr) {
                if (tbe_load_func != nullptr) {
                    tbe_load_func(tiling_lib_path.c_str());
                }
                loaded_handles.push_back(tiling_handle);
                loaded_count++;
                OP_LOGI(OP_NAME, "Successfully loaded and registered");
            } else {
                OP_LOGE(OP_NAME, "Failed to load: %s", dlerror());
            }
        } else {
            OP_LOGI(OP_NAME, "File does not exist");
        }
    }
    
    // 如果没有加载到任何 tiling 库,尝试从 ascend-toolkit 路径加载
    if (loaded_count == 0) {
        OP_LOGI(OP_NAME, "No tiling libraries loaded from current path, trying ascend-toolkit path...");
        std::string ascend_toolkit_opp = opp_base_path;
        size_t pos = ascend_toolkit_opp.find("/cann-");
        if (pos != std::string::npos) {
            ascend_toolkit_opp.replace(pos, 6, "/ascend-toolkit/");
            OP_LOGI(OP_NAME, "Trying ASCEND_OPP_PATH: %s", ascend_toolkit_opp.c_str());
            
            std::vector<std::string> toolkit_tiling_paths = {
                ascend_toolkit_opp + "/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libopmaster_rt.so",
                ascend_toolkit_opp + "/built-in/op_impl/ai_core/tbe/op_tiling/lib/linux/x86_64/libopmaster_rt2.0.so"
            };
            
            for (const auto& tiling_lib_path : toolkit_tiling_paths) {
                OP_LOGI(OP_NAME, "Checking tiling library: %s", tiling_lib_path.c_str());
                if (access(tiling_lib_path.c_str(), F_OK) == 0) {
                    OP_LOGI(OP_NAME, "File exists, loading...");
                    void* tiling_handle = dlopen(tiling_lib_path.c_str(), RTLD_LAZY | RTLD_GLOBAL);
                    if (tiling_handle != nullptr) {
                        if (tbe_load_func != nullptr) {
                            tbe_load_func(tiling_lib_path.c_str());
                        }
                        loaded_handles.push_back(tiling_handle);
                        loaded_count++;
                        OP_LOGI(OP_NAME, "Successfully loaded and registered");
                    } else {
                        OP_LOGE(OP_NAME, "Failed to load: %s", dlerror());
                    }
                } else {
                    OP_LOGI(OP_NAME, "File does not exist");
                }
            }
        }
    }
    
    if (loaded_count == 0) {
        OP_LOGE(OP_NAME, "Warning: No tiling libraries were loaded!");
    } else {
        OP_LOGI(OP_NAME, "Total tiling libraries loaded: %d", loaded_count);
    }

    typedef char* (*DoOpTilingFunc)(const char*, const char*, const char*,
                                    const char*, const char*, const char*,
                                    char*, size_t, uint64_t*, const char*);
    DoOpTilingFunc func = reinterpret_cast<DoOpTilingFunc>(dlsym(libregister_handle, "DoOpTilingForCompile"));
    if (func == nullptr) {
        OP_LOGE(OP_NAME, "Failed to find DoOpTilingForCompile function in libregister.so");
        return nullptr;
    }

    OP_LOGI(OP_NAME, "Calling DoOpTilingForCompile...");
    char* result = func(op_type, compile_info, compile_info_hash,
                        inputs, outputs, attrs, buf, buf_size, timer, extra_params);
    OP_LOGI(OP_NAME, "DoOpTilingForCompile returned");

    return result;
}

TilingResult CubeKernelTilingWrapper::DoMatMulTiling(const CompileInfo& compile_info,
                                                       const std::vector<TensorInfo>& inputs,
                                                       const std::vector<TensorInfo>& outputs,
                                                       const std::vector<AttrInfo>& attrs,
                                                       bool is_batch) {
    TilingResult result;

    std::vector<TensorInfo> processed_inputs = inputs;
    std::vector<AttrInfo> processed_attrs = attrs;

    ChangeParamNameToName(processed_inputs);
    InputsPreProcess(processed_inputs);
    AttrsPreProcess(processed_attrs);

    std::string compile_info_json = SerializeToJson(compile_info);
    std::string inputs_json = SerializeToJson(processed_inputs);
    std::string outputs_json = SerializeToJson(outputs);
    std::string attrs_json = SerializeToJson(processed_attrs);

    std::string compile_info_hash = GenerateCompileInfoHash(compile_info_json);

    json extra_params;
    extra_params["op_name"] = is_batch ? "BatchMatMulV3" : "MatMulV3";
    extra_params["deterministic"] = false;
    std::string extra_params_json = extra_params.dump();

    std::string op_type = is_batch ? "BatchMatMulV3" : "MatMulV3";

    const size_t buf_size = 1024 * 64;
    std::vector<char> buf(buf_size, 0);

    char* ret = CallDoOpTilingForCompile(op_type.c_str(),
                                           compile_info_json.c_str(),
                                           compile_info_hash.c_str(),
                                           inputs_json.c_str(),
                                           outputs_json.c_str(),
                                           attrs_json.c_str(),
                                           buf.data(),
                                           buf_size,
                                           nullptr,
                                           extra_params_json.c_str());

    if (ret == nullptr) {
        result.success = false;
        result.error_msg = "DoOpTilingForCompile returned nullptr";
        return (result);
    }

    std::string ret_json(ret);
    ParseTilingResult(ret_json, result);

    if (result.success) {
        std::string buf_json(buf.data());
        json j = json::parse(buf_json);
        if (j.contains("tiling_data")) {
            std::string hex_str = j["tiling_data"].get_string();
            result.tiling_data.clear();
            for (size_t i = 0; i < hex_str.length(); i += 2) {
                std::string byte_str = hex_str.substr(i, 2);
                result.tiling_data.push_back(static_cast<uint8_t>(std::stoul(byte_str, nullptr, 16)));
            }

            result.tiling_data = AlignTilingDataTo8Bytes(result.tiling_data, compile_info.soc_version);

            if (result.tiling_data.size() >= sizeof(MatMulV3BasicTilingData)) {
                memcpy(&result.matmul_basic_tiling_data, result.tiling_data.data(), sizeof(MatMulV3BasicTilingData));
            }
            if (result.tiling_data.size() >= sizeof(BatchMatMulV3BasicTilingData)) {
                memcpy(&result.batch_matmul_tiling_data, result.tiling_data.data(), sizeof(BatchMatMulV3BasicTilingData));
            }
        }
        if (j.contains("tiling_key")) {
            result.tiling_key = j["tiling_key"].get_int64();
        }
        if (j.contains("block_dim")) {
            result.block_dim = j["block_dim"].get_int64();
        }
        if (j.contains("workspaces")) {
            result.workspace_size = j["workspaces"][0].get_int64();
        }
        if (j.contains("clear_atomic")) {
            result.atomic_flag = j["clear_atomic"].get_bool();
        }
    }

    return result;
}

} // namespace autofuse
} // namespace ge

)";