文件最后提交记录最后更新时间
aclgraph示例指定kernel type Co-authored-by: wangkai<wangkai579@huawei.com> # message auto-generated for no-merge-commit merge: !4104 merge 123 into master aclgraph示例指定kernel type Created-by: mihudan Commit-by: wangkai Merged-by: ascend-robot Description: <!-- Thanks for sending a pull request! --> **What type of PR is this?** > Uncomment only one /kind <> line, hit enter to put that in a new line, and remove leading whitespaces from that line: /kind task **What does this PR do / why do we need it**: 增加的自定义算子<<<>>> + aclgraph调用的示例,指定kernel type适配不同设备环境 **Special notes for your reviewers**: 四个内核函数的声明限定符由 __global__ __aicore__ 替换为 __global__ __vector__。该修改仅调整了编译指令,确保了代码与目标硬件编程模型的一致性,内核的核心计算逻辑保持不变 See merge request: Ascend/op-plugin!41044 个月前
添加aclgraph+<<<>>> 示例代码 Co-authored-by: wangkai<wangkai579@huawei.com> # message auto-generated for no-merge-commit merge: !4019 merge master into master 添加aclgraph+<<<>>> 示例代码 Created-by: mihudan Commit-by: wangkai Merged-by: ascend-robot Description: <!-- Thanks for sending a pull request! --> **What type of PR is this?** > Uncomment only one /kind <> line, hit enter to put that in a new line, and remove leading whitespaces from that line: > > /kind bug /kind task > /kind feature **What does this PR do / why do we need it**: 添加aclgraph 与<<<>>>联调的demo测试代码,进行验证并供用户参考。 展示了如何使用PyTorch的torch.librar以及pybind两张方式注册自定义算子,通过<<<>>>内核调用符调用核函数,并适配aclgraph使用该自定义算子,以简单的Add算子和三角函数计算的原地算子为例,实现aclgraph下自定义算子的调用。 展示了3种aclgraph的使能方式,通过对比NPU输出与CPU标准加法结果来验证自定义算子的数值正确性。 1. torch.npu.NPUGraph() 2. torch.npu.make_graphed_callables 3. backend="npugraph_ex" ## 算子描述 ### Add算子 - 算子功能: Add算子实现了两个数据相加,返回相加结果的功能。对应的算子原型为: ``` ascendc_add(Tensor x, Tensor y) -> Tensor ``` - 算子规格: <table> <tr><td rowspan="1" align="center">核函数名</td><td colspan="4" align="center">add_custom</td></tr> </tr> <tr><td rowspan="3" align="center">算子输入</td><td align="center">name</td><td align="center">shape</td><td align="center">data type</td><td align="center">format</td></tr> <tr><td align="center">x</td><td align="center">8 * 2048</td><td align="center">int</td><td align="center">ND</td></tr> <tr><td align="center">y</td><td align="center">8 * 2048</td><td align="center">int</td><td align="center">ND</td></tr> </tr> </tr> <tr><td rowspan="1" align="center">算子输出</td><td align="center">z</td><td align="center">8 * 2048</td><td align="center">int</td><td align="center">ND</td></tr> </tr> </table> ### 原地三角函数算子 - 算子功能: 该算子入参为x, out_sin ,out_cos, 算子调用后,out_sin会被原地修改为sin(x)计算结果,out_cos会被原地修改为cos(x)计算结果,返回值tan(x)计算结果。对应的算子原型为: ``` ascendc_trig(Tensor x, Tensor(a!) out_sin, Tensor(b!) out_cos) -> Tensor ``` - 算子规格: <table> <tr><td rowspan="1" align="center">核函数名</td><td colspan="4" align="center">trig_inplace_custom</td></tr> </tr> <tr><td rowspan="4" align="center">算子输入</td><td align="center">name</td><td align="center">shape</td><td align="center">data type</td><td align="center">format</td></tr> <tr><td align="center">x</td><td align="center">8 * 2048</td><td align="center">float</td><td align="center">ND</td></tr> <tr><td align="center">out_sin</td><td align="center">8 * 2048</td><td align="center">float</td><td align="center">ND</td></tr> <tr><td align="center">out_cos</td><td align="center">8 * 2048</td><td align="center">float</td><td align="center">ND</td></tr> </tr> </tr> <tr><td rowspan="3" align="center">算子输出</td><td align="center">out_sin</td><td align="center">8 * 2048</td><td align="center">float</td><td align="center">ND</td></tr> <tr><td align="center">out_cos</td><td align="center">8 * 2048</td><td align="center">float</td><td align="center">ND</td></tr> <tr><td align="center">out_tan</td><td align="center">8 * 2048</td><td align="center">float</td><td align="center">ND</td></tr> </tr> </table> ## 代码实现介绍 - 以Add算子为例,样例在*.asc文件中定义了一个名为ascendc_ops的命名空间,并在其中注册了ascendc_add函数。在ascendc_add函数中通过c10_npu::getCurrentNPUStream()函数获取当前NPU上的流,并通过内核调用符<<<>>>调用自定义的Kernel函数add_custom,在NPU上执行算子。 ```c++ add_custom<<<blockDim, nullptr, aclStream>>>(xGm, yGm, zGm, totalLength); ``` - PyTorch提供TORCH_LIBRARY_FRAGMENT宏作为自定义算子注册的核心接口,用于创建并初始化自定义算子库,注册后在Python侧可以通过torch.ops.namespace.op_name方式进行调用,例如: ```c++ TORCH_LIBRARY_FRAGMENT(ascendc_ops, m) { m.def(ascendc_add"(Tensor x, Tensor y) -> Tensor"); } ``` - TORCH_LIBRARY_IMPL用于将算子逻辑绑定到特定的DispatchKey(PyTorch设备调度标识)。针对NPU设备,需要将算子实现注册到PrivateUse1这一专属的DispatchKey上,例如: ```c++ TORCH_LIBRARY_IMPL(ascendc_ops, PrivateUse1, m) { m.impl("ascendc_add", TORCH_FN(ascendc_ops::ascendc_add)); } ``` - 注册Meta函数: 注册Meta函数使faketensor流程正常工作,在使用fx, compile等功能涉及,注册代码如下: ```c++ TORCH_LIBRARY_IMPL(ascendc_ops, Meta, m) { m.impl("ascendc_add", &add_impl_meta); } ``` **Special notes for your reviewers**: See merge request: Ascend/op-plugin!40194 个月前
fix <<<>>> sample doc Co-authored-by: wangkai<wangkai579@huawei.com> # message auto-generated for no-merge-commit merge: !4080 merge master into master fix <<<>>> sample doc Created-by: mihudan Commit-by: wangkai Merged-by: ascend-robot Description: <!-- PR描述模板更新日期:20260203 --> # 【合入来源】 > (如有)请关联需求文档/issue链接 - [x] 需求 # 【修改方案】 > 请描述修改内容的具体实现,涉及哪些组件之间进行交互,可以用1、2、3、...进行罗列\ > 如果是需求或者重构类的PR,需要补充详细设计文档(说明上下游组件关系、时序图、类图、DFX能力等内容) <<<>>> aclgraph代码示例 注释描述有误 # 【资料变更】 > 请确认是否涉及资料变更。如涉及,需要在PR中体现,并简要说明修改内容。如不涉及,需填写“不涉及” 不涉及 # 【接口变更】 > 请确认是否涉及跨代码仓或者客户面可见的接口变更。如涉及,需要详细说明接口以及对应的变更内容,同时需要在资料中体现。如不涉及,需填写“不涉及” 不涉及 # 【功能验证】 > 说明测试场景,测试方法。如果本次测试方式与常规单元测试不同,请详细说明您的测试步骤\ > 新增/变更内容是否已新增/适配UT测试用例看护,并补充测试自验证截图 示例demo,需编译安装,无UT # 【CheckList】 > PR提交人对以下CheckList自检项进行全量自检,自检通过或不涉及,均修改 [ ] 为 [x] - [x] 代码注释完备,正确记录错误日志 - [x] 代码实现进行了返回值、空指针等校验 - [x] PR标题正确使用类型标签,如:feat、fix、refactor、docs、test等 - [x] PR持续集成流水线(CI)执行通过,代码检查无异常 See merge request: Ascend/op-plugin!40803 个月前
modify document Co-authored-by: molly123321<malei54@h-partners.com> # message auto-generated for no-merge-commit merge: !4896 merge master into master modify document Created-by: molly123321 Commit-by: molly123321 Merged-by: ascend-robot Description: <!-- PR描述模板更新日期:20260203 --> # 【合入来源】 > <font color="red">**如有社区issue,请关联issue链接**</font>\ > <font color="red">**请勿携带内部流程信息(需求链接、问题单、内部issue等)**</font> - [ ] 需求 - [ ] 问题单 - [ ] issue/工单 - [ ] 重构优化 - [x] 资料更新 # 【修改方案】 doc ci工具扫描问题清理 CANN安装语言风格统一、主干跳转链接修改至最新CANN版本 # 【资料变更】 修改格式和跳转问题 CANN安装语言风格统一、主干跳转链接修改至最新CANN版本 # 【接口变更】 “不涉及” # 【功能验证】 不涉及 # 【CheckList】 > PR提交人对以下CheckList自检项进行全量自检,自检通过或不涉及,均修改 [ ] 为 [x] - [x] 代码注释完备,正确记录错误日志 - [x] 代码实现进行了返回值、空指针等校验 - [x] PR标题正确使用类型标签,如:feat、fix、refactor、docs、test等 - [x] PR持续集成流水线(CI)执行通过,代码检查无异常 See merge request: Ascend/op-plugin!489626 天前
算子直调+aclgraph示例代码readme增加编译选项说明 Co-authored-by: wangkai<wangkai579@huawei.com> # message auto-generated for no-merge-commit merge: !4081 merge 120 into master 算子直调+aclgraph示例代码readme增加编译选项说明 Created-by: mihudan Commit-by: wangkai Merged-by: ascend-robot Description: <!-- Thanks for sending a pull request! --> **What type of PR is this?** > Uncomment only one /kind <> line, hit enter to put that in a new line, and remove leading whitespaces from that line: > > /kind bug /kind task > /kind feature **What does this PR do / why do we need it**: npu-arch编译选项参数增加说明 CANN包增加版本说明 **Special notes for your reviewers**: See merge request: Ascend/op-plugin!40814 个月前
README.md

