#!/usr/bin/env python3
# coding: utf-8
# ----------------------------------------------------------------------------------------------------------
# Copyright (c) 2026 Huawei Technologies Co., Ltd.
# This program is free software, you can redistribute it and/or modify it under the terms and conditions of
# CANN Open Software License Agreement Version 2.0 (the "License").
# Please refer to the License for details. You may not use this file except in compliance with the License.
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
# See LICENSE in the root of the software repository for the full text of the License.
# ----------------------------------------------------------------------------------------------------------
"""PyPTO {op} kernel implementation.
模板说明:
- 本文件是 {op}_impl.py 的固定模板,由 pypto-op-develop 在 Stage 3 生成。
- 所有 {op} 占位符需替换为实际算子名称。
- 导出函数 {op}_wrapper() 供 test_{op}.py 调用。
- kernel 使用 @pypto.frontend.jit 装饰,内部使用 pypto API。
- 参考 examples/ 中的 kernel 实现风格(activation、softmax、layer_norm 等)。
"""
import pypto
import torch
# ─────────────────────────────────────────────
# 1. 核心计算函数(可选,复杂算子拆分用)
# ─────────────────────────────────────────────
def {op}_core(x: pypto.Tensor) -> pypto.Tensor:
"""核心计算逻辑,供 kernel 调用。
根据 DESIGN.md 中的 API 映射实现。
使用 pypto 基础 API(如 pypto.exp, pypto.sum, pypto.amax 等)。
Args:
x: pypto.Tensor 输入。
根据实际算子需求调整参数列表。
Returns:
pypto.Tensor 计算结果。
"""
# TODO: 替换为实际计算逻辑
# 示例(SiLU): return x * pypto.sigmoid(x)
# 示例(Softmax):
# row_max = pypto.amax(x, dim=-1, keepdim=True)
# exp = pypto.exp(x - row_max)
# return exp / pypto.sum(exp, dim=-1, keepdim=True)
return x
# ─────────────────────────────────────────────
# 2. JIT Kernel
# ─────────────────────────────────────────────
@pypto.frontend.jit
def {op}_kernel(
input_tensor: pypto.Tensor(),
output_tensor: pypto.Tensor(),
):
"""PyPTO jit kernel。
根据 DESIGN.md 实现。
- Tensor 描述符使用 pypto.Tensor()(shape 自动推断)或
pypto.Tensor([pypto.DYNAMIC, ...], pypto.DT_FP32)(显式指定)。
- 必须配置 tiling:pypto.set_vec_tile_shapes(...) 或 pypto.set_cube_tile_shapes(...)。
- 输出写回使用 output_tensor[:] = result 或 pypto.assemble(result, offset, output_tensor)。
"""
# TODO: 根据 DESIGN.md 配置 tiling
# 示例: pypto.set_vec_tile_shapes(64, 128)
# TODO: 替换为实际 kernel 逻辑
result = {op}_core(input_tensor)
output_tensor[:] = result
# ─────────────────────────────────────────────
# 3. Wrapper 函数(导出接口)
# ─────────────────────────────────────────────
def {op}_wrapper(x: torch.Tensor) -> torch.Tensor:
"""算子 wrapper,供 test_{op}.py 调用。
负责:
1. 构造输出 torch.Tensor
2. 调用 JIT kernel
3. 返回结果 torch.Tensor
Args:
x: 输入 torch.Tensor。
根据实际算子需求调整参数列表(可多输入)。
Returns:
输出 torch.Tensor。
根据实际算子需求调整返回值(可多输出)。
"""
# TODO: 根据实际算子调整 output shape 和 dtype
output = torch.empty_like(x)
# 调用 kernel
{op}_kernel(x, output)
return output