triton.language.join
1 功能作用说明
将两个相同形状的输入张量沿着新的最小维度连接,输出张量比输入张量多一个维度,大小为2,保持其他维度不变。
语法:
triton.language.join(x, y)- 函数调用形式x.join(y)- 成员函数形式
功能:
- 将两个相同形状的输入张量沿着新的最小维度连接
- 输出张量比输入张量多一个维度,大小为2
- 保持其他维度不变
2 参数规格
2.1 参数说明
| 参数名 | 类型 | 必需 | 说明 |
|---|---|---|---|
| x | tensor | 是 | 第一个输入张量 |
| y | tensor | 是 | 第二个输入张量 |
返回值:
- 类型: tensor
- 形状: 输入tensor广播后的形状加上一个大小为2的维度
- 数据类型: 与输入张量相同
- 内存布局: 在新增维度上堆叠x和y
约束条件:
- 两个输入张量必须具有可以广播到相同形状的形状和数据类型
2.2 DataType支持表
| 支持情况 | int8 | int16 | int32 | int64 | uint8 | uint16 | uint32 | uint64 | float16 | float32 | bfloat16 | float8e4 | float8e5 | float64 | bool |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Ascend A2/A3 | ✓ | ✓ | ✓ | ✓ | ✓ | × | × | × | ✓ | ✓ | ✓ | × | × | × | ✓ |
| GPU支持 | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ | ✓ |
2.3 Shape支持表
支持任意维度数、任意形状大小。
2.4 特殊限制说明
无
2.5 使用方法
import torch
import triton
import triton.language as tl
@triton.jit
def join_example(out_ptr):
# 创建两个2x3的张量
x = tl.zeros([2, 3], dtype=tl.float32)
y = tl.full([2, 3], 1.0, dtype=tl.float32)
# 连接,变成2x2x3
z = tl.join(x, y)
# 将结果写回外部张量
offs = (
tl.arange(0, 2)[:, None, None] * (2 * 3)
+ tl.arange(0, 2)[None, :, None] * 3
+ tl.arange(0, 3)[None, None, :]
)
tl.store(out_ptr + offs, z)
## 调用示例
out = torch.empty((2, 2, 3), dtype=torch.float32, device="npu")
join_example[(1,)](out)
print(out.shape) # 输出: torch.Size([2, 2, 3])