自定义算子直调并适配aclgraph

概述

本样例展示了如何使用PyTorch的torch.library注册自定义算子,通过<<<>>>内核调用符调用核函数,并适配aclgraph使用该自定义算子,以简单的Add算子和三角函数计算的原地算子为例,实现aclgraph下自定义算子的调用。

支持的产品

  • Atlas A3 训练系列产品/Atlas A3 推理系列产品
  • Atlas A2 训练系列产品/Atlas A2 推理系列产品

目录结构介绍

├── README.md                   // 示例介绍
├── setup.py                    // setup文件
├── csrc
│   ├── add_custom.asc          // Add算子实现 & 自定义算子注册
│   └── trig_inplace_custom.asc // 原地三角函数算子实现 & 自定义算子注册
├── op_extension
│   ├── __init__.py             // python初始化文件
│   └── _load.py                // 加载模块
└── test
    ├── add_aclgraph_test.py    // Add算子aclgraph测试demo
    └── trig_aclgraph_test.py   // 原地三角函数aclgraph测试demo             

算子描述

Add算子

  • 算子功能:

    Add算子实现了两个数据相加,返回相加结果的功能。对应的算子原型为:

    ascendc_add(Tensor x, Tensor y) -> Tensor
    
  • 算子规格:

    核函数名add_custom
    算子输入nameshapedata typeformat
    x8 * 2048intND
    y8 * 2048intND
    算子输出z8 * 2048intND

