6c030913创建于 2025年5月22日历史提交
// Copyright (c) 2025 Huawei Technologies Co., Ltd
// All rights reserved.
//
// Licensed under the BSD 3-Clause License  (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.


#ifndef __TORCH_NPU_OP_PLUGIN_UTILS_FFT_COMMON_OP_API__
#define __TORCH_NPU_OP_PLUGIN_UTILS_FFT_COMMON_OP_API__

#include <mutex>
#include "op_plugin/utils/op_api_common.h"
#include "fft_plan_op_api.h"
#include "AsdSipNpuOpApi.h"

namespace op_api {
    // For torch_npu._C
    void setFFTPlanCapacity(int64_t maxSize);
    int64_t getFFTPlanCapacity();
    int64_t getFFTPlanSize();
    void clearFFTPlanCache();

    // For FFTc2cKernelNpuOpApi.cpp
    FFTPlanItem get_plan(int64_t prb_size, bool is_forward, PlanMode plan_mode, at::ScalarType scalar_dtype);

    // For EXEC_ASDSIP_FFT_NPU_CMD
    asdFftHandle getHandle(FFTParam param);

    struct FFTCacheKey {
        bool isAsdSip;
        FFTParam fftParam;
        PlanKey planKey;
    };
    struct FFTCacheValue {
        asdFftHandle handle = nullptr;
        FFTPlanItem plan;
    };

    inline bool operator==(const FFTCacheKey &one, const FFTCacheKey &other)
    {
        if (one.isAsdSip) {
            return one.fftParam == other.fftParam;
        } else {
            return one.planKey == other.planKey;
        }
    }

    class FFTMixCache {
    public:
        using FFTPair = std::pair<FFTCacheKey, FFTCacheValue>;
        FFTMixCache(int64_t c);
        FFTCacheValue get(FFTCacheKey &cacheKey);
        void setCapacity(int64_t maxSize);
        int64_t getCapacity();
        int64_t getSize();
        void clear();
    private:
        int64_t capacity;
        std::list<FFTPair> list{};
        std::mutex fftMutex;
    };

} // namespace op_api

#define DO_ASDSIP_COMPATIBILITY(asdSipApi, originCallExpression)                                                       \
    do {                                                                                                               \
        static const auto opApiFuncAddr = GetAsdSipApiFuncAddr("asdFftExec" #asdSipApi);                               \
        if (opApiFuncAddr == nullptr) {                                                                                \
            ASCEND_LOGW("%s not in %s, or %s not found. Will call %s", #asdSipApi,                                     \
                        GetAsdSipApiLibName(), GetAsdSipApiLibName(), #originCallExpression);                          \
            return originCallExpression;                                                                               \
        }                                                                                                              \
    } while (0)

/**
 * 异步调用npu执行, 无返回值.
 */
#define EXEC_ASDSIP_FFT_NPU_CMD(fftExecApi, inData, outData, fftParam)                                                 \
    do {                                                                                                               \
        auto sip_stream = c10_npu::getCurrentNPUStream().stream(false);                                                \
        at_npu::native::SetDeterministic();                                                                            \
        asdFftHandle handle = op_api::getHandle(fftParam);                                                             \
        size_t workspace_size = 0;                                                                                     \
        asdSipFftGetWorkspaceSize(handle, workspace_size);                                                             \
        void *workspace_addr = nullptr;                                                                                \
        at::Tensor workspace_tensor;                                                                                   \
        if (workspace_size != 0) {                                                                                     \
            workspace_tensor = at_npu::native::OpPreparation::unsafe_empty_workspace(workspace_size);                  \
            workspace_addr = const_cast<void *>(workspace_tensor.storage().data());                                    \
        }                                                                                                              \
        asdSipFftSetWorkspace(handle, workspace_addr);                                                                 \
        asdSipFftSetStream(handle, sip_stream);                                                                        \
        auto input = ConvertType(inData);                                                                         \
        auto output = ConvertType(outData);                                                                       \
        static const auto asdFftExec = GetAsdSipApiFuncAddr("asdFftExec" #fftExecApi);                                 \
        auto sip_call = [handle, input, output]() mutable -> int {                                                     \
            FftExecApiFunc fftExecApiFunc = reinterpret_cast<FftExecApiFunc>(asdFftExec);                              \
            auto api_ret = fftExecApiFunc(handle, input, output);                                                      \
            TORCH_CHECK(api_ret == 0, "call " "asdFftExec" #fftExecApi " failed");                                     \
            return api_ret;                                                                                            \
        };                                                                                                             \
        at_npu::native::OpCommand::RunOpApiV2("asdFftExec" #fftExecApi, sip_call);                                     \
    } while (false)

#endif //  __TORCH_NPU_OP_PLUGIN_UTILS_FFT_COMMON_OP_API__