KernelBuilder

功能说明

kernel层对象构建类,负责创建kernel层对象,kernel层对象包含kernel层policy和kernel层调度。

所属头文件链接

/include/elewise/kernel/builder.h

函数原型

template <typename BlockOp, const auto &Policy = defaultKernelPolicy, typename ScheduleCfg = DefaultKernelConfig,
  template <typename, const auto&, typename> class Schedule = DefaultKernelSchedule>
class KernelBuilder

参数说明

参数名称 参数类型 输入/输出 数据类型 参数说明 默认值
BlockOp 模板参数 输入 NA block层对象类型,跟kernel层是被包含关系 NA
Policy 模板参数 输入 NA kernel层的用户静态策略类型 DefaultKernelPolicy
ScheduleCfg 模板参数 输入 NA kernel层调度配置类型 DefaultKernelConfig
Schedule 模板参数 输入 NA kernel层调度类型 DefaultKernelSchedule

返回值说明

返回值数据类型 返回值说明
KernelBuilder 返回kernel层对象

约束说明

NA

使用示例

template <typename InputDtype, typename OutputDtype>
struct AddSubConfig {
    struct AddSubCompute {
        template <template <typename> class Tensor>
        __host_aicore__ constexpr auto Compute() const
        {
            auto in1 = Atvoss::PlaceHolder<1, Tensor<InputDtype>, Atvoss::ParamUsage::IN>();
            auto in2 = Atvoss::PlaceHolder<2, Tensor<InputDtype>, Atvoss::ParamUsage::IN>();
            auto in3 = Atvoss::PlaceHolder<3, InputDtype, Atvoss::ParamUsage::IN>();
            auto out = Atvoss::PlaceHolder<4, Tensor<OutputDtype>, Atvoss::ParamUsage::OUT>();
            return (out = in1 + in2 - in3);
        };
    };

    using ArchTag = Atvoss::Arch::DAV_3510;
    using BlockOp = Atvoss::Ele::BlockBuilder<AddSubCompute, ArchTag>;

    // 🔥🔥🔥 使用示例 🔥🔥🔥
    using KernelOp = Atvoss::Ele::KernelBuilder<BlockOp>;
    // 🔥🔥🔥 使用示例 🔥🔥🔥

    using DeviceOp = Atvoss::DeviceAdapter<KernelOp>;
};

template <typename InputDtype, typename OutputDtype>
static void Run() {
    /* ACL init and stream create */
    ...

    Atvoss::Tensor<InputDtype> in1(deviceIn1, {{3, 4, 0, 0, 0, 0, 0, 0}}, 2);
    Atvoss::Tensor<InputDtype> in2(deviceIn2, {{3, 4, 0, 0, 0, 0, 0, 0}}, 2);
    InputDtype in3 = 5.0;
    Atvoss::Tensor<OutputDtype> out(deviceOut, {{3, 4, 0, 0, 0, 0, 0, 0}}, 2);

    auto arguments = Atvoss::ArgumentsBuilder{}.inputOutput(in1, in2, in3, out).attr("dim", 5).build();

    using DeviceOp = typename AddSubConfig<InputDtype, OutputDtype>::DeviceOp;
    DeviceOp deviceOp;
    deviceOp.Run(arguments, stream);
}

int main(int argc, char const* argv[]) {
    Run<float, float>();
    return 0;
}