TORCH_LIBRARY_FRAGMENT(atb, m)
m.impl(...)
torch.ops.atb
at_npu::native::atb
torch_npu
torchnpugen/gen_atb_ops.py
::atb::*
m.def(...)
op_plugin/config/atb_ops.yaml
cpp_name -> impl_name
op_plugin/include/AtbOpsInterface.h
op_plugin/ops/atb/AtbOpsInterface.cpp
op_plugin/ops/atb/*.cpp
cpp_name
at_npu::native::atb::
torchnpugen/templates/AtbOpsInterface.h
AtbOpsInterface.cpp
string.Template
namespace at_npu::native::atb
op_plugin/include/atb_ops.h
AtbOpsInterface.h
gen_atb_ops.py
_npu_group_topk
libop_plugin_atb.so
at_npu::native::atb::*
python -m torchnpugen.gen_atb_ops --config op_plugin/config/atb_ops.yaml --header op_plugin/include/AtbOpsInterface.h --source op_plugin/ops/atb/AtbOpsInterface.cpp --atb-src-dir op_plugin/ops/atb
python test/check_atb_cpp_api_contract.py --check-source
cpp_name -> impl_name -> m.impl -> m.def
atb_ops.yaml
op-plugin
npu_grouped_matmul_swiglu_quant_v2
weightScale
[E, ceil(K/64), N, 2]
weightScale.dim2
npu_grouped_matmul_swiglu_quant_v2_meta
FP4_IN_INT8
output = [M, N / 2]
outputScale = [M, ceil((N / 2) / 64), 2]
GroupedMatmulSwigluQuantV2NpuOpapi.cpp
weight.size(2)
test/core_tests/test_fake_tensor.py
test/test_custom_ops/test_npu_grouped_matmul_swiglu_quant_v2.py
x_dtype/weight_dtype=torch_npu.float4_e2m1fn_x2
weight_scale_dtype/x_scale_dtype=torch_npu.float8_e8m0fnu
fix: grouped matmul swiglu quant v2 mxfp4 shape infer
TORCH_NPU_LAZY_FUSION=True