原地三角函数算子

  • 算子功能:

    该算子入参为x, out_sin ,out_cos, 算子调用后,out_sin会被原地修改为sin(x)计算结果,out_cos会被原地修改为cos(x)计算结果,返回值tan(x)计算结果。对应的算子原型为:

    ascendc_trig(Tensor x, Tensor(a!) out_sin, Tensor(b!) out_cos) -> Tensor
    
  • 算子规格:

    核函数名trig_inplace_custom
    算子输入nameshapedata typeformat
    x8 * 2048floatND
    out_sin8 * 2048floatND
    out_cos8 * 2048floatND
    算子输出out_sin8 * 2048floatND
    out_cos8 * 2048floatND
    out_tan8 * 2048floatND

代码实现介绍

  • 以Add算子为例,样例在*.asc文件中定义了一个名为ascendc_ops的命名空间,并在其中注册了ascendc_add函数。在ascendc_add函数中通过c10_npu::getCurrentNPUStream()函数获取当前NPU上的流,并通过内核调用符<<<>>>调用自定义的Kernel函数add_custom,在NPU上执行算子。

      add_custom<<<blockDim, nullptr, aclStream>>>(xGm, yGm, zGm, totalLength);
    
  • PyTorch提供TORCH_LIBRARY_FRAGMENT宏作为自定义算子注册的核心接口,用于创建并初始化自定义算子库,注册后在Python侧可以通过torch.ops.namespace.op_name方式进行调用,例如:

    TORCH_LIBRARY_FRAGMENT(ascendc_ops, m)
    {
        m.def(ascendc_add"(Tensor x, Tensor y) -> Tensor");
    }
    
  • TORCH_LIBRARY_IMPL用于将算子逻辑绑定到特定的DispatchKey(PyTorch设备调度标识)。针对NPU设备,需要将算子实现注册到PrivateUse1这一专属的DispatchKey上,例如:

    TORCH_LIBRARY_IMPL(ascendc_ops, PrivateUse1, m)
    {
        m.impl("ascendc_add", TORCH_FN(ascendc_ops::ascendc_add));
    }
    
  • 注册Meta函数:

    注册Meta函数使faketensor流程正常工作,在使用fx, compile等功能涉及,注册代码如下:

    TORCH_LIBRARY_IMPL(ascendc_ops, Meta, m)
    {
      m.impl("ascendc_add", &add_impl_meta);
    }
    
  • aclgraph的调用:

    示例代码中,通过torch.ops.load_library加载生成的自定义算子库,并展示了3种aclgraph的使能方式,通过对比NPU输出与CPU标准加法结果来验证自定义算子的数值正确性。

  1. torch.npu.NPUGraph()
  2. torch.npu.make_graphed_callables
  3. backend="npugraph_ex"

