triton.language.join

1 功能作用说明

将两个相同形状的输入张量沿着新的最小维度连接,输出张量比输入张量多一个维度,大小为2,保持其他维度不变。

语法:

  • triton.language.join(x, y) - 函数调用形式
  • x.join(y) - 成员函数形式

功能:

  • 将两个相同形状的输入张量沿着新的最小维度连接
  • 输出张量比输入张量多一个维度,大小为2
  • 保持其他维度不变

2 参数规格

2.1 参数说明

参数名 类型 必需 说明
x tensor 第一个输入张量
y tensor 第二个输入张量

返回值:

  • 类型: tensor
  • 形状: 输入tensor广播后的形状加上一个大小为2的维度
  • 数据类型: 与输入张量相同
  • 内存布局: 在新增维度上堆叠x和y

约束条件:

  • 两个输入张量必须具有可以广播到相同形状的形状和数据类型

2.2 DataType支持表

支持情况 int8 int16 int32 int64 uint8 uint16 uint32 uint64 float16 float32 bfloat16 float8e4 float8e5 float64 bool
Ascend A2/A3 × × × × × ×
GPU支持

2.3 Shape支持表

支持任意维度数、任意形状大小。

2.4 特殊限制说明

2.5 使用方法

import torch
import triton
import triton.language as tl

@triton.jit
def join_example(out_ptr):
    # 创建两个2x3的张量
    x = tl.zeros([2, 3], dtype=tl.float32)
    y = tl.full([2, 3], 1.0, dtype=tl.float32)

    # 连接,变成2x2x3
    z = tl.join(x, y)

    # 将结果写回外部张量
    offs = (
        tl.arange(0, 2)[:, None, None] * (2 * 3)
        + tl.arange(0, 2)[None, :, None] * 3
        + tl.arange(0, 3)[None, None, :]
    )
    tl.store(out_ptr + offs, z)

## 调用示例
out = torch.empty((2, 2, 3), dtype=torch.float32, device="npu")
join_example[(1,)](out)
print(out.shape)  # 输出: torch.Size([2, 2, 3])