MaskedCausalConv1d
产品支持情况
| 产品 | 是否支持 |
|---|---|
| Ascend 950PR/Ascend 950DT | √ |
| Atlas A3 训练系列产品/Atlas A3 推理系列产品 | × |
| Atlas A2 训练系列产品/Atlas A2 推理系列产品 | × |
| Atlas 200I/500 A2 推理产品 | × |
| Atlas 推理系列产品 | × |
| Atlas 训练系列产品 | × |
功能说明
-
算子功能:对hidden层的token之间进行带mask的因果一维分组卷积操作。
-
计算公式:
假设输入x和输出y的shape是[S, B, H],卷积权重weight的shape是[W, H],i和j分别表示S和B轴的索引,那么输出将被表示为:
y[i,j]=mask[j,i]∗∑k=0W−1x[i−k,j]∗weight[W−1−k]y[i,j] = mask[j,i] * \sum_{k=0}^{W-1} x[i-k,j] * weight[W-1-k]
其中,无效位置的padding为0填充;当前W仅支持3;H轴为elementwise操作,上述公式不体现。
参数说明
| 参数名 | 输入/输出 | 描述 | 数据类型 | 数据格式 |
|---|---|---|---|---|
| x | 输入 | 输入序列,shape为[S, B, H],对应公式中x。不支持空Tensor。 | FLOAT16、BFLOAT16 | ND |
| weight | 输入 | 因果1维分组卷积核,shape为[W, H],W固定为3,对应公式中weight。不支持空Tensor。 | 数据类型与x一致 | ND |
| mask | 可选输入 | 布尔掩码,shape为[B, S],对应公式中mask。默认值是None,为None时等价于全True。不支持空Tensor。 | BOOL | ND |
| y | 输出 | 输出结果,shape与x一致。不支持空Tensor。 | 数据类型与x一致 | ND |
约束说明
- 输入值域限制:
- B * S:取值范围为1~512K。
- H(hiddenSize):取值范围384~24576(64的整数倍)。
- W:当前只支持3。
- 算子入参与中间计算结果,在对应运行数据类型(float16/bfloat16) 下,数值均不会超出该类型值域范围。
- 算子输入不支持有±inf和nan的情况。
调用说明
| 调用方式 | 样例代码 | 说明 |
|---|---|---|
| aclnn接口 | test_aclnn_masked_causal_conv1d | 通过aclnnMaskedCausalConv1d调用MaskedCausalConv1d算子 |