38a62078创建于 2025年12月8日历史提交
/*
 * Copyright (c) 2025 Huawei Technologies Co., Ltd.
 * This program is free software, you can redistribute it and/or modify it under the terms and conditions of
 * CANN Open Software License Agreement Version 2.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */
#include <cstring>
#include <climits>
#include <fstream>
#include <unistd.h>
#include <climits>
#include <sys/stat.h>
#include <mki/base/operation_base.h>
#include <mki/types.h>
#include <mki/utils/const/op_const.h>
#include <mki/utils/env/env.h>
#include <mki/utils/log/log.h>
#include <mki/utils/platform/platform_info.h>
#include <mki/utils/strings/str_replace.h>
#include <mki/utils/strings/str_checker.h>
#include <mki_loader/op_register.h>
#include "atbops/params/params.h"

namespace AtbOps {
using namespace Mki;
const uint64_t INPUT_NUM_TWO = 2;
const uint64_t INPUT_NUM_THREE = 3;
const size_t DYNAMICSIZE = 126976;
const size_t DYNAMICNULLSIZE = 0;
class FusionOperation : public OperationBase {
    using MmType = OpParam::Fusion;

public:
    explicit FusionOperation(const std::string &opName) noexcept : OperationBase(opName) {}

    Kernel *GetBestKernel(const LaunchParam &launchParam) const override
    {
        std::string kernelName = "FusionMatmulAddKernel";
        OpParam::Fusion fusionType = launchParam.GetParam<OpParam::Fusion>();
        if (fusionType.fusionType == OpParam::Fusion::MATMUL_GELU) {
            kernelName = "FusionMatmulGeluKernel";
        } else if (fusionType.fusionType == OpParam::Fusion::MATMUL_SIGMOID) {
            kernelName = "FusionMatmulSigmoidKernel";
        } else if (fusionType.fusionType == OpParam::Fusion::MATMUL_SWIGLU) {
            kernelName = "FusionMatmulSwiGluKernel";
        } else if (fusionType.fusionType == OpParam::Fusion::NON_FUSION) {
            kernelName = "FusionErasedKernel";
        }
        MKI_LOG(INFO) << "getBestKernel " << kernelName;
        return GetKernelByName(kernelName);
    }

