* Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wfloat-equal"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <pybind11/functional.h>
#pragma GCC diagnostic pop
#include <sstream>
#include <atb/utils/param_to_json.h>
#include "graph_operation_builder.h"
#include "enger_graph_builder.h"
#include "graph_node.h"
#include "resource/memory_manager.h"
#include "prof/prof_stats.h"
#include "operation_wrapper.h"
namespace py = pybind11;
using namespace atb;
using namespace atb::infer;
template <typename Container, typename MethodType>
void AddElements(Container &container, MethodType &method)
{
for (const auto &element : container) {
method.push_back(element);
}
}
PYBIND11_MODULE(_C, m)
{
m.doc() = "Python bindings for torch_atb";
m.def("set_buffer_size", static_cast<void(*)(uint64_t)>(&TorchAtb::MemoryManager::SetBufferSize),
py::arg("bytes"), "Set default workspace buffer size (bytes)");
py::class_<TorchAtb::ProfStats>(m, "Prof")
.def_static("get_prof_stats", &TorchAtb::ProfStats::GetProfStats, py::return_value_policy::reference)
.def("get_run_time_stats", &TorchAtb::ProfStats::GetRunTimeStats);
py::class_<TorchAtb::OperationWrapper>(m, "Operation")
.def(py::init<const LayerNormParam &>())
.def(py::init<const ElewiseParam &>())
.def(py::init<const LinearParam &>())
.def(py::init<const SoftmaxParam &>())
.def(py::init<const SelfAttentionParam &>())
.def(py::init<const PagedAttentionParam &>())
.def(py::init<const RopeParam &>())
.def(py::init<const SplitParam &>())
.def(py::init<const GatherParam &>())
.def(py::init<const ActivationParam &>())
.def(py::init<const RmsNormParam &>())
.def(py::init<const AllGatherParam &>())
.def(py::init<const AsStridedParam &>())
.def(py::init<const CumsumParam &>())
.def(py::init<const DynamicNTKParam &>())
.def(py::init<const MultinomialParam &>())
.def(py::init<const ConcatParam &>())
.def(py::init<const SliceParam &>())
.def(py::init<const TransposeParam &>())
.def(py::init<const GatingParam &>())
.def(py::init<const ReshapeAndCacheParam &>())
.def(py::init<const FillParam &>())
.def(py::init<const RazorFusionAttentionParam &>())
.def(py::init<const AllReduceParam &>())
.def(py::init<const BroadcastParam &>())
.def(py::init<const ReduceScatterParam &>())
.def(py::init<const ReduceScatterVParam &>())
.def(py::init<const FaUpdateParam &>())
.def(py::init<const LinearParallelParam &>())
.def(py::init<const LinearSparseParam &>())
.def(py::init<const RelayAttentionParam &>())
.def(py::init<const TopkToppSamplingParam &>())
.def(py::init<const AllToAllParam &>())
.def(py::init<const GraphParam &>())
.def_property_readonly("name", &TorchAtb::OperationWrapper::GetName)
.def_property_readonly("input_num", &TorchAtb::OperationWrapper::GetInputNum)
.def_property_readonly("output_num", &TorchAtb::OperationWrapper::GetOutputNum)
.def("forward", &TorchAtb::OperationWrapper::Forward)
.def("__repr__", [](const TorchAtb::OperationWrapper &opWrapper) {
std::stringstream ss;
ss << "op name: " << opWrapper.GetName() << ", input_num: " << opWrapper.GetInputNum()
<< ", output_num: " << opWrapper.GetOutputNum();
return ss.str();
});
py::class_<TorchAtb::GraphNode>(m, "Node")
.def("set_op", &TorchAtb::GraphNode::SetOperation)
.def("get_output", &TorchAtb::GraphNode::GetOutput)
.def("set_stream_id", &TorchAtb::GraphNode::SetStreamId)
.def("get_stream_id", &TorchAtb::GraphNode::GetStreamId)
.def_readonly("inTensorIds", &TorchAtb::GraphNode::inTensorIds)
.def_readonly("outTensorIds", &TorchAtb::GraphNode::outTensorIds)
.def_readonly("operation", &TorchAtb::GraphNode::operation);
py::class_<TorchAtb::GraphBuilder>(m, "Builder")
.def(py::init<const std::string &>())
.def("add_input", &TorchAtb::GraphBuilder::AddInput)
.def("add_node",
py::overload_cast<const std::vector<std::string> &, atb::Operation *>(&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::LinearParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::LayerNormParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::ElewiseParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::SoftmaxParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::SelfAttentionParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::RopeParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::SplitParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::GatherParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::ActivationParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::RmsNormParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::AllGatherParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::AsStridedParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::CumsumParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::DynamicNTKParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::MultinomialParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::ConcatParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::SliceParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::TransposeParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::GatingParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::ReshapeAndCacheParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::FillParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node",
py::overload_cast<const std::vector<std::string> &, const atb::infer::RazorFusionAttentionParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::AllReduceParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::BroadcastParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::ReduceScatterParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::ReduceScatterVParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::LinearParallelParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::LinearSparseParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::RelayAttentionParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::TopkToppSamplingParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, const atb::infer::AllToAllParam &>(
&TorchAtb::GraphBuilder::AddNode))
.def("add_node", py::overload_cast<const std::vector<std::string> &, TorchAtb::OperationWrapper &>(
&TorchAtb::GraphBuilder::AddNode))
.def("reshape", &TorchAtb::GraphBuilder::Reshape)
.def("mark_output", &TorchAtb::GraphBuilder::MarkOutput)
.def("set_execute_streams", &TorchAtb::GraphBuilder::SetExecuteStreams)
.def("build", &TorchAtb::GraphBuilder::Build);
py::class_<TorchAtb::GraphOperationBuilder>(m, "GraphBuilder")
.def(py::init<const std::string &>())
.def("set_input_output", &TorchAtb::GraphOperationBuilder::SetInputOutput)
.def("reshape", &TorchAtb::GraphOperationBuilder::Reshape)
.def("add_operation", &TorchAtb::GraphOperationBuilder::AddOperation)
.def("build", &TorchAtb::GraphOperationBuilder::Build);
py::enum_<aclDataType>(m, "AclDataType")
.value("ACL_DT_UNDEFINED", aclDataType::ACL_DT_UNDEFINED)
.value("ACL_FLOAT", aclDataType::ACL_FLOAT)
.value("ACL_FLOAT16", aclDataType::ACL_FLOAT16)
.value("ACL_INT8", aclDataType::ACL_INT8)
.value("ACL_INT32", aclDataType::ACL_INT32)
.value("ACL_UINT8", aclDataType::ACL_UINT8)
.value("ACL_INT16", aclDataType::ACL_INT16)
.value("ACL_UINT16", aclDataType::ACL_UINT16)
.value("ACL_UINT32", aclDataType::ACL_UINT32)
.value("ACL_INT64", aclDataType::ACL_INT64)
.value("ACL_UINT64", aclDataType::ACL_UINT64)
.value("ACL_DOUBLE", aclDataType::ACL_DOUBLE)
.value("ACL_BOOL", aclDataType::ACL_BOOL)
.value("ACL_STRING", aclDataType::ACL_STRING)
.value("ACL_COMPLEX64", aclDataType::ACL_COMPLEX64)
.value("ACL_COMPLEX128", aclDataType::ACL_COMPLEX128)
.value("ACL_BF16", aclDataType::ACL_BF16)
.value("ACL_INT4", aclDataType::ACL_INT4)
.value("ACL_UINT1", aclDataType::ACL_UINT1)
.value("ACL_COMPLEX32", aclDataType::ACL_COMPLEX32)
.export_values();
py::enum_<aclFormat>(m, "AclFormat")
.value("ACL_FORMAT_UNDEFINED", ACL_FORMAT_UNDEFINED)
.value("ACL_FORMAT_NCHW", ACL_FORMAT_NCHW)
.value("ACL_FORMAT_NHWC", ACL_FORMAT_NHWC)
.value("ACL_FORMAT_ND", ACL_FORMAT_ND)
.value("ACL_FORMAT_NC1HWC0", ACL_FORMAT_NC1HWC0)
.value("ACL_FORMAT_FRACTAL_Z", ACL_FORMAT_FRACTAL_Z)
.value("ACL_FORMAT_NC1HWC0_C04", ACL_FORMAT_NC1HWC0_C04)
.value("ACL_FORMAT_HWCN", ACL_FORMAT_HWCN)
.value("ACL_FORMAT_NDHWC", ACL_FORMAT_NDHWC)
.value("ACL_FORMAT_FRACTAL_NZ", ACL_FORMAT_FRACTAL_NZ)
.value("ACL_FORMAT_NCDHW", ACL_FORMAT_NCDHW)
.value("ACL_FORMAT_NDC1HWC0", ACL_FORMAT_NDC1HWC0)
.value("ACL_FRACTAL_Z_3D", ACL_FRACTAL_Z_3D)
.export_values();
py::enum_<InputLayout>(m, "InputLayout")
.value("TYPE_BSND", InputLayout::TYPE_BSND)
.value("TYPE_BNSD", InputLayout::TYPE_BNSD)
.export_values();
py::enum_<QuantType>(m, "QuantType")
.value("QUANT_UNQUANT", QuantType::QUANT_UNQUANT)
.value("QUANT_INT4", QuantType::QUANT_INT4)
.value("QUANT_INT8", QuantType::QUANT_INT8)
.value("QUANT_INT16", QuantType::QUANT_INT16)
.value("QUANT_FLOAT8", QuantType::QUANT_FLOAT8)
.value("QUANT_FLOAT16", QuantType::QUANT_FLOAT16)
.export_values();
py::enum_<DynamicQuantType>(m, "DynamicQuantType")
.value("DYNAMIC_QUANT_UNDEFINED", DynamicQuantType::DYNAMIC_QUANT_UNDEFINED)
.value("DYNAMIC_QUANT_SYMMETRIC", DynamicQuantType::DYNAMIC_QUANT_SYMMETRIC)
.value("DYNAMIC_QUANT_ASYMMETRIC", DynamicQuantType::DYNAMIC_QUANT_ASYMMETRIC)
.export_values();
py::enum_<ActivationType>(m, "ActivationType")
.value("ACTIVATION_UNDEFINED", ACTIVATION_UNDEFINED)
.value("ACTIVATION_RELU", ACTIVATION_RELU)
.value("ACTIVATION_GELU", ACTIVATION_GELU)
.value("ACTIVATION_FAST_GELU", ACTIVATION_FAST_GELU)
.value("ACTIVATION_SWISH", ACTIVATION_SWISH)
.value("ACTIVATION_LOG", ACTIVATION_LOG)
.value("ACTIVATION_SWIGLU_FORWARD", ACTIVATION_SWIGLU_FORWARD)
.value("ACTIVATION_SWIGLU_BACKWARD", ACTIVATION_SWIGLU_BACKWARD)
.value("ACTIVATION_SIGMOID", ACTIVATION_SIGMOID)
.value("ACTIVATION_FASTER_GELU_FORWARD", ACTIVATION_FASTER_GELU_FORWARD)
.value("ACTIVATION_MAX", ACTIVATION_MAX)
.export_values();
py::enum_<CommMode>(m, "CommMode")
.value("COMM_UNDEFINED", COMM_UNDEFINED)
.value("COMM_MULTI_PROCESS", COMM_MULTI_PROCESS)
.value("COMM_MULTI_THREAD", COMM_MULTI_THREAD)
.export_values();
py::class_<LayerNormParam> layerNorm(m, "LayerNormParam");
py::enum_<LayerNormParam::LayerNormType>(layerNorm, "LayerNormType")
.value("LAYER_NORM_UNDEFINED", LayerNormParam::LayerNormType::LAYER_NORM_UNDEFINED)
.value("LAYER_NORM_NORM", LayerNormParam::LayerNormType::LAYER_NORM_NORM)
.value("LAYER_NORM_PRENORM", LayerNormParam::LayerNormType::LAYER_NORM_PRENORM)
.value("LAYER_NORM_POSTNORM", LayerNormParam::LayerNormType::LAYER_NORM_POSTNORM)
.value("LAYER_NORM_MAX", LayerNormParam::LayerNormType::LAYER_NORM_MAX);
py::class_<LayerNormParam::NormParam>(layerNorm, "NormParam")
.def(py::init<QuantType, float, int32_t, int32_t, DynamicQuantType>(),
py::arg("quant_type") = QuantType::QUANT_UNQUANT,
py::arg("epsilon") = 1e-5,
py::arg("begin_norm_axis") = 0,
py::arg("begin_params_axis") = 0,
py::arg("dynamic_quant_type") = DynamicQuantType::DYNAMIC_QUANT_UNDEFINED)
.def_readwrite("quant_type", &LayerNormParam::NormParam::quantType)
.def_readwrite("epsilon", &LayerNormParam::NormParam::epsilon)
.def_readwrite("begin_norm_axis", &LayerNormParam::NormParam::beginNormAxis)
.def_readwrite("begin_params_axis", &LayerNormParam::NormParam::beginParamsAxis)
.def_readwrite("dynamic_quant_type", &LayerNormParam::NormParam::dynamicQuantType);
py::class_<LayerNormParam::PreNormParam>(layerNorm, "PreNormParam")
.def(py::init<QuantType, float, uint64_t, float>(),
py::arg("quant_type") = QuantType::QUANT_UNQUANT,
py::arg("epsilon") = 1e-5,
py::arg("op_mode") = 0,
py::arg("zoom_scale_value") = 1.0f)
.def_readwrite("quant_type", &LayerNormParam::PreNormParam::quantType)
.def_readwrite("epsilon", &LayerNormParam::PreNormParam::epsilon)
.def_readwrite("op_mode", &LayerNormParam::PreNormParam::opMode)
.def_readwrite("zoom_scale_value", &LayerNormParam::PreNormParam::zoomScaleValue);
py::class_<LayerNormParam::PostNormParam>(layerNorm, "PostNormParam")
.def(py::init<QuantType, float, uint64_t, float>(),
py::arg("quant_type") = QuantType::QUANT_UNQUANT,
py::arg("epsilon") = 1e-5,
py::arg("op_mode") = 0,
py::arg("zoom_scale_value") = 1.0f)
.def_readwrite("quant_type", &LayerNormParam::PostNormParam::quantType)
.def_readwrite("epsilon", &LayerNormParam::PostNormParam::epsilon)
.def_readwrite("op_mode", &LayerNormParam::PostNormParam::opMode)
.def_readwrite("zoom_scale_value", &LayerNormParam::PostNormParam::zoomScaleValue);
layerNorm
.def(py::init<LayerNormParam::LayerNormType, LayerNormParam::NormParam, LayerNormParam::PreNormParam,
LayerNormParam::PostNormParam>(),
py::arg("layer_type") = LayerNormParam::LayerNormType::LAYER_NORM_UNDEFINED,
py::arg("norm_param") = LayerNormParam::NormParam(),
py::arg("pre_norm_param") = LayerNormParam::PreNormParam(),
py::arg("post_norm_param") = LayerNormParam::PostNormParam())
.def_readwrite("layer_type", &LayerNormParam::layerType)
.def_readwrite("norm_param", &LayerNormParam::normParam)
.def_readwrite("pre_norm_param", &LayerNormParam::preNormParam)
.def_readwrite("post_norm_param", &LayerNormParam::postNormParam)
.def("__repr__", [](const LayerNormParam ¶m) { return "LayerNormParam: " + OpParamToJson(param).dump(); });
py::class_<ElewiseParam> elewise(m, "ElewiseParam");
py::enum_<ElewiseParam::ElewiseType>(elewise, "ElewiseType")
.value("ELEWISE_UNDEFINED", ElewiseParam::ElewiseType::ELEWISE_UNDEFINED)
.value("ELEWISE_CAST", ElewiseParam::ElewiseType::ELEWISE_CAST)
.value("ELEWISE_MULS", ElewiseParam::ElewiseType::ELEWISE_MULS)
.value("ELEWISE_COS", ElewiseParam::ElewiseType::ELEWISE_COS)
.value("ELEWISE_SIN", ElewiseParam::ElewiseType::ELEWISE_SIN)
.value("ELEWISE_NEG", ElewiseParam::ElewiseType::ELEWISE_NEG)
.value("ELEWISE_QUANT", ElewiseParam::ElewiseType::ELEWISE_QUANT)
.value("ELEWISE_LOGICAL_NOT", ElewiseParam::ElewiseType::ELEWISE_LOGICAL_NOT)
.value("ELEWISE_ADD", ElewiseParam::ElewiseType::ELEWISE_ADD)
.value("ELEWISE_MUL", ElewiseParam::ElewiseType::ELEWISE_MUL)
.value("ELEWISE_REALDIV", ElewiseParam::ElewiseType::ELEWISE_REALDIV)
.value("ELEWISE_LOGICAL_AND", ElewiseParam::ElewiseType::ELEWISE_LOGICAL_AND)
.value("ELEWISE_LOGICAL_OR", ElewiseParam::ElewiseType::ELEWISE_LOGICAL_OR)
.value("ELEWISE_LESS", ElewiseParam::ElewiseType::ELEWISE_LESS)
.value("ELEWISE_GREATER", ElewiseParam::ElewiseType::ELEWISE_GREATER)
.value("ELEWISE_SUB", ElewiseParam::ElewiseType::ELEWISE_SUB)
.value("ELEWISE_EQUAL", ElewiseParam::ElewiseType::ELEWISE_EQUAL)
.value("ELEWISE_QUANT_PER_CHANNEL", ElewiseParam::ElewiseType::ELEWISE_QUANT_PER_CHANNEL)
.value("ELEWISE_DEQUANT_PER_CHANNEL", ElewiseParam::ElewiseType::ELEWISE_DEQUANT_PER_CHANNEL)
.value("ELEWISE_DYNAMIC_QUANT", ElewiseParam::ElewiseType::ELEWISE_DYNAMIC_QUANT)
.value("ELEWISE_TANH", ElewiseParam::ElewiseType::ELEWISE_TANH);
py::class_<ElewiseParam::QuantParam>(elewise, "QuantParam")
.def(py::init<float, bool, int>(),
py::arg("input_scale") = 1.0f,
py::arg("asymmetric") = false,
py::arg("input_offset") = 0)
.def_readwrite("input_scale", &ElewiseParam::QuantParam::inputScale)
.def_readwrite("asymmetric", &ElewiseParam::QuantParam::asymmetric)
.def_readwrite("input_offset", &ElewiseParam::QuantParam::inputOffset);
py::class_<ElewiseParam::MulsParam>(elewise, "MulsParam")
.def(py::init<float>(),
py::arg("var_attr") = 0.0f)
.def_readwrite("var_attr", &ElewiseParam::MulsParam::varAttr);
elewise
.def(py::init<ElewiseParam::ElewiseType, ElewiseParam::QuantParam, ElewiseParam::MulsParam, aclDataType>(),
py::arg("elewise_type") = ElewiseParam::ElewiseType::ELEWISE_UNDEFINED,
py::arg("quant_param") = ElewiseParam::QuantParam(),
py::arg("muls_param") = ElewiseParam::MulsParam(),
py::arg("out_tensor_type") = aclDataType::ACL_DT_UNDEFINED)
.def_readwrite("elewise_type", &ElewiseParam::elewiseType)
.def_readwrite("quant_param", &ElewiseParam::quantParam)
.def_readwrite("muls_param", &ElewiseParam::mulsParam)
.def_readwrite("out_tensor_type", &ElewiseParam::outTensorType)
.def("__repr__", [](const ElewiseParam ¶m) { return "ElewiseParam: " + OpParamToJson(param).dump(); });
py::class_<LinearParam> linear(m, "LinearParam");
py::enum_<LinearParam::MatmulType>(linear, "MatmulType")
.value("MATMUL_UNDEFINED", LinearParam::MatmulType::MATMUL_UNDEFINED)
.value("MATMUL_EIN_SUM", LinearParam::MatmulType::MATMUL_EIN_SUM);
py::enum_<LinearParam::QuantMode>(linear, "QuantMode")
.value("QUANT_UNDEFINED", LinearParam::QuantMode::QUANT_UNDEFINED)
.value("PER_CHANNEL", LinearParam::QuantMode::PER_CHANNEL)
.value("PER_TOKEN", LinearParam::QuantMode::PER_TOKEN);
linear
.def(py::init<bool, bool, bool, aclDataType, bool, LinearParam::MatmulType, LinearParam::QuantMode>(),
py::arg("transpose_a") = false,
py::arg("transpose_b") = true,
py::arg("has_bias") = true,
py::arg("out_data_type") = aclDataType::ACL_DT_UNDEFINED,
py::arg("en_accum") = false,
py::arg("matmul_type") = LinearParam::MatmulType::MATMUL_UNDEFINED,
py::arg("quant_mode") = LinearParam::QuantMode::QUANT_UNDEFINED)
.def_readwrite("transpose_a", &LinearParam::transposeA)
.def_readwrite("transpose_b", &LinearParam::transposeB)
.def_readwrite("has_bias", &LinearParam::hasBias)
.def_readwrite("out_data_type", &LinearParam::outDataType)
.def_readwrite("en_accum", &LinearParam::enAccum)
.def_readwrite("matmul_type", &LinearParam::matmulType)
.def_readwrite("quant_mode", &LinearParam::quantMode)
.def("__repr__", [](const LinearParam ¶m) { return "LinearParam: " + OpParamToJson(param).dump(); });
py::class_<SoftmaxParam>(m, "SoftmaxParam")
.def(py::init([](const std::vector<int64_t> &axes) {
SoftmaxParam param;
AddElements(axes, param.axes);
return param;
}),
py::arg("axes") = py::list())
.def_readwrite("axes", &SoftmaxParam::axes)
.def_property(
"axes",
[](SoftmaxParam ¶m) -> std::vector<int64_t> {
std::vector<int64_t> axisVec;
axisVec.resize(param.axes.size());
for (size_t i = 0; i < param.axes.size(); i++) {
axisVec.at(i) = param.axes.at(i);
}
return axisVec;
},
[](SoftmaxParam ¶m, const std::vector<int64_t> &axes) {
param.axes.clear();
AddElements(axes, param.axes);
})
.def("__repr__", [](const SoftmaxParam ¶m) { return "SoftmaxParam: " + OpParamToJson(param).dump(); });
py::class_<SelfAttentionParam> selfAttention(m, "SelfAttentionParam");
py::enum_<SelfAttentionParam::CalcType>(selfAttention, "CalcType")
.value("UNDEFINED", SelfAttentionParam::UNDEFINED)
.value("ENCODER", SelfAttentionParam::ENCODER)
.value("DECODER", SelfAttentionParam::DECODER)
.value("PA_ENCODER", SelfAttentionParam::PA_ENCODER)
.value("PREFIX_ENCODER", SelfAttentionParam::PREFIX_ENCODER);
py::enum_<SelfAttentionParam::KernelType>(selfAttention, "KernelType")
.value("KERNELTYPE_DEFAULT", SelfAttentionParam::KERNELTYPE_DEFAULT)
.value("KERNELTYPE_HIGH_PRECISION", SelfAttentionParam::KERNELTYPE_HIGH_PRECISION);
py::enum_<SelfAttentionParam::ClampType>(selfAttention, "ClampType")
.value("CLAMP_TYPE_UNDEFINED", SelfAttentionParam::CLAMP_TYPE_UNDEFINED)
.value("CLAMP_TYPE_MIN_MAX", SelfAttentionParam::CLAMP_TYPE_MIN_MAX);
py::enum_<SelfAttentionParam::MaskType>(selfAttention, "MaskType")
.value("MASK_TYPE_UNDEFINED", SelfAttentionParam::MASK_TYPE_UNDEFINED)
.value("MASK_TYPE_NORM", SelfAttentionParam::MASK_TYPE_NORM)
.value("MASK_TYPE_ALIBI", SelfAttentionParam::MASK_TYPE_ALIBI)
.value("MASK_TYPE_NORM_COMPRESS", SelfAttentionParam::MASK_TYPE_NORM_COMPRESS)
.value("MASK_TYPE_ALIBI_COMPRESS", SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS)
.value("MASK_TYPE_ALIBI_COMPRESS_SQRT", SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_SQRT)
.value("MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN", SelfAttentionParam::MASK_TYPE_ALIBI_COMPRESS_LEFT_ALIGN)
.value("MASK_TYPE_SLIDING_WINDOW_NORM", SelfAttentionParam::MASK_TYPE_SLIDING_WINDOW_NORM)
.value("MASK_TYPE_SLIDING_WINDOW_COMPRESS", SelfAttentionParam::MASK_TYPE_SLIDING_WINDOW_COMPRESS);
py::enum_<SelfAttentionParam::KvCacheCfg>(selfAttention, "KvCacheCfg")
.value("K_CACHE_V_CACHE", SelfAttentionParam::K_CACHE_V_CACHE)
.value("K_BYPASS_V_BYPASS", SelfAttentionParam::K_BYPASS_V_BYPASS);
py::enum_<SelfAttentionParam::ScaleType>(selfAttention, "ScaleType")
.value("SCALE_TYPE_TOR", SelfAttentionParam::SCALE_TYPE_TOR)
.value("SCALE_TYPE_LOGN", SelfAttentionParam::SCALE_TYPE_LOGN)
.value("SCALE_TYPE_MAX", SelfAttentionParam::SCALE_TYPE_MAX);
py::enum_<SelfAttentionParam::QuantType>(selfAttention, "QuantType")
.value("TYPE_QUANT_UNQUANT", SelfAttentionParam::TYPE_QUANT_UNQUANT)
.value("TYPE_DEQUANT_FUSION", SelfAttentionParam::TYPE_DEQUANT_FUSION)
.value("TYPE_QUANT_QKV_OFFLINE", SelfAttentionParam::TYPE_QUANT_QKV_OFFLINE)
.value("TYPE_QUANT_QKV_ONLINE", SelfAttentionParam::TYPE_QUANT_QKV_ONLINE);
py::enum_<SelfAttentionParam::CacheType>(selfAttention, "CacheType")
.value("CACHE_TYPE_NORM", SelfAttentionParam::CACHE_TYPE_NORM)
.value("CACHE_TYPE_SWA", SelfAttentionParam::CACHE_TYPE_SWA);
selfAttention
.def(py::init<SelfAttentionParam::QuantType, aclDataType, int32_t, int32_t, float, float, bool, uint32_t,
SelfAttentionParam::CalcType, SelfAttentionParam::KernelType, SelfAttentionParam::ClampType,
float, float, SelfAttentionParam::MaskType, SelfAttentionParam::KvCacheCfg,
SelfAttentionParam::ScaleType, InputLayout, uint32_t, SelfAttentionParam::CacheType, uint32_t>(),
py::arg("quant_type") = SelfAttentionParam::QuantType::TYPE_QUANT_UNQUANT,
py::arg("out_data_type") = aclDataType::ACL_DT_UNDEFINED,
py::arg("head_num") = 0,
py::arg("kv_head_num") = 0,
py::arg("q_scale") = 1,
py::arg("qk_scale") = 1,
py::arg("batch_run_status_enable") = false,
py::arg("is_triu_mask") = 0,
py::arg("calc_type") = SelfAttentionParam::CalcType::UNDEFINED,
py::arg("kernel_type") = SelfAttentionParam::KernelType::KERNELTYPE_DEFAULT,
py::arg("clamp_type") = SelfAttentionParam::ClampType::CLAMP_TYPE_UNDEFINED,
py::arg("clamp_min") = 0,
py::arg("clamp_max") = 0,
py::arg("mask_type") = SelfAttentionParam::MaskType::MASK_TYPE_UNDEFINED,
py::arg("kvcache_cfg") = SelfAttentionParam::KvCacheCfg::K_CACHE_V_CACHE,
py::arg("scale_type") = SelfAttentionParam::ScaleType::SCALE_TYPE_TOR,
py::arg("input_layout") = InputLayout::TYPE_BSND,
py::arg("mla_v_head_size") = 0,
py::arg("cache_type") = SelfAttentionParam::CacheType::CACHE_TYPE_NORM,
py::arg("window_size") = 0)
.def_readwrite("quant_type", &SelfAttentionParam::quantType)
.def_readwrite("out_data_type", &SelfAttentionParam::outDataType)
.def_readwrite("head_num", &SelfAttentionParam::headNum)
.def_readwrite("kv_head_num", &SelfAttentionParam::kvHeadNum)
.def_readwrite("q_scale", &SelfAttentionParam::qScale)
.def_readwrite("qk_scale", &SelfAttentionParam::qkScale)
.def_readwrite("batch_run_status_enable", &SelfAttentionParam::batchRunStatusEnable)
.def_readwrite("is_triu_mask", &SelfAttentionParam::isTriuMask)
.def_readwrite("calc_type", &SelfAttentionParam::calcType)
.def_readwrite("kernel_type", &SelfAttentionParam::kernelType)
.def_readwrite("clamp_type", &SelfAttentionParam::clampType)
.def_readwrite("clamp_min", &SelfAttentionParam::clampMin)
.def_readwrite("clamp_max", &SelfAttentionParam::clampMax)
.def_readwrite("mask_type", &SelfAttentionParam::maskType)
.def_readwrite("kvcache_cfg", &SelfAttentionParam::kvcacheCfg)
.def_readwrite("scale_type", &SelfAttentionParam::scaleType)
.def_readwrite("input_layout", &SelfAttentionParam::inputLayout)
.def_readwrite("mla_v_head_size", &SelfAttentionParam::mlaVHeadSize)
.def_readwrite("cache_type", &SelfAttentionParam::cacheType)
.def_readwrite("window_size", &SelfAttentionParam::windowSize)
.def("__repr__",
[](const SelfAttentionParam ¶m) { return "SelfAttentionParam: " + OpParamToJson(param).dump(); });
py::class_<PagedAttentionParam> pageAttention(m, "PagedAttentionParam");
py::enum_<PagedAttentionParam::CalcType>(pageAttention, "CalcType")
.value("CALC_TYPE_UNDEFINED", PagedAttentionParam::CalcType::CALC_TYPE_UNDEFINED)
.value("CALC_TYPE_SPEC", PagedAttentionParam::CalcType::CALC_TYPE_SPEC);
py::enum_<PagedAttentionParam::CompressType>(pageAttention, "CompressType")
.value("COMPRESS_TYPE_UNDEFINED", PagedAttentionParam::CompressType::COMPRESS_TYPE_UNDEFINED)
.value("COMPRESS_TYPE_KVHEAD", PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD)
.value("COMPRESS_TYPE_KVHEAD_ROPE", PagedAttentionParam::CompressType::COMPRESS_TYPE_KVHEAD_ROPE)
.value("COMPRESS_TYPE_MAX", PagedAttentionParam::CompressType::COMPRESS_TYPE_MAX);
py::enum_<PagedAttentionParam::MaskType>(pageAttention, "MaskType")
.value("UNDEFINED", PagedAttentionParam::MaskType::UNDEFINED)
.value("MASK_TYPE_NORM", PagedAttentionParam::MaskType::MASK_TYPE_NORM)
.value("MASK_TYPE_ALIBI", PagedAttentionParam::MaskType::MASK_TYPE_ALIBI)
.value("MASK_TYPE_SPEC", PagedAttentionParam::MaskType::MASK_TYPE_SPEC);
py::enum_<PagedAttentionParam::ScaleType>(pageAttention, "ScaleType")
.value("SCALE_TYPE_TOR", PagedAttentionParam::ScaleType::SCALE_TYPE_TOR)
.value("SCALE_TYPE_LOGN", PagedAttentionParam::ScaleType::SCALE_TYPE_LOGN)
.value("SCALE_TYPE_MAX", PagedAttentionParam::ScaleType::SCALE_TYPE_MAX);
py::enum_<PagedAttentionParam::QuantType>(pageAttention, "QuantType")
.value("TYPE_QUANT_UNQUANT", PagedAttentionParam::QuantType::TYPE_QUANT_UNQUANT)
.value("TYPE_DEQUANT_FUSION", PagedAttentionParam::QuantType::TYPE_DEQUANT_FUSION)
.value("TYPE_QUANT_QKV_OFFLINE", PagedAttentionParam::QuantType::TYPE_QUANT_QKV_OFFLINE)
.value("TYPE_QUANT_QKV_ONLINE", PagedAttentionParam::QuantType::TYPE_QUANT_QKV_ONLINE);
pageAttention
.def(py::init<int32_t, float, int32_t, PagedAttentionParam::MaskType, bool, PagedAttentionParam::QuantType,
aclDataType, bool, PagedAttentionParam::CompressType, PagedAttentionParam::CalcType,
PagedAttentionParam::ScaleType, InputLayout>(),
py::arg("head_num") = 0,
py::arg("qk_scale") = 1.0,
py::arg("kv_head_num") = 0,
py::arg("mask_type") = PagedAttentionParam::MaskType::UNDEFINED,
py::arg("batch_run_status_enable") = false,
py::arg("quant_type") = PagedAttentionParam::QuantType::TYPE_QUANT_UNQUANT,
py::arg("out_data_type") = aclDataType::ACL_DT_UNDEFINED,
py::arg("has_quant_offset") = false,
py::arg("compress_type") = PagedAttentionParam::CompressType::COMPRESS_TYPE_UNDEFINED,
py::arg("calc_type") = PagedAttentionParam::CalcType::CALC_TYPE_UNDEFINED,
py::arg("scale_type") = PagedAttentionParam::ScaleType::SCALE_TYPE_TOR,
py::arg("input_layout") = InputLayout::TYPE_BSND)
.def_readwrite("head_num", &PagedAttentionParam::headNum)
.def_readwrite("qk_scale", &PagedAttentionParam::qkScale)
.def_readwrite("kv_head_num", &PagedAttentionParam::kvHeadNum)
.def_readwrite("mask_type", &PagedAttentionParam::maskType)
.def_readwrite("batch_run_status_enable", &PagedAttentionParam::batchRunStatusEnable)
.def_readwrite("quant_type", &PagedAttentionParam::quantType)
.def_readwrite("out_data_type", &PagedAttentionParam::outDataType)
.def_readwrite("has_quant_offset", &PagedAttentionParam::hasQuantOffset)
.def_readwrite("compress_type", &PagedAttentionParam::compressType)
.def_readwrite("calc_type", &PagedAttentionParam::calcType)
.def_readwrite("scale_type", &PagedAttentionParam::scaleType)
.def_readwrite("input_layout", &PagedAttentionParam::inputLayout)
.def("__repr__",
[](const PagedAttentionParam ¶m) { return "PagedAttentionParam: " + OpParamToJson(param).dump(); });
py::class_<RopeParam>(m, "RopeParam")
.def(py::init<int32_t, int32_t>(),
py::arg("rotary_coeff") = 4,
py::arg("cos_format") = 0)
.def_readwrite("rotary_coeff", &RopeParam::rotaryCoeff)
.def_readwrite("cos_format", &RopeParam::cosFormat)
.def("__repr__", [](const RopeParam ¶m) { return "RopeParam: " + OpParamToJson(param).dump(); });
py::class_<SplitParam>(m, "SplitParam")
.def(py::init([](int32_t splitDim, int32_t splitNum, const std::vector<int32_t> &splitSizes) {
SplitParam param;
param.splitDim = splitDim;
param.splitNum = splitNum;
AddElements(splitSizes, param.splitSizes);
return param;
}),
py::arg("split_dim") = 0,
py::arg("split_num") = 2,
py::arg("split_sizes") = py::list())
.def_readwrite("split_dim", &SplitParam::splitDim)
.def_readwrite("split_num", &SplitParam::splitNum)
.def_property(
"split_sizes",
[](SplitParam ¶m) -> std::vector<int32_t> {
std::vector<int32_t> vec;
vec.resize(param.splitSizes.size());
for (size_t i = 0; i < param.splitSizes.size(); i++) {
vec.at(i) = param.splitSizes.at(i);
}
return vec;
},
[](SplitParam ¶m, const std::vector<int32_t> &splitSizes) {
param.splitSizes.clear();
AddElements(splitSizes, param.splitSizes);
})
.def("__repr__", [](const SplitParam ¶m) { return "SplitParam: " + OpParamToJson(param).dump(); });
py::class_<GatherParam>(m, "GatherParam")
.def(py::init<int64_t, int64_t>(),
py::arg("axis") = 0,
py::arg("batch_dims") = 0)
.def_readwrite("axis", &GatherParam::axis)
.def_readwrite("batch_dims", &GatherParam::batchDims)
.def("__repr__", [](const GatherParam ¶m) { return "GatherParam: " + OpParamToJson(param).dump(); });
py::class_<ActivationParam> activation(m, "ActivationParam");
py::enum_<ActivationParam::GeLUMode>(activation, "GeLUMode")
.value("TANH_MODE", ActivationParam::TANH_MODE)
.value("NONE_MODE", ActivationParam::NONE_MODE);
activation
.def(py::init<ActivationType, float, int32_t, ActivationParam::GeLUMode>(),
py::arg("activation_type") = ActivationType::ACTIVATION_UNDEFINED,
py::arg("scale") = 1.0f,
py::arg("dim") = -1,
py::arg("gelu_mode") = ActivationParam::GeLUMode::TANH_MODE)
.def_readwrite("activation_type", &ActivationParam::activationType)
.def_readwrite("scale", &ActivationParam::scale)
.def_readwrite("dim", &ActivationParam::dim)
.def_readwrite("gelu_mode", &ActivationParam::geluMode)
.def("__repr__",
[](const ActivationParam ¶m) { return "ActivationParam: " + OpParamToJson(param).dump(); });
py::class_<RmsNormParam> rmsNorm(m, "RmsNormParam");
py::enum_<RmsNormParam::RmsNormType>(rmsNorm, "RmsNormType")
.value("RMS_NORM_UNDEFINED", RmsNormParam::RMS_NORM_UNDEFINED)
.value("RMS_NORM_NORM", RmsNormParam::RMS_NORM_NORM)
.value("RMS_NORM_PRENORM", RmsNormParam::RMS_NORM_PRENORM)
.value("RMS_NORM_POSTNORM", RmsNormParam::RMS_NORM_POSTNORM);
py::enum_<RmsNormParam::PrecisionMode>(rmsNorm, "PrecisionMode")
.value("HIGH_PRECISION_MODE", RmsNormParam::HIGH_PRECISION_MODE)
.value("HIGH_PERFORMANCE_MODE", RmsNormParam::HIGH_PERFORMANCE_MODE);
py::enum_<RmsNormParam::ModelType>(rmsNorm, "ModelType")
.value("LLAMA_MODEL", RmsNormParam::LLAMA_MODEL)
.value("GEMMA_MODEL", RmsNormParam::GEMMA_MODEL);
py::class_<RmsNormParam::NormParam>(rmsNorm, "NormParam")
.def(py::init<QuantType, float, double, bool, RmsNormParam::PrecisionMode, RmsNormParam::ModelType,
DynamicQuantType>(),
py::arg("quant_type") = QuantType::QUANT_UNQUANT,
py::arg("epsilon") = 1e-5,
py::arg("layer_norm_eps") = 1e-5,
py::arg("rstd") = false,
py::arg("precision_mode") = RmsNormParam::PrecisionMode::HIGH_PRECISION_MODE,
py::arg("model_type") = RmsNormParam::ModelType::LLAMA_MODEL,
py::arg("dynamic_quant_type") = DynamicQuantType::DYNAMIC_QUANT_UNDEFINED)
.def_readwrite("quant_type", &RmsNormParam::NormParam::quantType)
.def_readwrite("epsilon", &RmsNormParam::NormParam::epsilon)
.def_readwrite("layer_norm_eps", &RmsNormParam::NormParam::layerNormEps)
.def_readwrite("rstd", &RmsNormParam::NormParam::rstd)
.def_readwrite("precision_mode", &RmsNormParam::NormParam::precisionMode)
.def_readwrite("model_type", &RmsNormParam::NormParam::modelType)
.def_readwrite("dynamic_quant_type", &RmsNormParam::NormParam::dynamicQuantType);
py::class_<RmsNormParam::PreNormParam>(rmsNorm, "PreNormParam")
.def(py::init<QuantType, float, bool>(),
py::arg("quant_type") = QuantType::QUANT_UNQUANT,
py::arg("epsilon") = 1e-5,
py::arg("has_bias") = false)
.def_readwrite("quant_type", &RmsNormParam::PreNormParam::quantType)
.def_readwrite("epsilon", &RmsNormParam::PreNormParam::epsilon)
.def_readwrite("has_bias", &RmsNormParam::PreNormParam::hasBias);
py::class_<RmsNormParam::PostNormParam>(rmsNorm, "PostNormParam")
.def(py::init<QuantType, float, bool>(),
py::arg("quant_type") = QuantType::QUANT_UNQUANT,
py::arg("epsilon") = 1e-5,
py::arg("has_bias") = false)
.def_readwrite("quant_type", &RmsNormParam::PostNormParam::quantType)
.def_readwrite("epsilon", &RmsNormParam::PostNormParam::epsilon)
.def_readwrite("has_bias", &RmsNormParam::PostNormParam::hasBias);
rmsNorm
.def(py::init<RmsNormParam::RmsNormType, RmsNormParam::NormParam, RmsNormParam::PreNormParam,
RmsNormParam::PostNormParam>(),
py::arg("layer_type") = RmsNormParam::RmsNormType::RMS_NORM_UNDEFINED,
py::arg("norm_param") = RmsNormParam::NormParam(),
py::arg("pre_norm_param") = RmsNormParam::PreNormParam(),
py::arg("post_norm_param") = RmsNormParam::PostNormParam())
.def_readwrite("layer_type", &RmsNormParam::layerType)
.def_readwrite("norm_param", &RmsNormParam::normParam)
.def_readwrite("pre_norm_param", &RmsNormParam::preNormParam)
.def_readwrite("post_norm_param", &RmsNormParam::postNormParam)
.def("__repr__", [](const RmsNormParam ¶m) { return "RmsNormParam: " + OpParamToJson(param).dump(); });
py::class_<AllGatherParam>(m, "AllGatherParam")
.def(py::init<int, int, int, std::string, HcclComm, CommMode, std::string, std::string>(),
py::arg("rank") = 0,
py::arg("rank_size") = 0,
py::arg("rank_root") = 0,
py::arg("backend") = "hccl",
py::arg("hccl_comm") = nullptr,
py::arg("comm_mode") = CommMode::COMM_MULTI_PROCESS,
py::arg("rank_table_file") = "",
py::arg("comm_domain") = "")
.def_readwrite("rank", &AllGatherParam::rank)
.def_readwrite("rank_size", &AllGatherParam::rankSize)
.def_readwrite("rank_root", &AllGatherParam::rankRoot)
.def_readwrite("backend", &AllGatherParam::backend)
.def_readwrite("hccl_comm", &AllGatherParam::hcclComm)
.def_readwrite("comm_mode", &AllGatherParam::commMode)
.def_readwrite("rank_table_file", &AllGatherParam::rankTableFile)
.def_readwrite("comm_domain", &AllGatherParam::commDomain)
.def("__repr__", [](const AllGatherParam ¶m) { return "AllGatherParam: " + OpParamToJson(param).dump(); });
py::class_<AsStridedParam>(m, "AsStridedParam")
.def(py::init([](const std::vector<int64_t> &size,
const std::vector<int64_t> &stride,
const std::vector<int64_t> &offset) {
AsStridedParam param;
AddElements(size, param.size);
AddElements(stride, param.stride);
AddElements(offset, param.offset);
return param;
}),
py::arg("size") = py::list(),
py::arg("stride") = py::list(),
py::arg("offset") = py::list())
.def_readwrite("size", &AsStridedParam::size)
.def_readwrite("stride", &AsStridedParam::stride)
.def_readwrite("offset", &AsStridedParam::offset)
.def_property(
"size",
[](AsStridedParam ¶m) -> std::vector<int64_t> {
std::vector<int64_t> sizeVec;
sizeVec.resize(param.size.size());
for (size_t i = 0; i < param.size.size(); i++) {
sizeVec.at(i) = param.size.at(i);
}
return sizeVec;
},
[](AsStridedParam ¶m, const std::vector<int64_t> &size) {
param.size.clear();
AddElements(size, param.size);
})
.def_property(
"stride",
[](AsStridedParam ¶m) -> std::vector<int64_t> {
std::vector<int64_t> strideVec;
strideVec.resize(param.stride.size());
for (size_t i = 0; i < param.stride.size(); i++) {
strideVec.at(i) = param.stride.at(i);
}
return strideVec;
},
[](AsStridedParam ¶m, const std::vector<int64_t> &stride) {
param.stride.clear();
AddElements(stride, param.stride);
})
.def_property(
"offset",
[](AsStridedParam ¶m) -> std::vector<int64_t> {
std::vector<int64_t> offsetVec;
offsetVec.resize(param.offset.size());
for (size_t i = 0; i < param.offset.size(); i++) {
offsetVec.at(i) = param.offset.at(i);
}
return offsetVec;
},
[](AsStridedParam ¶m, const std::vector<int64_t> &offset) {
param.offset.clear();
AddElements(offset, param.offset);
})
.def("__repr__", [](const AsStridedParam ¶m) { return "AsStridedParam: " + OpParamToJson(param).dump(); });
py::class_<CumsumParam>(m, "CumsumParam")
.def(py::init([](const std::vector<int64_t> &axes, bool exclusive, bool reverse) {
CumsumParam param;
AddElements(axes, param.axes);
param.exclusive = exclusive;
param.reverse = reverse;
return param;
}),
py::arg("axes") = py::list(),
py::arg("exclusive") = false,
py::arg("reverse") = false)
.def_readwrite("axes", &CumsumParam::axes)
.def_readwrite("exclusive", &CumsumParam::exclusive)
.def_readwrite("reverse", &CumsumParam::reverse)
.def_property(
"axes",
[](CumsumParam ¶m) -> std::vector<int64_t> {
std::vector<int64_t> vec;
vec.resize(param.axes.size());
for (size_t i = 0; i < param.axes.size(); i++) {
vec.at(i) = param.axes.at(i);
}
return vec;
},
[](CumsumParam ¶m, const std::vector<int64_t> &axes) {
param.axes.clear();
AddElements(axes, param.axes);
})
.def("__repr__", [](const CumsumParam ¶m) { return "CumsumParam: " + OpParamToJson(param).dump(); });
py::class_<DynamicNTKParam>(m, "DynamicNTKParam")
.def(py::init<aclDataType>(),
py::arg("out_data_type") = aclDataType::ACL_DT_UNDEFINED)
.def_readwrite("out_data_type", &DynamicNTKParam::outDataType)
.def("__repr__",
[](const DynamicNTKParam ¶m) { return "DynamicNTKParam: " + OpParamToJson(param).dump(); });
py::class_<MultinomialParam>(m, "MultinomialParam")
.def(py::init<uint32_t, uint32_t>(),
py::arg("num_samples") = 1,
py::arg("rand_seed") = 0)
.def_readwrite("num_samples", &MultinomialParam::numSamples)
.def_readwrite("rand_seed", &MultinomialParam::randSeed)
.def("__repr__",
[](const MultinomialParam ¶m) { return "MultinomialParam: " + OpParamToJson(param).dump(); });
py::class_<ConcatParam>(m, "ConcatParam")
.def(py::init<int>(),
py::arg("concat_dim") = 0)
.def_readwrite("concat_dim", &ConcatParam::concatDim)
.def("__repr__", [](const ConcatParam ¶m) { return "ConcatParam: " + OpParamToJson(param).dump(); });
py::class_<SliceParam>(m, "SliceParam")
.def(py::init([](const std::vector<int64_t> &offsets, const std::vector<int64_t> &size) {
SliceParam param;
AddElements(offsets, param.offsets);
AddElements(size, param.size);
return param;
}),
py::arg("offsets") = py::list(),
py::arg("size") = py::list())
.def_readwrite("offsets", &SliceParam::offsets)
.def_readwrite("size", &SliceParam::size)
.def_property(
"offsets",
[](SliceParam ¶m) -> std::vector<int64_t> {
std::vector<int64_t> offsetsVec;
offsetsVec.resize(param.offsets.size());
for (size_t i = 0; i < param.offsets.size(); i++) {
offsetsVec.at(i) = param.offsets.at(i);
}
return offsetsVec;
},
[](SliceParam ¶m, const std::vector<int64_t> &offsets) {
param.offsets.clear();
AddElements(offsets, param.offsets);
})
.def_property(
"size",
[](SliceParam ¶m) -> std::vector<int64_t> {
std::vector<int64_t> sizeVec;
sizeVec.resize(param.size.size());
for (size_t i = 0; i < param.size.size(); i++) {
sizeVec.at(i) = param.size.at(i);
}
return sizeVec;
},
[](SliceParam ¶m, const std::vector<int64_t> &size) {
param.size.clear();
AddElements(size, param.size);
})
.def("__repr__", [](const SliceParam ¶m) { return "SliceParam: " + OpParamToJson(param).dump(); });
py::class_<TransposeParam>(m, "TransposeParam")
.def(py::init([](const std::vector<int32_t> &perm) {
TransposeParam param;
AddElements(perm, param.perm);
return param;
}),
py::arg("perm") = py::list())
.def_readwrite("perm", &TransposeParam::perm)
.def_property(
"perm",
[](TransposeParam ¶m) -> std::vector<int32_t> {
std::vector<int32_t> vec;
vec.resize(param.perm.size());
for (size_t i = 0; i < param.perm.size(); i++) {
vec.at(i) = param.perm.at(i);
}
return vec;
},
[](TransposeParam ¶m, const std::vector<int32_t> &perm) {
param.perm.clear();
AddElements(perm, param.perm);
})
.def("__repr__", [](const TransposeParam ¶m) { return "TransposeParam: " + OpParamToJson(param).dump(); });
py::class_<GatingParam>(m, "GatingParam")
.def(py::init([](int32_t topkExpertNum, int32_t cumSumNum, bool cumSumInt64,
const std::vector<int32_t> &deviceExpert) {
GatingParam param;
param.topkExpertNum = topkExpertNum;
param.cumSumNum = cumSumNum;
param.cumSumInt64 = cumSumInt64;
AddElements(deviceExpert, param.deviceExpert);
return param;
}),
py::arg("topk_expert_num") = 1,
py::arg("cum_sum_num") = 0,
py::arg("cum_sum_int64") = false,
py::arg("device_expert") = py::list())
.def_readwrite("topk_expert_num", &GatingParam::topkExpertNum)
.def_readwrite("cum_sum_num", &GatingParam::cumSumNum)
.def_readwrite("cum_sum_int64", &GatingParam::cumSumInt64)
.def_readwrite("device_expert", &GatingParam::deviceExpert)
.def("__repr__", [](const GatingParam ¶m) { return "GatingParam: " + OpParamToJson(param).dump(); });
py::class_<ReshapeAndCacheParam> reshapeAndCache(m, "ReshapeAndCacheParam");
py::enum_<ReshapeAndCacheParam::CompressType>(reshapeAndCache, "CompressType")
.value("COMPRESS_TYPE_UNDEFINED", ReshapeAndCacheParam::COMPRESS_TYPE_UNDEFINED)
.value("COMPRESS_TYPE_KVHEAD", ReshapeAndCacheParam::COMPRESS_TYPE_KVHEAD)
.value("COMPRESS_TYPE_KVHEAD_ROPE", ReshapeAndCacheParam::COMPRESS_TYPE_KVHEAD_ROPE);
py::enum_<ReshapeAndCacheParam::KvCacheCfg>(reshapeAndCache, "KvCacheCfg")
.value("K_CACHE_V_CACHE", ReshapeAndCacheParam::K_CACHE_V_CACHE)
.value("K_CACHE_V_BYPASS", ReshapeAndCacheParam::K_CACHE_V_BYPASS);
reshapeAndCache
.def(py::init<ReshapeAndCacheParam::CompressType, ReshapeAndCacheParam::KvCacheCfg>(),
py::arg("compress_type") = ReshapeAndCacheParam::CompressType::COMPRESS_TYPE_UNDEFINED,
py::arg("kv_cache_cfg") = ReshapeAndCacheParam::KvCacheCfg::K_CACHE_V_CACHE)
.def_readwrite("compress_type", &ReshapeAndCacheParam::compressType)
.def_readwrite("kv_cache_cfg", &ReshapeAndCacheParam::kvCacheCfg)
.def("__repr__",
[](const ReshapeAndCacheParam ¶m) { return "ReshapeAndCacheParam: " + OpParamToJson(param).dump(); });
py::class_<FillParam>(m, "FillParam")
.def(py::init([](bool withMask, const std::vector<float> &value, const std::vector<int64_t> &outDim) {
FillParam param;
param.withMask = withMask;
AddElements(value, param.value);
AddElements(outDim, param.outDim);
return param;
}),
py::arg("with_mask") = true,
py::arg("value") = py::list(),
py::arg("out_dim") = py::list())
.def_readwrite("with_mask", &FillParam::withMask)
.def_readwrite("value", &FillParam::value)
.def_readwrite("out_dim", &FillParam::outDim)
.def_property(
"value",
[](FillParam ¶m) -> std::vector<float> {
std::vector<float> valueVec;
valueVec.resize(param.value.size());
for (size_t i = 0; i < param.value.size(); i++) {
valueVec.at(i) = param.value.at(i);
}
return valueVec;
},
[](FillParam ¶m, const std::vector<float> &value) {
param.value.clear();
AddElements(value, param.value);
})
.def_property(
"out_dim",
[](FillParam ¶m) -> std::vector<int64_t> {
std::vector<int64_t> outDimVec;
outDimVec.resize(param.outDim.size());
for (size_t i = 0; i < param.outDim.size(); i++) {
outDimVec.at(i) = param.outDim.at(i);
}
return outDimVec;
},
[](FillParam ¶m, const std::vector<float> &outDim) {
param.outDim.clear();
AddElements(outDim, param.outDim);
})
.def("__repr__", [](const FillParam ¶m) { return "FillParam: " + OpParamToJson(param).dump(); });
py::class_<RazorFusionAttentionParam>(m, "RazorFusionAttentionParam")
.def(py::init<int32_t, int32_t, float, int32_t, int32_t, int32_t, int32_t, int32_t, int32_t, int32_t>(),
py::arg("head_num") = 1,
py::arg("kv_head_num") = 1,
py::arg("qk_scale") = 1,
py::arg("razor_len") = 0,
py::arg("pre_tokens") = 0,
py::arg("next_tokens") = 0,
py::arg("tile_q") = 0,
py::arg("tile_kv") = 0,
py::arg("text_q_len") = 0,
py::arg("text_kv_len") = 0)
.def_readwrite("head_num", &RazorFusionAttentionParam::headNum)
.def_readwrite("kv_head_num", &RazorFusionAttentionParam::kvHeadNum)
.def_readwrite("qk_scale", &RazorFusionAttentionParam::qkScale)
.def_readwrite("razor_len", &RazorFusionAttentionParam::razorLen)
.def_readwrite("pre_tokens", &RazorFusionAttentionParam::preTokens)
.def_readwrite("next_tokens", &RazorFusionAttentionParam::nextTokens)
.def_readwrite("tile_q", &RazorFusionAttentionParam::tileQ)
.def_readwrite("tile_kv", &RazorFusionAttentionParam::tileKv)
.def_readwrite("text_q_len", &RazorFusionAttentionParam::textQLen)
.def_readwrite("text_kv_len", &RazorFusionAttentionParam::textKvLen)
.def("__repr__",
[](const RazorFusionAttentionParam ¶m) {
return "RazorFusionAttentionParam: " + OpParamToJson(param).dump();
});
py::class_<AllReduceParam> allReduce(m, "AllReduceParam");
py::enum_<AllReduceParam::QuantType>(reshapeAndCache, "QuantType")
.value("QUANT_TYPE_UNQUANT", AllReduceParam::QuantType::QUANT_TYPE_UNQUANT)
.value("QUANT_TYPE_UNDEFINED", AllReduceParam::QuantType::QUANT_TYPE_UNDEFINED)
.value("QUANT_TYPE_PER_TENSOR", AllReduceParam::QuantType::QUANT_TYPE_PER_TENSOR)
.value("QUANT_TYPE_PER_CHANNEL", AllReduceParam::QuantType::QUANT_TYPE_PER_CHANNEL)
.value("QUANT_TYPE_MAX", AllReduceParam::QuantType::QUANT_TYPE_MAX);
allReduce
.def(py::init<int, int, int, std::string, std::string, HcclComm, CommMode, std::string, std::string,
AllReduceParam::QuantType, aclDataType>(),
py::arg("rank") = 0,
py::arg("rank_size") = 0,
py::arg("rank_root") = 0,
py::arg("all_reduce_type") = "sum",
py::arg("backend") = "hccl",
py::arg("hccl_comm") = nullptr,
py::arg("comm_mode") = CommMode::COMM_MULTI_PROCESS,
py::arg("rank_table_file") = "",
py::arg("comm_domain") = "",
py::arg("quant_type") = AllReduceParam::QuantType::QUANT_TYPE_UNQUANT,
py::arg("out_data_type") = aclDataType::ACL_DT_UNDEFINED)
.def_readwrite("rank", &AllReduceParam::rank)
.def_readwrite("rank_size", &AllReduceParam::rankSize)
.def_readwrite("rank_root", &AllReduceParam::rankRoot)
.def_readwrite("all_reduce_type", &AllReduceParam::allReduceType)
.def_readwrite("backend", &AllReduceParam::backend)
.def_readwrite("hccl_comm", &AllReduceParam::hcclComm)
.def_readwrite("comm_mode", &AllReduceParam::commMode)
.def_readwrite("rank_table_file", &AllReduceParam::rankTableFile)
.def_readwrite("comm_domain", &AllReduceParam::commDomain)
.def_readwrite("quant_type", &AllReduceParam::quantType)
.def_readwrite("out_data_type", &AllReduceParam::outDataType)
.def("__repr__", [](const AllReduceParam ¶m) { return "AllReduceParam: " + OpParamToJson(param).dump(); });
py::class_<BroadcastParam>(m, "BroadcastParam")
.def(py::init<int, int, int, HcclComm, CommMode, std::string, std::string, std::string>(),
py::arg("rank") = 0,
py::arg("rank_size") = 0,
py::arg("rank_root") = 0,
py::arg("hccl_comm") = nullptr,
py::arg("comm_mode") = CommMode::COMM_MULTI_PROCESS,
py::arg("backend") = "hccl",
py::arg("rank_table_file") = "",
py::arg("comm_domain") = "")
.def_readwrite("rank", &BroadcastParam::rank)
.def_readwrite("rank_size", &BroadcastParam::rankSize)
.def_readwrite("rank_root", &BroadcastParam::rankRoot)
.def_readwrite("hccl_comm", &BroadcastParam::hcclComm)
.def_readwrite("comm_mode", &BroadcastParam::commMode)
.def_readwrite("backend", &BroadcastParam::backend)
.def_readwrite("rank_table_file", &BroadcastParam::rankTableFile)
.def_readwrite("comm_domain", &BroadcastParam::commDomain)
.def("__repr__",
[](const BroadcastParam ¶m) { return "BroadcastParam: " + OpParamToJson(param).dump(); });
py::class_<ReduceScatterParam>(m, "ReduceScatterParam")
.def(py::init<int, int, int, std::string, HcclComm, CommMode, std::string, std::string, std::string>(),
py::arg("rank") = 0,
py::arg("rank_size") = 0,
py::arg("rank_root") = 0,
py::arg("reduce_type") = "sum",
py::arg("hccl_comm") = nullptr,
py::arg("comm_mode") = CommMode::COMM_MULTI_PROCESS,
py::arg("backend") = "lccl",
py::arg("rank_table_file") = "",
py::arg("comm_domain") = "")
.def_readwrite("rank", &ReduceScatterParam::rank)
.def_readwrite("rank_size", &ReduceScatterParam::rankSize)
.def_readwrite("rank_root", &ReduceScatterParam::rankRoot)
.def_readwrite("reduce_type", &ReduceScatterParam::reduceType)
.def_readwrite("hccl_comm", &ReduceScatterParam::hcclComm)
.def_readwrite("comm_mode", &ReduceScatterParam::commMode)
.def_readwrite("backend", &ReduceScatterParam::backend)
.def_readwrite("rank_table_file", &ReduceScatterParam::rankTableFile)
.def_readwrite("comm_domain", &ReduceScatterParam::commDomain)
.def("__repr__",
[](const ReduceScatterParam ¶m) { return "ReduceScatterParam: " + OpParamToJson(param).dump(); });
py::class_<ReduceScatterVParam>(m, "ReduceScatterVParam")
.def(py::init([](int rank, int rankSize, int rankRoot, const std::vector<int64_t> &sendCounts,
const std::vector<int64_t> &sdispls, std::int64_t recvCount, std::string reduceType,
HcclComm hcclComm, CommMode commMode, std::string backend, std::string rankTableFile,
std::string commDomain) {
ReduceScatterVParam param;
param.rank = rank;
param.rankSize = rankSize;
param.rankRoot = rankRoot;
AddElements(sendCounts, param.sendCounts);
AddElements(sdispls, param.sdispls);
param.recvCount = recvCount;
param.reduceType = reduceType;
param.hcclComm = hcclComm;
param.commMode = commMode;
param.backend = backend;
param.rankTableFile = rankTableFile;
param.commDomain = commDomain;
return param;
}),
py::arg("rank") = 0,
py::arg("rank_size") = 0,
py::arg("rank_root") = 0,
py::arg("send_counts") = py::list(),
py::arg("sdispls") = py::list(),
py::arg("recv_count") = 0,
py::arg("reduce_type") = "sum",
py::arg("hccl_comm") = nullptr,
py::arg("comm_mode") = CommMode::COMM_MULTI_PROCESS,
py::arg("backend") = "hccl",
py::arg("rank_table_file") = "",
py::arg("comm_domain") = "")
.def_readwrite("rank", &ReduceScatterVParam::rank)
.def_readwrite("rank_size", &ReduceScatterVParam::rankSize)
.def_readwrite("rank_root", &ReduceScatterVParam::rankRoot)
.def_readwrite("send_counts", &ReduceScatterVParam::sendCounts)
.def_readwrite("sdispls", &ReduceScatterVParam::sdispls)
.def_readwrite("recv_count", &ReduceScatterVParam::recvCount)
.def_readwrite("reduce_type", &ReduceScatterVParam::reduceType)
.def_readwrite("hccl_comm", &ReduceScatterVParam::hcclComm)
.def_readwrite("comm_mode", &ReduceScatterVParam::commMode)
.def_readwrite("backend", &ReduceScatterVParam::backend)
.def_readwrite("rank_table_file", &ReduceScatterVParam::rankTableFile)
.def_readwrite("comm_domain", &ReduceScatterVParam::commDomain)
.def("__repr__",
[](const ReduceScatterVParam ¶m) { return "ReduceScatterVParam: " + OpParamToJson(param).dump(); });
py::class_<FaUpdateParam> faUpdate(m, "FaUpdateParam");
py::enum_<FaUpdateParam::FaUpdateType>(faUpdate, "FaUpdateType")
.value("DECODE_UPDATE", FaUpdateParam::FaUpdateType::DECODE_UPDATE);
faUpdate.def(py::init<FaUpdateParam::FaUpdateType, uint32_t>(),
py::arg("fa_update_type") = FaUpdateParam::FaUpdateType::DECODE_UPDATE,
py::arg("sp") = 1)
.def_readwrite("fa_update_type", &FaUpdateParam::faUpdateType)
.def_readwrite("sp", &FaUpdateParam::sp)
.def("__repr__", [](const FaUpdateParam ¶m) { return "FaUpdateParam: " + OpParamToJson(param).dump(); });
py::class_<LinearParallelParam> linearParallel(m, "LinearParallelParam");
py::enum_<LinearParallelParam::ParallelType>(linearParallel, "ParallelType")
.value("UNDEFINED", LinearParallelParam::ParallelType::UNDEFINED)
.value("LINEAR_ALL_REDUCE", LinearParallelParam::ParallelType::LINEAR_ALL_REDUCE)
.value("LINEAR_REDUCE_SCATTER", LinearParallelParam::ParallelType::LINEAR_REDUCE_SCATTER)
.value("ALL_GATHER_LINEAR", LinearParallelParam::ParallelType::ALL_GATHER_LINEAR)
.value("PURE_LINEAR", LinearParallelParam::ParallelType::PURE_LINEAR)
.value("ALL_GATHER_LINEAR_REDUCE_SCATTER", LinearParallelParam::ParallelType::ALL_GATHER_LINEAR_REDUCE_SCATTER)
.value("MAX", LinearParallelParam::ParallelType::MAX);
py::enum_<LinearParallelParam::QuantType>(linearParallel, "QuantType")
.value("QUANT_TYPE_UNDEFINED", LinearParallelParam::QuantType::QUANT_TYPE_UNDEFINED)
.value("QUANT_TYPE_UNQUANT", LinearParallelParam::QuantType::QUANT_TYPE_UNQUANT)
.value("QUANT_TYPE_PER_TENSOR", LinearParallelParam::QuantType::QUANT_TYPE_PER_TENSOR)
.value("QUANT_TYPE_PER_CHANNEL", LinearParallelParam::QuantType::QUANT_TYPE_PER_CHANNEL)
.value("QUANT_TYPE_PER_GROUP", LinearParallelParam::QuantType::QUANT_TYPE_PER_GROUP)
.value("QUANT_TYPE_PER_TOKEN", LinearParallelParam::QuantType::QUANT_TYPE_PER_TOKEN)
.value("QUANT_TYPE_MAX", LinearParallelParam::QuantType::QUANT_TYPE_MAX);
py::class_<LinearParallelParam::TwoDimTPInfo>(linearParallel, "TwoDimTPInfo")
.def(py::init<uint16_t, uint16_t, uint8_t>(),
py::arg("ag_dim") = 0,
py::arg("rs_dim") = 0,
py::arg("inner_dim_is_ag") = 1)
.def_readwrite("ag_dim", &LinearParallelParam::TwoDimTPInfo::agDim)
.def_readwrite("rs_dim", &LinearParallelParam::TwoDimTPInfo::rsDim)
.def_readwrite("inner_dim_is_ag", &LinearParallelParam::TwoDimTPInfo::innerDimIsAg);
py::class_<LinearParallelParam::MoeInfo>(linearParallel, "MoeInfo")
.def(py::init<int16_t, int8_t, int8_t>(),
py::arg("localExpertNums") = 0,
py::arg("epSize") = 0,
py::arg("tpSize") = 0)
.def_readwrite("localExpertNums", &LinearParallelParam::MoeInfo::localExpertNums)
.def_readwrite("epSize", &LinearParallelParam::MoeInfo::epSize)
.def_readwrite("tpSize", &LinearParallelParam::MoeInfo::tpSize);
linearParallel
.def(py::init<bool, int, int, int, bool, std::string, HcclComm, CommMode, std::string,
LinearParallelParam::ParallelType, bool, LinearParallelParam::QuantType, int32_t, aclDataType,
std::string, LinearParallelParam::TwoDimTPInfo, LinearParallelParam::MoeInfo>(),
py::arg("trans_weight") = true,
py::arg("rank") = 0,
py::arg("rank_size") = 0,
py::arg("rank_root") = 0,
py::arg("has_residual") = false,
py::arg("backend") = "hccl",
py::arg("hccl_comm") = nullptr,
py::arg("comm_mode") = CommMode::COMM_MULTI_PROCESS,
py::arg("rank_table_file") = "",
py::arg("type") = LinearParallelParam::ParallelType::LINEAR_ALL_REDUCE,
py::arg("keep_intermediate") = false,
py::arg("quant_type") = LinearParallelParam::QuantType::QUANT_TYPE_UNQUANT,
py::arg("quant_group_size") = 0,
py::arg("out_data_type") = aclDataType::ACL_DT_UNDEFINED,
py::arg("comm_domain") = "",
py::arg("two_dim_TP_info") = LinearParallelParam::TwoDimTPInfo(),
py::arg("moe_info") = LinearParallelParam::MoeInfo())
.def_readwrite("trans_weight", &LinearParallelParam::transWeight)
.def_readwrite("rank", &LinearParallelParam::rank)
.def_readwrite("rank_size", &LinearParallelParam::rankSize)
.def_readwrite("rank_root", &LinearParallelParam::rankRoot)
.def_readwrite("has_residual", &LinearParallelParam::hasResidual)
.def_readwrite("backend", &LinearParallelParam::backend)
.def_readwrite("hccl_comm", &LinearParallelParam::hcclComm)
.def_readwrite("comm_mode", &LinearParallelParam::commMode)
.def_readwrite("rank_table_file", &LinearParallelParam::rankTableFile)
.def_readwrite("type", &LinearParallelParam::type)
.def_readwrite("keep_intermediate", &LinearParallelParam::keepIntermediate)
.def_readwrite("quant_type", &LinearParallelParam::quantType)
.def_readwrite("quant_group_size", &LinearParallelParam::quantGroupSize)
.def_readwrite("out_data_type", &LinearParallelParam::outDataType)
.def_readwrite("comm_domain", &LinearParallelParam::commDomain)
.def_readwrite("two_dim_TP_info", &LinearParallelParam::twoDimTPInfo)
.def_readwrite("moe_info", &LinearParallelParam::moeInfo)
.def("__repr__",
[](const LinearParallelParam ¶m) { return "LinearParallelParam: " + OpParamToJson(param).dump(); });
py::class_<LinearSparseParam>(m, "LinearSparseParam")
.def(py::init<bool, bool, uint32_t, uint32_t>(),
py::arg("transpose_a") = false,
py::arg("transpose_b") = true,
py::arg("tiling_k") = 8,
py::arg("tiling_n") = 8)
.def_readwrite("transpose_a", &LinearSparseParam::transposeA)
.def_readwrite("transpose_b", &LinearSparseParam::transposeB)
.def_readwrite("tiling_k", &LinearSparseParam::tilingK)
.def_readwrite("tiling_n", &LinearSparseParam::tilingN)
.def("__repr__",
[](const LinearSparseParam ¶m) { return "LinearSparseParam: " + OpParamToJson(param).dump(); });
py::class_<RelayAttentionParam> relayAttention(m, "RelayAttentionParam");
py::enum_<RelayAttentionParam::MaskType>(relayAttention, "MaskType")
.value("MASK_TYPE_UNDEFINED", RelayAttentionParam::MaskType::MASK_TYPE_UNDEFINED)
.value("MASK_TYPE_NORM", RelayAttentionParam::MaskType::MASK_TYPE_NORM);
relayAttention
.def(py::init<int32_t, float, int32_t, RelayAttentionParam::MaskType>(),
py::arg("head_num") = 0,
py::arg("qk_scale") = 1,
py::arg("expected_kv_head_num") = 0,
py::arg("mask_type") = RelayAttentionParam::MaskType::MASK_TYPE_UNDEFINED)
.def_readwrite("head_num", &RelayAttentionParam::headNum)
.def_readwrite("qk_scale", &RelayAttentionParam::qkScale)
.def_readwrite("kv_head_num", &RelayAttentionParam::kvHeadNum)
.def_readwrite("mask_type", &RelayAttentionParam::maskType)
.def("__repr__",
[](const RelayAttentionParam ¶m) { return "RelayAttentionParam: " + OpParamToJson(param).dump(); });
py::class_<TopkToppSamplingParam> topkToppSampling(m, "TopkToppSamplingParam");
py::enum_<TopkToppSamplingParam::TopkToppSamplingType>(topkToppSampling, "TopkToppSamplingType")
.value("SAMPLING_UNDEFINED", TopkToppSamplingParam::TopkToppSamplingType::SAMPLING_UNDEFINED)
.value("SINGLE_TOPK_SAMPLING", TopkToppSamplingParam::TopkToppSamplingType::SINGLE_TOPK_SAMPLING)
.value("BATCH_TOPK_MULTINOMIAL_SAMPLING",
TopkToppSamplingParam::TopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_SAMPLING)
.value("BATCH_TOPK_EXPONENTIAL_SAMPLING",
TopkToppSamplingParam::TopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_SAMPLING)
.value("BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING",
TopkToppSamplingParam::TopkToppSamplingType::BATCH_TOPK_MULTINOMIAL_LOGPROBS_SAMPLING)
.value("BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING",
TopkToppSamplingParam::TopkToppSamplingType::BATCH_TOPK_EXPONENTIAL_LOGPROBS_SAMPLING)
.value("SAMPLING_MAX", TopkToppSamplingParam::TopkToppSamplingType::SAMPLING_MAX);
topkToppSampling
.def(py::init([](TopkToppSamplingParam::TopkToppSamplingType topkToppSamplingType,
const std::vector<uint32_t> &randSeeds, uint32_t randSeed, uint32_t topk,
int32_t logProbsSize) {
TopkToppSamplingParam param;
param.topkToppSamplingType = topkToppSamplingType;
AddElements(randSeeds, param.randSeeds);
param.randSeed = randSeed;
param.topk = topk;
param.logProbsSize = logProbsSize;
return param;
}),
py::arg("topk_topp_sampling_type") = TopkToppSamplingParam::TopkToppSamplingType::SINGLE_TOPK_SAMPLING,
py::arg("rand_seeds") = py::list(),
py::arg("rand_seed") = 0,
py::arg("topk") = 100,
py::arg("log_probs_size") = 0)
.def_readwrite("topk_topp_sampling_type", &TopkToppSamplingParam::topkToppSamplingType)
.def_readwrite("rand_seeds", &TopkToppSamplingParam::randSeeds)
.def_readwrite("rand_seed", &TopkToppSamplingParam::randSeed)
.def_readwrite("topk", &TopkToppSamplingParam::topk)
.def_readwrite("log_probs_size", &TopkToppSamplingParam::logProbsSize)
.def("__repr__", [](const TopkToppSamplingParam ¶m) {
return "TopkToppSamplingParam: " + OpParamToJson(param).dump();
});
py::class_<AllToAllParam>(m, "AllToAllParam")
.def(py::init<int, int, int, std::string, HcclComm, CommMode, std::string, std::string, bool>(),
py::arg("rank") = 0,
py::arg("rank_size") = 0,
py::arg("rank_root") = 0,
py::arg("backend") = "hccl",
py::arg("hccl_comm") = nullptr,
py::arg("comm_mode") = CommMode::COMM_MULTI_PROCESS,
py::arg("rank_table_file") = "",
py::arg("comm_domain") = "",
py::arg("transpose") = false)
.def_readwrite("rank", &AllToAllParam::rank)
.def_readwrite("rank_size", &AllToAllParam::rankSize)
.def_readwrite("rank_root", &AllToAllParam::rankRoot)
.def_readwrite("backend", &AllToAllParam::backend)
.def_readwrite("hccl_comm", &AllToAllParam::hcclComm)
.def_readwrite("comm_mode", &AllToAllParam::commMode)
.def_readwrite("rank_table_file", &AllToAllParam::rankTableFile)
.def_readwrite("comm_domain", &AllToAllParam::commDomain)
.def_readwrite("transpose", &AllToAllParam::transpose)
.def("__repr__",
[](const AllToAllParam ¶m) { return "AllToAllParam: " + atb::OpParamToJson(param).dump(); });
}