编译运行

在本样例根目录下执行如下步骤,编译并执行算子。

  • 环境安装
  1. 请参考与您当前使用的版本配套的《Ascend Extension for PyTorch 软件安装指南》,获取PyTorch和torch_npu详细的安装步骤。

    本样例需torch2.6.0及以上版本,支持backend="npugraph_ex"需7.3.0及以上版本。

  2. 根据实际环境安装CANN toolkit包,本样例需8.5.0及以上版本,安装指导详见《CANN 软件安装指南》。

  3. 根据实际环境安装CANN ops包。根据产品型号和环境架构,下载对应安装包,可参考下载链接并执行如下命令安装:

    # 确保安装包具有可执行权限
    chmod +x Ascend-cann-${soc_name}-ops_${cann_version}_linux-${arch}.run
    # 安装命令
    ./Ascend-cann-${soc_name}-ops_${cann_version}_linux-${arch}.run  --install --quiet --install-path=${install_path}
    
    • ${soc_name}:表示NPU型号名称,即${soc_version}删除“ascend”后剩余的内容。
    • ${install_path}:表示指定安装路径,需要与toolkit包安装在相同路径,默认安装在/usr/local/Ascend目录。
  • 配置环境变量

    请根据当前环境上CANN开发套件包的安装位置,执行如下配置环境变量的命令。

    source ${install_path}/ascend-toolkit/set_env.sh
    
  • 样例执行

    参考表格,根据实际昇腾AI处理器架构修改setup.py中的--npu-arch参数,并执行如下命令:

    python setup.py bdist_wheel
    pip install dist/*.whl --force-reinstall
    cd test
    python ./add_aclgraph_test.py
    

执行结果如下,说明精度对比成功。

Ran * test in **s.
OK