    void MatMulAddFusion()
    {
        static uint8_t matMulAddFusionKernelBinData[DYNAMICSIZE];
        std::string path = std::getenv("HOME") ? std::string(std::getenv("HOME")) : "";
        if ("" == path) {
            MKI_LOG(ERROR) << "Get ENV HOME failed!";
            return;
        }
        path += "/.atb_auto_fusion/bishengir_bin/matmul_add.cpp";
        if (!Mki::CheckNameValid(path, 256)) { // 256: 最大路径路径长度
            MKI_LOG(ERROR) << "path is invalid, please check the path: " << path;
            return;
        }
        char resolvedPath[PATH_MAX] = {0};
        if (realpath(path.c_str(), resolvedPath) == nullptr) {
            MKI_LOG(ERROR) << "path resolve fail, please check the path: " << path;
            return;
        }
        if (IsSoftLink(resolvedPath)) {
            MKI_LOG(ERROR) << "MatMulSigmoidFusion CPP SHOULD NOT be a symbolic link ";
            return;
        }
        std::ifstream cpp(resolvedPath);
        std::string line;
        uint32_t counter = 0;
        while (std::getline(cpp, line)) {
            char c = static_cast<char>(std::stoi(line, nullptr, 16));
            matMulAddFusionKernelBinData[counter++] = c;
        }
        cpp.close();
        std::string deviceVersion = PlatformInfo::Instance().GetPlatformName();
        std::string kernelName = "FusionMatmulAddKernel";
        if (deviceVersion == "ascend910b") {
            static KernelBinaryRegister binFusionMatmulAddKernelascend910bregister =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), matMulAddFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulAddKernelascend910bregister);
        } else if (deviceVersion == "ascend910") {
            static KernelBinaryRegister binFusionMatmulAddKernelascend910register =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), matMulAddFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulAddKernelascend910register);
        } else if (deviceVersion == "ascend310b") {
            static KernelBinaryRegister binFusionMatmulAddKernelascend310bregister =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), matMulAddFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulAddKernelascend310bregister);
        } else if (deviceVersion == "ascend310p") {
            static KernelBinaryRegister binFusionMatmulAddKernelascend310pregister =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), matMulAddFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulAddKernelascend310pregister);
        } else {
            MKI_LOG(ERROR) << "MatMulAddFusion operation add failed ";
        }
        return;
    }

    void MatMulGeluFusion()
    {
        static uint8_t matMulGeluFusionKernelBinData[DYNAMICSIZE];
        std::string path = std::getenv("HOME") ? std::string(std::getenv("HOME")) : "";
        if ("" == path) {
            MKI_LOG(ERROR) << "Get ENV HOME failed!";
            return;
        }
        path += "/.atb_auto_fusion/bishengir_bin/matmul_gelu.cpp";
        if (!Mki::CheckNameValid(path, 256)) { // 256: 最大路径路径长度
            MKI_LOG(ERROR) << "path is invalid, please check the path: " << path;
            return;
        }
        char resolvedPath[PATH_MAX] = {0};
        if (realpath(path.c_str(), resolvedPath) == nullptr) {
            MKI_LOG(ERROR) << "path resolve fail, please check the path: " << path;
            return;
        }
        if (IsSoftLink(resolvedPath)) {
            MKI_LOG(ERROR) << "MatMulSigmoidFusion CPP SHOULD NOT be a symbolic link ";
            return;
        }
        std::ifstream cpp(resolvedPath);
        std::string line;
        uint32_t counter = 0;
        while (std::getline(cpp, line)) {
            char c = static_cast<char>(std::stoi(line, nullptr, 16));
            matMulGeluFusionKernelBinData[counter++] = c;
        }
        cpp.close();
        std::string deviceVersion = PlatformInfo::Instance().GetPlatformName();
        std::string kernelName = "FusionMatmulGeluKernel";
        if (deviceVersion == "ascend910b") {
            static KernelBinaryRegister binFusionMatmulGeluKernelascend910bregister =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), matMulGeluFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulGeluKernelascend910bregister);
        } else if (deviceVersion == "ascend910") {
            static KernelBinaryRegister binFusionMatmulGeluKernelascend910register =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), matMulGeluFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulGeluKernelascend910register);
        } else if (deviceVersion == "ascend310b") {
            static KernelBinaryRegister binFusionMatmulGeluKernelascend310bregister =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), matMulGeluFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulGeluKernelascend310bregister);
        } else if (deviceVersion == "ascend310p") {
            static KernelBinaryRegister binFusionMatmulGeluKernelascend310pregister =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), matMulGeluFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulGeluKernelascend310pregister);
        } else {
            MKI_LOG(ERROR) << "MatMulGeluFusion operation add failed ";
        }
        return;
    }

    void MatMulSigmoidFusion()
    {
        static uint8_t matMulSigmoidFusionKernelBinData[DYNAMICSIZE];
        std::string path = std::getenv("HOME") ? std::string(std::getenv("HOME")) : "";
        if ("" == path) {
            MKI_LOG(ERROR) << "Get ENV HOME failed!";
            return;
        }
        path += "/.atb_auto_fusion/bishengir_bin/matmul_sigmoid.cpp";
        if (!Mki::CheckNameValid(path, 256)) { // 256: 最大路径路径长度
            MKI_LOG(ERROR) << "path is invalid, please check the path: " << path;
            return;
        }
        char resolvedPath[PATH_MAX] = {0};
        if (realpath(path.c_str(), resolvedPath) == nullptr) {
            MKI_LOG(ERROR) << "path resolve fail, please check the path: " << path;
            return;
        }
        if (IsSoftLink(resolvedPath)) {
            MKI_LOG(ERROR) << "MatMulSigmoidFusion CPP SHOULD NOT be a symbolic link ";
            return;
        }
        std::ifstream cpp(resolvedPath);
        std::string line;
        uint32_t counter = 0;
        while (std::getline(cpp, line)) {
            char c = static_cast<char>(std::stoi(line, nullptr, 16));
            matMulSigmoidFusionKernelBinData[counter++] = c;
        }
        cpp.close();
        std::string deviceVersion = PlatformInfo::Instance().GetPlatformName();
        std::string kernelName = "FusionMatmulSigmoidKernel";
        if (deviceVersion == "ascend910b") {
            static KernelBinaryRegister binFusionMatmulSigmoidKernelascend910bregister = KernelBinaryRegister(
                deviceVersion.c_str(), kernelName.c_str(), matMulSigmoidFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulSigmoidKernelascend910bregister);
        } else if (deviceVersion == "ascend910") {
            static KernelBinaryRegister binFusionMatmulSigmoidKernelascend910register = KernelBinaryRegister(
                deviceVersion.c_str(), kernelName.c_str(), matMulSigmoidFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulSigmoidKernelascend910register);
        } else if (deviceVersion == "ascend310b") {
            static KernelBinaryRegister binFusionMatmulSigmoidKernelascend310bregister = KernelBinaryRegister(
                deviceVersion.c_str(), kernelName.c_str(), matMulSigmoidFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulSigmoidKernelascend310bregister);
        } else if (deviceVersion == "ascend310p") {
            static KernelBinaryRegister binFusionMatmulSigmoidKernelascend310pregister = KernelBinaryRegister(
                deviceVersion.c_str(), kernelName.c_str(), matMulSigmoidFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulSigmoidKernelascend310pregister);
        } else {
            MKI_LOG(ERROR) << "MatMulSigmoidFusion operation add failed ";
        }
        return;
    }

    void MatMulSwigluFusion()
    {
        static uint8_t matMulSwigluFusionKernelBinData[DYNAMICSIZE];
        std::string path = std::getenv("HOME") ? std::string(std::getenv("HOME")) : "";
        if ("" == path) {
            MKI_LOG(ERROR) << "Get ENV HOME failed!";
            return;
        }
        path += "/.atb_auto_fusion/bishengir_bin/matmul_swiglu.cpp";
        if (!Mki::CheckNameValid(path, 256)) { // 256: 最大路径路径长度
            MKI_LOG(ERROR) << "path is invalid, please check the path: " << path;
            return;
        }
        char resolvedPath[PATH_MAX] = {0};
        if (realpath(path.c_str(), resolvedPath) == nullptr) {
            MKI_LOG(ERROR) << "path resolve fail, please check the path: " << path;
            return;
        }
        if (IsSoftLink(resolvedPath)) {
            MKI_LOG(ERROR) << "MatMulSigmoidFusion CPP SHOULD NOT be a symbolic link ";
            return;
        }
        std::ifstream cpp(resolvedPath);
        std::string line;
        uint32_t counter = 0;
        while (std::getline(cpp, line)) {
            char c = static_cast<char>(std::stoi(line, nullptr, 16));
            matMulSwigluFusionKernelBinData[counter++] = c;
        }
        cpp.close();
        std::string deviceVersion = PlatformInfo::Instance().GetPlatformName();
        std::string kernelName = "FusionMatmulSwiGluKernel";
        if (deviceVersion == "ascend910b") {
            static KernelBinaryRegister binFusionMatmulSwiGluKernelascend910bregister = KernelBinaryRegister(
                deviceVersion.c_str(), kernelName.c_str(), matMulSwigluFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulSwiGluKernelascend910bregister);
        } else if (deviceVersion == "ascend910") {
            static KernelBinaryRegister binFusionMatmulSwiGluKernelascend910register = KernelBinaryRegister(
                deviceVersion.c_str(), kernelName.c_str(), matMulSwigluFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulSwiGluKernelascend910register);
        } else if (deviceVersion == "ascend310b") {
            static KernelBinaryRegister binFusionMatmulSwiGluKernelascend310bregister = KernelBinaryRegister(
                deviceVersion.c_str(), kernelName.c_str(), matMulSwigluFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulSwiGluKernelascend310bregister);
        } else if (deviceVersion == "ascend310p") {
            static KernelBinaryRegister binFusionMatmulSwiGluKernelascend310pregister = KernelBinaryRegister(
                deviceVersion.c_str(), kernelName.c_str(), matMulSwigluFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionMatmulSwiGluKernelascend310pregister);
        } else {
            MKI_LOG(ERROR) << "MatMulSwigluFusion operation add failed ";
        }
        return;
    }

    void ErasedFusion() const
    {
        static uint8_t erasedFusionKernelBinData[DYNAMICNULLSIZE];
        uint32_t counter = 0;
        std::string deviceVersion = PlatformInfo::Instance().GetPlatformName();
        std::string kernelName = "FusionErasedKernel";
        if (deviceVersion == "ascend910b") {
            static KernelBinaryRegister binFusionErasedKernelascend910bregister =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), erasedFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionErasedKernelascend910bregister);
        } else if (deviceVersion == "ascend910") {
            static KernelBinaryRegister binFusionErasedKernelascend910register =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), erasedFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionErasedKernelascend910register);
        } else if (deviceVersion == "ascend310b") {
            static KernelBinaryRegister binFusionErasedKernelascend310bregister =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), erasedFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionErasedKernelascend310bregister);
        } else if (deviceVersion == "ascend310p") {
            static KernelBinaryRegister binFusionErasedKernelascend310pregister =
                KernelBinaryRegister(deviceVersion.c_str(), kernelName.c_str(), erasedFusionKernelBinData, counter);
            UNUSED_VALUE(binFusionErasedKernelascend310pregister);
        } else {
            MKI_LOG(ERROR) << "ErasedFusion operation add failed ";
        }
        return;
    }

    bool DynamicRegisterKernelByName(const LaunchParam &launchParam, const std::string &opName) override
    {
        OpParam::Fusion fusionType = launchParam.GetParam<OpParam::Fusion>();
        std::string deviceVersion = PlatformInfo::Instance().GetPlatformName();
        if (fusionType.fusionType == OpParam::Fusion::MATMUL_ADD) {
            MatMulAddFusion();
        } else if (fusionType.fusionType == OpParam::Fusion::MATMUL_GELU) {
            MatMulGeluFusion();
        } else if (fusionType.fusionType == OpParam::Fusion::MATMUL_SIGMOID) {
            MatMulSigmoidFusion();
        } else if (fusionType.fusionType == OpParam::Fusion::MATMUL_SWIGLU) {
            MatMulSwigluFusion();
        } else {
            ErasedFusion();
            MKI_LOG(DEBUG) << "DynamicRegisterKernelByName null kernel ";
        }
        return true;
    }

    int64_t GetInputNum(const Any &specificParam) const override
    {
        OpParam::Fusion fusionParam = AnyCast<OpParam::Fusion>(specificParam);
        if (fusionParam.fusionType == OpParam::Fusion::MATMUL_ADD) {
            return INPUT_NUM_THREE;
        }
        return INPUT_NUM_TWO;
    }

protected:
    Status InferShapeImpl(const LaunchParam &launchParam, SVector<Tensor> &outTensors) const override
    {
        OpParam::Fusion fusionType = launchParam.GetParam<OpParam::Fusion>();
        if (fusionType.fusionType == OpParam::Fusion::MATMUL_ADD) {
            auto inTensorDescA = launchParam.GetInTensor(2).desc;
            TensorDesc &tensorDescOut = outTensors[0].desc;
            tensorDescOut.dtype = TENSOR_DTYPE_FLOAT16;
            tensorDescOut.format = inTensorDescA.format;
            auto &outDims = tensorDescOut.dims;
            outDims.emplace_back(inTensorDescA.dims[DIM_0]);
            outDims.emplace_back(inTensorDescA.dims[DIM_1]);
        } else {
            auto inTensorDescA = launchParam.GetInTensor(0).desc;
            auto inTensorDescB = launchParam.GetInTensor(1).desc;

            TensorDesc &tensorDescOut = outTensors[0].desc;
            tensorDescOut.dtype = TENSOR_DTYPE_FLOAT16;
            tensorDescOut.format = inTensorDescA.format;
            auto &outDims = tensorDescOut.dims;
            outDims.emplace_back(inTensorDescA.dims[DIM_0]);
            outDims.emplace_back(inTensorDescB.dims[DIM_0]);
        }
        return Status::OkStatus();
    }

    bool IsSoftLink(const char *path) const
    {
        struct stat fileStat;
        if (lstat(path, &fileStat) != 0) {
            return false;
        }
        return S_ISLNK(fileStat.st_mode);
    }
};
REG_OPERATION(FusionOperation);
} // namespace AtbOps