import math
import os
import torch
import torch_npu
from torch.library import Library, impl
from torch.fx.node import has_side_effect
from torch_npu.utils._error_code import ErrCode, ops_error
from torch_npu.npu.utils import get_cann_version
from torch_npu.npu._backends import get_soc_version
'''
Registering Meta implementations for custom ops
'''
BIT_NUMBER = 128
UINT8_BIT_NUMBER = 8
NPU_TENSOR_DIM_LIMIT = 8
INPUTS_DIM_LIMIT_QUANTCONV2D = 4
ATTR_DIM_LIMIT_QUANTCONV2D = 2
FP8_E4M3_BLOCK_SIZE = 32
m = Library("npu", "IMPL", "Meta")
m_aten = Library("aten", "IMPL", "Meta")
TORCH_DTYPE_MAP = {
torch.float16: 5,
torch.bfloat16: 15,
torch.float32: 6,
torch.float8_e5m2: 23,
torch.float8_e4m3fn: 24,
torch.bits8: 21,
torch.int8: 1,
torch.int32: 3,
}
TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP = {
0: torch.uint8,
1: torch.int8,
2: torch.int16,
3: torch.int32,
4: torch.int64,
5: torch.float16,
6: torch.float32,
7: torch.float64,
8: torch.complex32,
9: torch.complex64,
10: torch.complex128,
11: torch.bool,
12: torch.qint8,
13: torch.quint8,
14: torch.qint32,
15: torch.bfloat16,
16: torch.quint4x2,
21: torch.bits8,
23: torch.float8_e5m2,
24: torch.float8_e4m3fn,
285: torch.uint8,
290: torch.uint8,
291: torch.float8_e5m2,
292: torch.float8_e4m3fn,
296: torch.uint8,
297: torch.uint8,
}
TORCH_NPU_DTYPE_TO_STRING_MAP = {
290: "torch_npu.hifloat8",
293: "torch_npu.float8_e8m0fnu",
296: "torch_npu.float4_e2m1fn_x2",
297: "torch_npu.float4_e1m2fn_x2",
}
def npu_dtype_to_str(dtype):
torch_dtype = TORCH_NPU_DTYPE_TO_STRING_MAP.get(dtype)
if torch_dtype is not None:
return torch_dtype
torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dtype)
if torch_dtype is None:
return str(dtype)
return str(torch_dtype)
def _matmul_get_output_size(tensor1, tensor2):
dim_tensor1 = tensor1.dim()
dim_tensor2 = tensor2.dim()
if dim_tensor1 == 1 and dim_tensor2 == 1:
return []
elif dim_tensor1 == 2 and dim_tensor2 == 1:
return [tensor1.size(0)]
elif dim_tensor1 == 1 and dim_tensor2 == 2:
return [tensor2.size(1)]
elif dim_tensor1 == 2 and dim_tensor2 == 2:
return [tensor1.size(0), tensor2.size(1)]
elif dim_tensor1 >= 3 and (dim_tensor2 == 1 or dim_tensor2 == 2):
size1 = list(tensor1.size())[:-1]
if dim_tensor2 > 1:
size2 = list(tensor2.size())
output_size = size1 + [size2[-1]]
else:
output_size = size1
return output_size
elif (dim_tensor1 == 1 or dim_tensor1 == 2) and dim_tensor2 >= 3:
size2 = list(tensor2.size())[:-2]
if dim_tensor1 > 1:
size1 = list(tensor1.size())
output_size = size2 + [size1[0]] + [tensor2.size(-1)]
else:
output_size = size2 + [tensor2.size(-1)]
return output_size
elif dim_tensor1 >= 3 and dim_tensor2 >= 3:
n = tensor1.size(-2)
batch_tensor1 = list(tensor1.size())[:-2]
p = tensor2.size(-1)
batch_tensor2 = list(tensor2.size())[:-2]
broadcast_batch = torch.broadcast_shapes(torch.Size(batch_tensor1), torch.Size(batch_tensor2))
output_size = list(broadcast_batch) + [n, p]
return output_size
else:
raise RuntimeError(f"matmul got error sizes: ({dim_tensor1}, {dim_tensor2})")
if os.getenv("TORCH_NPU_USE_COMPATIBLE_IMPL") != "1":
@impl(m_aten, "matmul_backward")
def matmul_backward_meta(grad, self, other, mask):
mat1 = self
mat2 = other
grad_tensor = grad
while mat1.dim() > 2 and mat1.size(0) == 1:
mat1 = mat1.squeeze(0)
if mat2.dim() == 1:
mat2 = mat2.unsqueeze(-1)
grad_tensor = grad_tensor.unsqueeze(-1)
if mat1.dim() == 1:
mat1 = mat1.unsqueeze(0)
grad_tensor = grad_tensor.unsqueeze(-2)
if mat1.dim() == 2 and mat2.dim() > 2:
self_grad = torch.empty(list(self.size()), dtype=grad_tensor.dtype, device=grad_tensor.device)
else:
mat2_transposed = mat2.transpose(-2, -1)
self_grad_size = _matmul_get_output_size(grad_tensor, mat2_transposed)
self_grad = torch.empty(self_grad_size, dtype=grad_tensor.dtype, device=grad_tensor.device)
mat1 = self
mat2 = other
grad_tensor = grad
while mat2.dim() > 2 and mat2.size(0) == 1:
mat2 = mat2.squeeze(0)
if mat2.dim() == 1:
mat2 = mat2.unsqueeze(-1)
grad_tensor = grad_tensor.unsqueeze(-1)
if mat1.dim() == 1:
mat1 = mat1.unsqueeze(0)
grad_tensor = grad_tensor.unsqueeze(-2)
if mat2.dim() == 2 and mat1.dim() > 2:
other_grad = torch.empty(list(mat2.size()), dtype=mat1.dtype, device=mat1.device)
else:
mat1_transposed = mat1.transpose(-2, -1)
other_grad_size = _matmul_get_output_size(mat1_transposed, grad_tensor)
other_grad = torch.empty(other_grad_size, dtype=mat1.dtype, device=mat1.device)
if other.dim() == 1 and other_grad.size(-1) == 1 and other_grad.dim() != 1:
other_grad = other_grad.squeeze(-1)
return (self_grad, other_grad)
@impl(m, "npu_mhc_pre")
def npu_mhc_pre_meta(x, phi, alpha, bias, *, gamma=None, norm_eps=1e-6, hc_eps=1e-6, out_flag=0):
torch._check(
x.numel() > 0,
lambda: "Input x should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
phi.numel() > 0,
lambda: "Input phi should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
alpha.numel() == 3,
lambda: "Input alpha must have 3 elements, but got " + str(alpha.numel()) + "." + ops_error(ErrCode.VALUE),
)
torch._check(
bias.numel() > 0,
lambda: "Input bias should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
x.dim() == 3 or x.dim() == 4,
lambda: f"Input x must be 3D or 4D, but got {x.dim()}D." + ops_error(ErrCode.VALUE),
)
torch._check(
out_flag == 0 or out_flag == 1,
lambda: f"Input out_flag must be 0 or 1, but got {out_flag}." + ops_error(ErrCode.VALUE),
)
if x.dim() == 4:
batch = x.size(0)
sequence = x.size(1)
num_residual = x.size(2)
dim = x.size(3)
mat_k = phi.size(0)
out_hin = torch.empty([batch, sequence, dim], dtype=x.dtype, device="meta")
out_hpost = torch.empty([batch, sequence, num_residual], dtype=torch.float32, device="meta")
out_hres = torch.empty([batch, sequence, num_residual, num_residual], dtype=torch.float32, device="meta")
out_inv_rms = torch.empty([batch, sequence], dtype=torch.float32, device="meta")
out_hmix = torch.empty([batch, sequence, mat_k], dtype=torch.float32, device="meta")
out_hpre = torch.empty([batch, sequence, num_residual], dtype=torch.float32, device="meta")
else:
t = x.size(0)
num_residual = x.size(1)
dim = x.size(2)
mat_k = phi.size(0)
out_hin = torch.empty([t, dim], dtype=x.dtype, device="meta")
out_hpost = torch.empty([t, num_residual], dtype=torch.float32, device="meta")
out_hres = torch.empty([t, num_residual, num_residual], dtype=torch.float32, device="meta")
out_inv_rms = torch.empty([t], dtype=torch.float32, device="meta")
out_hmix = torch.empty([t, mat_k], dtype=torch.float32, device="meta")
out_hpre = torch.empty([t, num_residual], dtype=torch.float32, device="meta")
output_tensors = (out_hin, out_hpost, out_hres, out_inv_rms, out_hmix, out_hpre)
return output_tensors
@impl(m, "npu_incre_flash_attention")
def npu_incre_flash_attention_forward(query, key, value, *, padding_mask=None, atten_mask=None, pse_shift=None, actual_seq_lengths=None,
antiquant_scale=None, antiquant_offset=None, block_table=None,
dequant_scale1=None, quant_scale1=None, dequant_scale2=None, quant_scale2=None,
quant_offset2=None, kv_padding_size=None, num_heads=1, scale_value=1.0, input_layout="BSH",
num_key_value_heads=0, block_size=0, inner_precise=1):
if quant_scale2 is not None:
return torch.empty_like(query, dtype=torch.int8)
elif query.dtype == torch.int8:
return torch.empty_like(query, dtype=torch.half)
else:
return torch.empty_like(query)
@impl(m, "batch_norm_gather_stats_update")
def batch_norm_gather_stats_update_meta(inp, mean, invstd, running_mean, running_var, momentum, eps, counts):
C = inp.shape[1]
out_shape = (C,)
dtype = mean.dtype
device = inp.device
batch_mean = torch.empty(out_shape, dtype=dtype, device=device)
batch_invstd = torch.empty(out_shape, dtype=dtype, device=device)
return (batch_mean, batch_invstd)
@impl(m, "npu_sparse_flash_attention")
def npu_sparse_flash_attention_forward(query, key, value, sparse_indices, scale_value, *, block_table=None,
actual_seq_lengths_query=None, actual_seq_lengths_kv=None, query_rope=None,
key_rope=None, sparse_block_size=1, layout_query="BSND", layout_kv="BSND",
sparse_mode=3, pre_tokens=(1 << 63) - 1, next_tokens=(1 << 63) - 1,
attention_mode=0, return_softmax_lse=False):
require_param = {"query": query, "key": key, "value": value, "sparse_indices": sparse_indices}
for item_name, item in require_param.items():
torch._check(
item is not None,
lambda: item_name + " should not be None, but the actual value is None" + ops_error(ErrCode.VALUE),
)
torch._check(
query.numel() > 0,
lambda: "Input query should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
key.numel() > 0,
lambda: "Input key should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
value.numel() > 0,
lambda: "Input value should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
sparse_indices.numel() > 0,
lambda: "Input sparse_indices should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
not return_softmax_lse,
lambda: "when return_softmax_lse is true, not support pytorch compile." + ops_error(ErrCode.VALUE),
)
if layout_query == "TND":
torch._check(
query.dim() == 3,
lambda: "When the layout of query is TND, the query dimension must be 3, but got " + str(query.dim()) + ops_error(ErrCode.VALUE),
)
attention_out = torch.empty([query.size(0), query.size(1), query.size(2)], dtype=query.dtype, device='meta')
elif layout_query == "BSND":
torch._check(
query.dim() == 4,
lambda: "When the layout of query is BSND, the query dimension must be 4, but got " + str(query.dim()) + ops_error(ErrCode.VALUE),
)
attention_out = torch.empty([query.size(0), query.size(1), query.size(2), query.size(3)], dtype=query.dtype, device='meta')
else:
torch._check(
False,
lambda: "Not support layout of query: " + layout_query + ops_error(ErrCode.VALUE),
)
if return_softmax_lse:
if layout_query == "TND":
softmax_max = torch.empty([key.size(1), query.size(0), query.size(1) // key.size(1)], dtype=torch.float32, device='meta')
softmax_sum = torch.empty([key.size(1), query.size(0), query.size(1) // key.size(1)], dtype=torch.float32, device='meta')
if layout_query == "BSND":
softmax_max = torch.empty([query.size(0), key.size(2), query.size(1), query.size(2) // key.size(2)], dtype=torch.float32, device='meta')
softmax_sum = torch.empty([query.size(0), key.size(2), query.size(1), query.size(2) // key.size(2)], dtype=torch.float32, device='meta')
else:
softmax_max = torch.empty([0], dtype=torch.float32, device='meta')
softmax_sum = torch.empty([0], dtype=torch.float32, device='meta')
return (attention_out, softmax_max, softmax_sum)
@impl(m, "npu_sparse_flash_attention_grad")
def npu_sparse_flash_attention_grad_meta(query, key, value, sparse_indices, d_out, out, softmax_max, softmax_sum, scale_value, sparse_block_size, query_rope=None, key_rope=None, actual_seq_qlen=None, actual_seq_kvlen=None, layout="BSND", sparse_mode=3, pre_tokens=9223372036854775807, next_tokens=9223372036854775807, attention_mode=0):
d_query = query.new_empty(query.shape, dtype=query.dtype, device='meta')
d_key = key.new_empty(key.shape, dtype=key.dtype, device='meta')
d_value = value.new_empty(value.shape, dtype=value.dtype, device='meta')
d_query_rope = torch.empty([0], dtype=query.dtype, device='meta') if query_rope is None else query_rope.new_empty(query_rope.shape, dtype=query_rope.dtype, device='meta')
d_key_rope = torch.empty([0], dtype=key.dtype, device='meta') if key_rope is None else key_rope.new_empty(key_rope.shape, dtype=key_rope.dtype, device='meta')
return (d_query, d_key, d_value, d_query_rope, d_key_rope)
@impl(m, "npu_mla_prolog")
def npu_mla_prolog_forward(token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq, rmsnorm_gamma_ckv,
rope_sin, rope_cos, cache_index, kv_cache, kr_cache, *, dequant_scale_x=None, dequant_scale_w_dq=None, dequant_scale_w_uq_qr=None, dequant_scale_w_dkv_kr=None,
quant_scale_ckv=None, quant_scale_ckr=None, smooth_scales_cq=None,
rmsnorm_epsilon_cq=1e-5, rmsnorm_epsilon_ckv=1e-5, cache_mode="PA_BSND"):
torch._check(
get_soc_version() < 260,
lambda: "npu_mla_prolog not support on this soc version, please use npu_mla_prolog_v3" + ops_error(ErrCode.NOT_SUPPORT),
)
require_param = {"token_x": token_x, "weight_dq": weight_dq, "weight_uq_qr": weight_uq_qr, "weight_uk": weight_uk, "weight_dkv_kr": weight_dkv_kr, "rmsnorm_gamma_cq": rmsnorm_gamma_cq, "rmsnorm_gamma_ckv": rmsnorm_gamma_ckv, "rope_sin": rope_sin, "rope_cos": rope_cos, "cache_index": cache_index, "kv_cache": kv_cache, "kr_cache": kr_cache}
for item_name, item in require_param.items():
torch._check(
item is not None,
lambda: item_name + " should not be None, but the actual value is None" + ops_error(ErrCode.VALUE),
)
token_x_dim = token_x.dim()
torch._check(
token_x_dim == 2 or token_x_dim == 3,
lambda: "token_x dim num should be 2 or 3, but the actual value is " + str(token_x_dim) + ops_error(ErrCode.VALUE),
)
weight_uk_dim = weight_uk.dim()
torch._check(
weight_uk_dim == 3,
lambda: "weight_uk dim num should be 3, but the actual value is " + str(weight_uk_dim) + ops_error(ErrCode.VALUE),
)
rope_sin_dim = rope_sin.dim()
torch._check(
rope_sin_dim == 2 or rope_sin_dim == 3,
lambda: "rope_sin dim num should be 2 or 3, but the actual value is " + str(rope_sin_dim) + ops_error(ErrCode.VALUE),
)
if token_x_dim == 3:
query_shape = []
query_shape.append(token_x.size(0))
query_shape.append(token_x.size(1))
query_shape.append(weight_uk.size(0))
query_shape.append(weight_uk.size(2))
query_rope_shape = []
query_rope_shape.append(token_x.size(0))
query_rope_shape.append(token_x.size(1))
query_rope_shape.append(weight_uk.size(0))
query_rope_shape.append(rope_sin.size(2))
else:
query_shape = []
query_shape.append(token_x.size(0))
query_shape.append(weight_uk.size(0))
query_shape.append(weight_uk.size(2))
query_rope_shape = []
query_rope_shape.append(token_x.size(0))
query_rope_shape.append(weight_uk.size(0))
query_rope_shape.append(rope_sin.size(1))
query = torch.empty(query_shape, dtype=rope_sin.dtype, device='meta')
query_rope = torch.empty(query_rope_shape, dtype=rope_sin.dtype, device='meta')
kv_cache_out = torch.empty_like(kv_cache, dtype=kv_cache.dtype, device='meta')
kr_cache_out = torch.empty_like(kr_cache, dtype=kr_cache.dtype, device='meta')
return (query, query_rope, kv_cache_out, kr_cache_out)
@impl(m, "npu_mla_prolog_v2")
def npu_mla_prolog_v2_forward(token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq, rmsnorm_gamma_ckv,
rope_sin, rope_cos, cache_index, kv_cache, kr_cache, *, dequant_scale_x=None, dequant_scale_w_dq=None, dequant_scale_w_uq_qr=None, dequant_scale_w_dkv_kr=None,
quant_scale_ckv=None, quant_scale_ckr=None, smooth_scales_cq=None,
rmsnorm_epsilon_cq=1e-5, rmsnorm_epsilon_ckv=1e-5, cache_mode="PA_BSND"):
torch._check(
get_soc_version() < 260,
lambda: "npu_mla_prolog_v2 not support on this soc version, please use npu_mla_prolog_v3" + ops_error(ErrCode.NOT_SUPPORT),
)
require_param = {"token_x": token_x, "weight_dq": weight_dq, "weight_uq_qr": weight_uq_qr, "weight_uk": weight_uk, "weight_dkv_kr": weight_dkv_kr, "rmsnorm_gamma_cq": rmsnorm_gamma_cq, "rmsnorm_gamma_ckv": rmsnorm_gamma_ckv, "rope_sin": rope_sin, "rope_cos": rope_cos, "cache_index": cache_index, "kv_cache": kv_cache, "kr_cache": kr_cache}
for item_name, item in require_param.items():
torch._check(
item is not None,
lambda: item_name + " should not be None, but the actual value is None" + ops_error(ErrCode.VALUE),
)
token_x_dim = token_x.dim()
torch._check(
token_x_dim == 2 or token_x_dim == 3,
lambda: "token_x dim num should be 2 or 3, but the actual value is " + str(token_x_dim) + ops_error(ErrCode.VALUE),
)
weight_uk_dim = weight_uk.dim()
torch._check(
weight_uk_dim == 3,
lambda: "weight_uk dim num should be 3, but the actual value is " + str(weight_uk_dim) + ops_error(ErrCode.VALUE),
)
rope_sin_dim = rope_sin.dim()
torch._check(
rope_sin_dim == 2 or rope_sin_dim == 3,
lambda: "rope_sin dim num should be 2 or 3, but the actual value is " + str(rope_sin_dim) + ops_error(ErrCode.VALUE),
)
if token_x_dim == 3:
query_shape = []
query_shape.append(token_x.size(0))
query_shape.append(token_x.size(1))
query_shape.append(weight_uk.size(0))
query_shape.append(weight_uk.size(2))
query_rope_shape = []
query_rope_shape.append(token_x.size(0))
query_rope_shape.append(token_x.size(1))
query_rope_shape.append(weight_uk.size(0))
query_rope_shape.append(rope_sin.size(2))
dequant_scale_q_nope_shape = []
dequant_scale_q_nope_shape.append(token_x.size(0) * token_x.size(1))
dequant_scale_q_nope_shape.append(weight_uk.size(0))
dequant_scale_q_nope_shape.append(1)
else:
query_shape = []
query_shape.append(token_x.size(0))
query_shape.append(weight_uk.size(0))
query_shape.append(weight_uk.size(2))
query_rope_shape = []
query_rope_shape.append(token_x.size(0))
query_rope_shape.append(weight_uk.size(0))
query_rope_shape.append(rope_sin.size(1))
dequant_scale_q_nope_shape = []
dequant_scale_q_nope_shape.append(token_x.size(0))
dequant_scale_q_nope_shape.append(weight_uk.size(0))
dequant_scale_q_nope_shape.append(1)
if token_x.dtype == torch.int8 and quant_scale_ckv is not None:
query = torch.empty(query_shape, dtype=torch.int8, device='meta')
dequant_scale_q_nope = torch.empty(dequant_scale_q_nope_shape, dtype=torch.float32, device='meta')
else:
query = torch.empty(query_shape, dtype=rope_sin.dtype, device='meta')
dequant_scale_q_nope = torch.empty([1], dtype=torch.float32, device='meta')
query_rope = torch.empty(query_rope_shape, dtype=torch.bfloat16, device='meta')
kv_cache_out = torch.empty_like(kv_cache, dtype=kv_cache.dtype, device='meta')
kr_cache_out = torch.empty_like(kr_cache, dtype=kr_cache.dtype, device='meta')
return (query, query_rope, kv_cache_out, kr_cache_out, dequant_scale_q_nope)
@impl(m, "npu_mla_prolog_v3")
def npu_mla_prolog_v3_forward(token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq, rmsnorm_gamma_ckv,
rope_sin, rope_cos, kv_cache, kr_cache, *, cache_index=None, dequant_scale_x=None, dequant_scale_w_dq=None, dequant_scale_w_uq_qr=None, dequant_scale_w_dkv_kr=None,
quant_scale_ckv=None, quant_scale_ckr=None, smooth_scales_cq=None, actual_seq_len=None, k_nope_clip_alpha=None, rmsnorm_epsilon_cq=1e-5, rmsnorm_epsilon_ckv=1e-5,
cache_mode="PA_BSND", query_norm_flag=False, weight_quant_mode=0, kv_cache_quant_mode=0, query_quant_mode=0, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, qc_qr_scale=1.0, kc_scale=1.0):
require_param = {"token_x": token_x, "weight_dq": weight_dq, "weight_uq_qr": weight_uq_qr, "weight_uk": weight_uk, "weight_dkv_kr": weight_dkv_kr, "rmsnorm_gamma_cq": rmsnorm_gamma_cq, "rmsnorm_gamma_ckv": rmsnorm_gamma_ckv, "rope_sin": rope_sin, "rope_cos": rope_cos, "kv_cache": kv_cache, "kr_cache": kr_cache}
if weight_quant_mode == 3:
torch._check(
get_soc_version() >= 260,
lambda: "When weight_quant_mode is 3, not support on this soc version." + ops_error(ErrCode.NOT_SUPPORT),
)
for item_name, item in require_param.items():
torch._check(
item is not None,
lambda: item_name + " should not be None, but the actual value is None" + ops_error(ErrCode.VALUE),
)
token_x_dim = token_x.dim()
torch._check(
token_x_dim == 2 or token_x_dim == 3,
lambda: "token_x dim num should be 2 or 3, but the actual value is " + str(token_x_dim) + ops_error(ErrCode.VALUE),
)
weight_uk_dim = weight_uk.dim()
torch._check(
weight_uk_dim == 3,
lambda: "weight_uk dim num should be 3, but the actual value is " + str(weight_uk_dim) + ops_error(ErrCode.VALUE),
)
rope_sin_dim = rope_sin.dim()
if token_x_dim == 3:
torch._check(
rope_sin_dim == 3,
lambda: "when token_x dim num is 3, rope_sin dim num should be 3, but the actual value is " + str(rope_sin_dim) + ops_error(ErrCode.VALUE),
)
query_shape = []
query_shape.append(token_x.size(0))
query_shape.append(token_x.size(1))
query_shape.append(weight_uk.size(0))
query_shape.append(weight_uk.size(2))
query_rope_shape = []
query_rope_shape.append(token_x.size(0))
query_rope_shape.append(token_x.size(1))
query_rope_shape.append(weight_uk.size(0))
query_rope_shape.append(rope_sin.size(2))
dequant_scale_q_nope_shape = []
dequant_scale_q_nope_shape.append(token_x.size(0) * token_x.size(1))
dequant_scale_q_nope_shape.append(weight_uk.size(0))
dequant_scale_q_nope_shape.append(1)
query_norm_shape = []
query_norm_shape.append(token_x.size(0))
query_norm_shape.append(token_x.size(1))
query_norm_shape.append(weight_dq.size(1))
dequant_scale_q_norm_shape = []
dequant_scale_q_norm_shape.append(token_x.size(0) * token_x.size(1))
if weight_quant_mode == 3:
dequant_scale_q_norm_shape.append(int(weight_dq.size(1) / FP8_E4M3_BLOCK_SIZE))
else:
dequant_scale_q_norm_shape.append(1)
else:
torch._check(
rope_sin_dim == 2,
lambda: "when token_x dim num is 2, rope_sin dim num should be 2, but the actual value is " + str(rope_sin_dim) + ops_error(ErrCode.VALUE),
)
query_shape = []
query_shape.append(token_x.size(0))
query_shape.append(weight_uk.size(0))
query_shape.append(weight_uk.size(2))
query_rope_shape = []
query_rope_shape.append(token_x.size(0))
query_rope_shape.append(weight_uk.size(0))
query_rope_shape.append(rope_sin.size(1))
dequant_scale_q_nope_shape = []
dequant_scale_q_nope_shape.append(token_x.size(0))
dequant_scale_q_nope_shape.append(weight_uk.size(0))
dequant_scale_q_nope_shape.append(1)
query_norm_shape = []
query_norm_shape.append(token_x.size(0))
query_norm_shape.append(weight_dq.size(1))
dequant_scale_q_norm_shape = []
dequant_scale_q_norm_shape.append(token_x.size(0))
if weight_quant_mode == 3:
dequant_scale_q_norm_shape.append(int(weight_dq.size(1) / FP8_E4M3_BLOCK_SIZE))
else:
dequant_scale_q_norm_shape.append(1)
is_cann_version_gte_required = torch_npu.npu.utils._is_gte_cann_version("8.5.0.alpha003", "CANN")
if weight_quant_mode == 3 and kv_cache_quant_mode == 1:
query = torch.empty(query_shape, dtype=torch.float8_e4m3fn, device='meta')
dequant_scale_q_nope = torch.empty(dequant_scale_q_nope_shape, dtype=torch.float32, device='meta')
elif weight_quant_mode == 2 and kv_cache_quant_mode == 1:
query = torch.empty(query_shape, dtype=torch.int8, device='meta')
dequant_scale_q_nope = torch.empty(dequant_scale_q_nope_shape, dtype=torch.float32, device='meta')
else:
query = torch.empty(query_shape, dtype=rope_sin.dtype, device='meta')
if is_cann_version_gte_required:
dequant_scale_q_nope = torch.empty([0], dtype=torch.float32, device='meta')
else:
dequant_scale_q_nope = torch.empty([1], dtype=torch.float32, device='meta')
if query_norm_flag:
query_norm = torch.empty(query_norm_shape, dtype=weight_uq_qr.dtype, device='meta')
if weight_quant_mode == 1 or weight_quant_mode == 2:
dequant_scale_q_norm = torch.empty(dequant_scale_q_norm_shape, dtype=torch.float32, device='meta')
elif weight_quant_mode == 3:
dequant_scale_q_norm = torch.empty(dequant_scale_q_norm_shape, dtype=torch.float8_e8m0fnu, device='meta')
else:
if is_cann_version_gte_required:
dequant_scale_q_norm = torch.empty([0], dtype=torch.float32, device='meta')
else:
dequant_scale_q_norm = torch.empty([1], dtype=torch.float32, device='meta')
else:
if is_cann_version_gte_required:
query_norm = torch.empty([0], dtype=weight_uq_qr.dtype, device='meta')
dequant_scale_q_norm = torch.empty([0], dtype=torch.float8_e8m0fnu if weight_quant_mode == 3 else torch.float32, device='meta')
else:
query_norm = torch.empty([1], dtype=weight_uq_qr.dtype, device='meta')
dequant_scale_q_norm = torch.empty([1], dtype=torch.float8_e8m0fnu if weight_quant_mode == 3 else torch.float32, device='meta')
query_rope = torch.empty(query_rope_shape, dtype=torch.bfloat16, device='meta')
return (query, query_rope, dequant_scale_q_nope, query_norm, dequant_scale_q_norm)
@impl(m, "npu_mla_prolog_v3_functional")
def npu_mla_prolog_v3_functional_forward(token_x, weight_dq, weight_uq_qr, weight_uk, weight_dkv_kr, rmsnorm_gamma_cq, rmsnorm_gamma_ckv,
rope_sin, rope_cos, kv_cache, kr_cache, *, cache_index=None, dequant_scale_x=None, dequant_scale_w_dq=None, dequant_scale_w_uq_qr=None, dequant_scale_w_dkv_kr=None,
quant_scale_ckv=None, quant_scale_ckr=None, smooth_scales_cq=None, actual_seq_len=None, k_nope_clip_alpha=None, rmsnorm_epsilon_cq=1e-5, rmsnorm_epsilon_ckv=1e-5,
cache_mode="PA_BSND", query_norm_flag=False, weight_quant_mode=0, kv_cache_quant_mode=0, query_quant_mode=0, ckvkr_repo_mode=0, quant_scale_repo_mode=0, tile_size=128, qc_qr_scale=1.0, kc_scale=1.0):
require_param = {"token_x": token_x, "weight_dq": weight_dq, "weight_uq_qr": weight_uq_qr, "weight_uk": weight_uk, "weight_dkv_kr": weight_dkv_kr, "rmsnorm_gamma_cq": rmsnorm_gamma_cq, "rmsnorm_gamma_ckv": rmsnorm_gamma_ckv, "rope_sin": rope_sin, "rope_cos": rope_cos, "kv_cache": kv_cache, "kr_cache": kr_cache}
if weight_quant_mode == 3:
torch._check(
get_soc_version() >= 260,
lambda: "When weight_quant_mode is 3, not support on this soc version." + ops_error(ErrCode.NOT_SUPPORT),
)
for item_name, item in require_param.items():
torch._check(
item is not None,
lambda: item_name + " should not be None, but the actual value is None" + ops_error(ErrCode.VALUE),
)
token_x_dim = token_x.dim()
torch._check(
token_x_dim == 2 or token_x_dim == 3,
lambda: "token_x dim num should be 2 or 3, but the actual value is " + str(token_x_dim) + ops_error(ErrCode.VALUE),
)
weight_uk_dim = weight_uk.dim()
torch._check(
weight_uk_dim == 3,
lambda: "weight_uk dim num should be 3, but the actual value is " + str(weight_uk_dim) + ops_error(ErrCode.VALUE),
)
rope_sin_dim = rope_sin.dim()
if token_x_dim == 3:
torch._check(
rope_sin_dim == 3,
lambda: "when token_x dim num is 3, rope_sin dim num should be 3, but the actual value is " + str(rope_sin_dim) + ops_error(ErrCode.VALUE),
)
query_shape = []
query_shape.append(token_x.size(0))
query_shape.append(token_x.size(1))
query_shape.append(weight_uk.size(0))
query_shape.append(weight_uk.size(2))
query_rope_shape = []
query_rope_shape.append(token_x.size(0))
query_rope_shape.append(token_x.size(1))
query_rope_shape.append(weight_uk.size(0))
query_rope_shape.append(rope_sin.size(2))
dequant_scale_q_nope_shape = []
dequant_scale_q_nope_shape.append(token_x.size(0) * token_x.size(1))
dequant_scale_q_nope_shape.append(weight_uk.size(0))
dequant_scale_q_nope_shape.append(1)
query_norm_shape = []
query_norm_shape.append(token_x.size(0))
query_norm_shape.append(token_x.size(1))
query_norm_shape.append(weight_dq.size(1))
dequant_scale_q_norm_shape = []
dequant_scale_q_norm_shape.append(token_x.size(0) * token_x.size(1))
if weight_quant_mode == 3:
dequant_scale_q_norm_shape.append(int(weight_dq.size(1) / FP8_E4M3_BLOCK_SIZE))
else:
dequant_scale_q_norm_shape.append(1)
else:
torch._check(
rope_sin_dim == 2,
lambda: "when token_x dim num is 2, rope_sin dim num should be 2, but the actual value is " + str(rope_sin_dim) + ops_error(ErrCode.VALUE),
)
query_shape = []
query_shape.append(token_x.size(0))
query_shape.append(weight_uk.size(0))
query_shape.append(weight_uk.size(2))
query_rope_shape = []
query_rope_shape.append(token_x.size(0))
query_rope_shape.append(weight_uk.size(0))
query_rope_shape.append(rope_sin.size(1))
dequant_scale_q_nope_shape = []
dequant_scale_q_nope_shape.append(token_x.size(0))
dequant_scale_q_nope_shape.append(weight_uk.size(0))
dequant_scale_q_nope_shape.append(1)
query_norm_shape = []
query_norm_shape.append(token_x.size(0))
query_norm_shape.append(weight_dq.size(1))
dequant_scale_q_norm_shape = []
dequant_scale_q_norm_shape.append(token_x.size(0))
if weight_quant_mode == 3:
dequant_scale_q_norm_shape.append(int(weight_dq.size(1) / FP8_E4M3_BLOCK_SIZE))
else:
dequant_scale_q_norm_shape.append(1)
is_cann_version_gte_required = torch_npu.npu.utils._is_gte_cann_version("8.5.0.alpha003", "CANN")
if weight_quant_mode == 3 and kv_cache_quant_mode == 1:
query = torch.empty(query_shape, dtype=torch.float8_e4m3fn, device='meta')
dequant_scale_q_nope = torch.empty(dequant_scale_q_nope_shape, dtype=torch.float32, device='meta')
elif weight_quant_mode == 2 and kv_cache_quant_mode == 1:
query = torch.empty(query_shape, dtype=torch.int8, device='meta')
dequant_scale_q_nope = torch.empty(dequant_scale_q_nope_shape, dtype=torch.float32, device='meta')
else:
query = torch.empty(query_shape, dtype=rope_sin.dtype, device='meta')
if is_cann_version_gte_required:
dequant_scale_q_nope = torch.empty([0], dtype=torch.float32, device='meta')
else:
dequant_scale_q_nope = torch.empty([1], dtype=torch.float32, device='meta')
query_rope = torch.empty(query_rope_shape, dtype=torch.bfloat16, device='meta')
if query_norm_flag:
query_norm = torch.empty(query_norm_shape, dtype=weight_uq_qr.dtype, device='meta')
if weight_quant_mode == 1 or weight_quant_mode == 2:
dequant_scale_q_norm = torch.empty(dequant_scale_q_norm_shape, dtype=torch.float32, device='meta')
elif weight_quant_mode == 3:
dequant_scale_q_norm = torch.empty(dequant_scale_q_norm_shape, dtype=torch.float8_e8m0fnu, device='meta')
else:
if is_cann_version_gte_required:
dequant_scale_q_norm = torch.empty([0], dtype=torch.float32, device='meta')
else:
dequant_scale_q_norm = torch.empty([1], dtype=torch.float32, device='meta')
else:
if is_cann_version_gte_required:
query_norm = torch.empty([0], dtype=weight_uq_qr.dtype, device='meta')
dequant_scale_q_norm = torch.empty([0], dtype=torch.float8_e8m0fnu if weight_quant_mode == 3 else torch.float32, device='meta')
else:
query_norm = torch.empty([1], dtype=weight_uq_qr.dtype, device='meta')
dequant_scale_q_norm = torch.empty([1], dtype=torch.float8_e8m0fnu if weight_quant_mode == 3 else torch.float32, device='meta')
kv_cache_out = torch.empty_like(kv_cache, dtype=kv_cache.dtype, device='meta')
kr_cache_out = torch.empty_like(kr_cache, dtype=kr_cache.dtype, device='meta')
return (query, query_rope, dequant_scale_q_nope, query_norm, dequant_scale_q_norm, kv_cache_out, kr_cache_out)
if "2.1." in torch.__version__:
@impl(m, "npu_prompt_flash_attention")
def npu_prompt_flash_attention_forward(query, key, value, *, padding_mask=None, atten_mask=None, pse_shift=None, actual_seq_lengths=None, deq_scale1=None, quant_scale1=None, deq_scale2=None, quant_scale2=None, quant_offset2=None, num_heads=1, scale_value=1.0, pre_tokens=2147483647, next_tokens=0, input_layout="BSH", num_key_value_heads=0, actual_seq_lengths_kv=None, sparse_mode=0):
tmp_out = torch.empty_like(query, dtype=query.dtype, device='meta')
if input_layout == "BNSD_BSND":
tmp_out = torch.empty([query.size(0), query.size(2), query.size(1), query.size(3)], dtype=query.dtype, device='meta')
elif input_layout == "SH":
tmp_out = torch.empty([query.size(0), query.size(1)], dtype=query.dtype, device='meta')
elif input_layout == "BSH" or input_layout == "NSD":
tmp_out = torch.empty([query.size(0), query.size(1), query.size(2)], dtype=query.dtype, device='meta')
elif input_layout == "TND":
tmp_out = torch.empty([query.size(0), query.size(1), value.size(2)], dtype=query.dtype, device='meta')
elif input_layout == "BNSD":
tmp_out = torch.empty([query.size(0), query.size(1), query.size(2), query.size(3)],
dtype=query.dtype, device='meta')
elif input_layout == "BSND":
tmp_out = torch.empty([query.size(0), query.size(1), query.size(2), query.size(3)],
dtype=query.dtype, device='meta')
else:
torch._check(
False,
lambda: "not support layout: " + str(input_layout) + ops_error(ErrCode.VALUE),
)
if quant_scale2 is not None:
return torch.empty_like(tmp_out, dtype=torch.int8)
elif query.dtype == torch.int8:
return torch.empty_like(tmp_out, dtype=torch.half)
else:
return torch.empty_like(tmp_out, dtype=query.dtype)
else:
@impl(m, "npu_prompt_flash_attention")
def npu_prompt_flash_attention_forward(query, key, value, *, padding_mask=None, atten_mask=None, pse_shift=None, actual_seq_lengths=None, deq_scale1=None, quant_scale1=None, deq_scale2=None, quant_scale2=None, quant_offset2=None, num_heads=1, scale_value=1.0, pre_tokens=2147483647, next_tokens=0, input_layout="BSH", num_key_value_heads=0, actual_seq_lengths_kv=None, sparse_mode=0):
tmp_out = torch.empty_like(query, dtype=query.dtype, device='meta')
if input_layout == "TND":
tmp_out = torch.empty([query.size(0), query.size(1), value.size(2)], dtype=query.dtype, device='meta')
if input_layout == "BNSD_BSND":
tmp_out = torch.empty([query.size(0), query.size(2), query.size(1), query.size(3)], dtype=query.dtype, device='meta')
if quant_scale2 is not None:
return torch.empty_like(tmp_out, dtype=torch.int8)
elif query.dtype == torch.int8:
return torch.empty_like(tmp_out, dtype=torch.half)
else:
return torch.empty_like(tmp_out, dtype=query.dtype)
@impl(m, "npu_mm_reduce_scatter_base")
def npu_mm_reduce_scatter_base_meta(self, x2, hcom, world_size, reduce_op='sum',
bias=None, x1_scale=None, x2_scale=None, comm_turn=0,
output_dtype=None, comm_mode=None):
if world_size <= 0:
world_size = 1
out_m = math.floor(self.size(0) / world_size)
dtype = self.dtype
size = [out_m, x2.size(1)]
if x2_scale is not None:
if x2_scale.dtype == torch.int64:
dtype = torch.float16
elif output_dtype is not None:
dtype = output_dtype
else:
dtype = torch.bfloat16
return torch.empty(size, dtype=dtype, device='meta')
@impl(m, "npu_quant_mm_reduce_scatter")
def npu_quant_mm_reduce_scatter_meta(self, x2, hcom, world_size, reduce_op='sum',
bias=None, x1_scale=None, x2_scale=None, quant_scale=None,
block_size=0, comm_turn=0, group_sizes=None, amax_output=False, y_dtype=None,
x1_dtype=None, x2_dtype=None, x1_scale_dtype=None, x2_scale_dtype=None):
if world_size <= 0:
raise RuntimeError("world_size must be bigger than zero")
out_m = math.floor(self.size(0) / world_size)
torch_dtype = self.dtype if y_dtype is None else TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[y_dtype]
return (self.new_empty(out_m, x2.size(1), dtype=torch_dtype), self.new_empty(0, dtype=torch.float32))
@impl(m, "npu_quant_reduce_scatter")
def npu_quant_reduce_scatter_meta(x, scales, hcom_name, world_size, reduce_op='sum',
output_dtype=None, x_dtype=None, scales_dtype=None):
torch._check(
x is not None,
lambda: "x cannot be None, please input some value" + ops_error(ErrCode.TYPE),
)
torch._check(
scales is not None,
lambda: "scales cannot be None, please input some value" + ops_error(ErrCode.TYPE),
)
world_size_list = [2, 4, 8]
torch._check(
world_size in world_size_list,
lambda: "world_size must be in " + str(world_size_list) + ", but actual value is: " + str(world_size) + ops_error(ErrCode.VALUE),
)
if x.dim() == 2:
size = [x.size(0) // world_size, x.size(1)]
if x.dim() == 3:
size = [x.size(0) * x.size(1) // world_size, x.size(2)]
dtype = x.dtype
if output_dtype is not None:
dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[output_dtype]
else:
dtype = torch.bfloat16
return torch.empty(size, dtype=dtype, device='meta')
@impl(m, "npu_quant_all_reduce")
def npu_quant_all_reduce_meta(x, scales, hcom_name, world_size, reduce_op='sum',
output_dtype=None, x_dtype=None, scales_dtype=None):
torch._check(
x is not None,
lambda: "x cannot be None, please input some value" + ops_error(ErrCode.TYPE),
)
torch._check(
scales is not None,
lambda: "scales cannot be None, please input some value" + ops_error(ErrCode.TYPE),
)
world_size_list = [2, 4, 8]
torch._check(
world_size in world_size_list,
lambda: "world_size must be in " + str(world_size_list) + ", but actual value is: " + str(world_size) + ops_error(ErrCode.VALUE),
)
size = x.size()
dtype = x.dtype
if output_dtype is not None:
dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[output_dtype]
else:
dtype = torch.bfloat16
return torch.empty(size, dtype=dtype, device='meta')
@impl(m, "npu_gmm_alltoallv")
def npu_gmm_alltoallv_meta(gmm_x, gmm_weight, hcom, ep_world_size, send_counts,
recv_counts, *, send_counts_tensor=None,
recv_counts_tensor=None, mm_x=None,
mm_weight=None, trans_gmm_weight=False,
trans_mm_weight=False):
if ep_world_size <= 0:
ep_world_size = 1
out_x = sum(recv_counts)
out_y = gmm_weight.size(2)
if trans_gmm_weight:
out_y = gmm_weight.size(1)
out_mm_x = 0
out_mm_y = 0
y = None
mm_y = None
if mm_x is not None:
out_mm_x = mm_x.size(0)
out_mm_y = mm_weight.size(1)
if trans_mm_weight:
out_mm_y = mm_weight.size(0)
mm_y = torch.empty([out_mm_x, out_mm_y], dtype=mm_x.dtype, device='meta')
y = torch.empty([out_x, out_y], dtype=gmm_x.dtype, device='meta')
return (y, mm_y)
@impl(m, "npu_quant_gmm_alltoallv")
def npu_quant_gmm_alltoallv_meta(gmm_x, gmm_weight, gmm_x_scale, gmm_weight_scale, hcom, ep_world_size,
send_counts, recv_counts, gmm_y_dtype, *, send_counts_tensor=None,
recv_counts_tensor=None, mm_x=None, mm_weight=None, mm_x_scale=None,
mm_weight_scale=None, comm_quant_scale=None, gmm_x_quant_mode=None,
gmm_weight_quant_mode=None, mm_x_quant_mode=None, mm_weight_quant_mode=None,
comm_quant_mode=None, group_size=None, gmm_x_dtype=None, gmm_weight_dtype=None,
gmm_x_scale_dtype=None, gmm_weight_scale_dtype=None, mm_x_dtype=None,
mm_weight_dtype=None, mm_x_scale_dtype=None, mm_weight_scale_dtype=None,
comm_quant_dtype=None, mm_y_dtype=None):
if ep_world_size <= 0:
ep_world_size = 1
if gmm_x is not None:
torch._check(
gmm_x.dim() == 2,
lambda: f"The gmm_x's dim should be 2, but got {gmm_x.dim()}.",
)
if gmm_weight is not None:
torch._check(
gmm_weight.dim() == 3,
lambda: f"The gmm_weight's dim should be 3, but got {gmm_weight.dim()}.",
)
if mm_x is not None:
torch._check(
mm_x.dim() == 2,
lambda: f"The mm_x's dim should be 2, but got {mm_x.dim()}.",
)
if mm_weight is not None:
torch._check(
mm_weight.dim() == 2,
lambda: f"The mm_weight's dim should be 2, but got {mm_weight.dim()}.",
)
scale_list = [gmm_x_scale, gmm_weight_scale, mm_x_scale, mm_weight_scale]
for scale in scale_list:
if scale is not None:
torch._check(
scale.dim() == 1 and scale.size(0) == 1,
lambda: f"Scale's shape should be (1,), but got {list(scale.shape)}.",
)
out_x = sum(recv_counts)
out_y = gmm_weight.size(2)
out_mm_x = 0
out_mm_y = 0
y = None
mm_y = None
out_scalar_type = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(gmm_y_dtype)
if mm_x is not None:
out_mm_x = mm_x.size(0)
out_mm_y = mm_weight.size(1)
mm_out_scalar_type = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(mm_y_dtype)
mm_y = torch.empty([out_mm_x, out_mm_y], dtype=mm_out_scalar_type, device='meta')
y = torch.empty([out_x, out_y], dtype=out_scalar_type, device='meta')
return (y, mm_y)
@impl(m, "npu_fused_matmul")
def npu_fused_matmul_meta(x, x2, *, bias=None, x3=None, fused_op_type=''):
torch._check(
x is not None,
lambda: "x must not be None, please input some value" + ops_error(ErrCode.TYPE),
)
torch._check(
x2 is not None,
lambda: "x2 must not be None, please input some value" + ops_error(ErrCode.TYPE),
)
torch._check(
x.dtype == x2.dtype,
lambda: "x and x2 type not same. x.dtype is " + str(x.dtype) + ", x2.dtype is " + str(x2.dtype) + ops_error(ErrCode.TYPE),
)
x_dim = x.dim()
torch._check(
x_dim >= 2,
lambda: "x dim num cannot be less than 2,but the actual value is " + str(x_dim) + ops_error(ErrCode.VALUE),
)
x2_dim = x2.dim()
torch._check(
x2_dim >= 2,
lambda: "x2 dim num cannot be less than 2,but the actual value is " + str(x2_dim) + ops_error(ErrCode.VALUE),
)
ma = x.size(x_dim - 2)
ka = x.size(x_dim - 1)
kb = x2.size(x2_dim - 2)
nb = x2.size(x2_dim - 1)
torch._check(
ka == kb,
lambda: "ka and kb should be the same" + ops_error(ErrCode.TYPE),
)
out_dim_num = max(x_dim, x2_dim)
shape_long = x if x_dim > x2_dim else x2
shape_short = x2 if x_dim > x2_dim else x
vaild_offset = out_dim_num - min(x_dim, x2_dim)
output_shape = []
for i in range(0, out_dim_num - 2):
short_dim = 1 if i < vaild_offset else shape_short.size(i - vaild_offset)
long_dim = shape_long.size(i)
torch._check(
not (short_dim > 1 and long_dim > 1 and short_dim != long_dim),
lambda: "the batch shape cannot be broadcast" + ops_error(ErrCode.VALUE),
)
cur_batch_val = max(short_dim, long_dim)
output_shape.append(cur_batch_val)
output_shape.append(ma)
output_shape.append(nb)
if fused_op_type == "gelu_erf" or fused_op_type == "gelu_tanh":
torch._check(
x3 is None,
lambda: "there is no x3 for gelu_erf and gelu_tanh" + ops_error(ErrCode.TYPE),
)
if fused_op_type == "add" or fused_op_type == "mul":
torch._check(
x3 is not None,
lambda: "there must have x3 for add and mul" + ops_error(ErrCode.TYPE),
)
result = torch.empty(output_shape, dtype=x.dtype, device='meta')
return torch.empty_like(result, dtype=x.dtype)
@impl(m, "npu_alltoallv_gmm")
def npu_alltoallv_gmm_meta(gmm_x, gmm_weight, hcom, ep_world_size, send_counts,
recv_counts, *, send_counts_tensor=None,
recv_counts_tensor=None, mm_x=None,
mm_weight=None, trans_gmm_weight=False,
trans_mm_weight=False, permute_out_flag=False):
if ep_world_size <= 0:
ep_world_size = 1
out_x = sum(recv_counts)
out_y = gmm_weight.size(2)
if trans_gmm_weight:
out_y = gmm_weight.size(1)
out_mm_x = 0
out_mm_y = 0
permute_out_x = 0
permute_out_y = 0
gmm_y = None
mm_y = None
permute_out = None
if mm_x is not None:
out_mm_x = mm_x.size(0)
out_mm_y = mm_weight.size(1)
if trans_mm_weight:
out_mm_y = mm_weight.size(0)
mm_y = torch.empty([out_mm_x, out_mm_y], dtype=mm_x.dtype, device='meta')
if permute_out_flag:
permute_out_x = out_x
permute_out_y = gmm_x.size(1)
permute_out = torch.empty([permute_out_x, permute_out_y], dtype=gmm_x.dtype, device='meta')
gmm_y = torch.empty([out_x, out_y], dtype=gmm_x.dtype, device='meta')
return (gmm_y, mm_y, permute_out)
@impl(m, "npu_all_gather_base_mm")
def npu_all_gather_base_mm_meta(self, x2, hcom, world_size, bias=None,
x1_scale=None, x2_scale=None,
gather_index=0, gather_output=True, comm_turn=0,
output_dtype=None, comm_mode=None):
if world_size <= 0:
world_size = 1
out_x = self.size(0)
if gather_index == 0:
out_x = self.size(0) * world_size
out_y = x2.size(1)
out_gather_x = x2.size(0) * world_size
out_gather_y = x2.size(1)
if gather_index == 0:
out_gather_x = self.size(0) * world_size
out_gather_y = self.size(1)
out_size = (out_x, out_y)
gather_output_size = 0
if gather_output:
gather_output_size = (out_gather_x, out_gather_y)
dtype = self.dtype
if x2_scale is not None:
if x2_scale.dtype == torch.int64:
dtype = torch.float16
elif output_dtype is not None:
dtype = output_dtype
else:
dtype = torch.bfloat16
return (torch.empty(out_size, dtype=dtype, device='meta'),
torch.empty(gather_output_size, dtype=self.dtype, device='meta'))
@impl(m, "npu_alltoallv_quant_gmm")
def npu_alltoallv_quant_gmm_meta(gmm_x, gmm_weight, gmm_x_scale, gmm_weight_scale, hcom, ep_world_size,
send_counts, recv_counts, gmm_y_dtype, *, send_counts_tensor=None,
recv_counts_tensor=None, mm_x=None, mm_weight=None, mm_x_scale=None,
mm_weight_scale=None, gmm_x_quant_mode=None, gmm_weight_quant_mode=None,
mm_x_quant_mode=None, mm_weight_quant_mode=None, permute_out_flag=False,
group_size=None, gmm_x_dtype=None, gmm_weight_dtype=None, gmm_x_scale_dtype=None,
gmm_weight_scale_dtype=None, mm_x_dtype=None, mm_weight_dtype=None,
mm_x_scale_dtype=None, mm_weight_scale_dtype=None, mm_y_dtype=None):
if ep_world_size <= 0:
ep_world_size = 1
if gmm_x is not None:
torch._check(
gmm_x.dim() == 2,
lambda: f"The gmm_x's dim should be 2, but got {gmm_x.dim()}.",
)
if gmm_weight is not None:
torch._check(
gmm_weight.dim() == 3,
lambda: f"The gmm_weight's dim should be 3, but got {gmm_weight.dim()}.",
)
if mm_x is not None:
torch._check(
mm_x.dim() == 2,
lambda: f"The mm_x's dim should be 2, but got {mm_x.dim()}.",
)
if mm_weight is not None:
torch._check(
mm_weight.dim() == 2,
lambda: f"The mm_weight's dim should be 2, but got {mm_weight.dim()}.",
)
scale_list = [gmm_x_scale, gmm_weight_scale, mm_x_scale, mm_weight_scale]
for scale in scale_list:
if scale is not None:
torch._check(
scale.dim() == 1 and scale.size(0) == 1,
lambda: f"Scale's shape should be (1,), but got {list(scale.shape)}.",
)
out_x = sum(recv_counts)
out_y = gmm_weight.size(2)
out_mm_x = 0
out_mm_y = 0
permute_out_x = 0
permute_out_y = 0
gmm_y = None
mm_y = None
permute_out = None
out_scalar_type = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(gmm_y_dtype)
if mm_x is not None:
out_mm_x = mm_x.size(0)
out_mm_y = mm_weight.size(1)
mm_out_scalar_type = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(mm_y_dtype)
mm_y = torch.empty([out_mm_x, out_mm_y], dtype=mm_out_scalar_type, device='meta')
if permute_out_flag:
permute_out_x = out_x
permute_out_y = gmm_x.size(1)
permute_out = gmm_x.new_empty(permute_out_x, permute_out_y)
gmm_y = torch.empty([out_x, out_y], dtype=out_scalar_type, device='meta')
return (gmm_y, mm_y, permute_out)
@impl(m, "npu_all_gather_quant_mm")
def npu_all_gather_quant_mm_meta(self, x2, hcom, world_size, bias=None, x1_scale=None, x2_scale=None,
quant_scale=None, block_size=0, gather_index=0, gather_output=True,
comm_turn=0, group_sizes=None, amax_output=False, y_dtype=None, x1_dtype=None,
x2_dtype=None, x1_scale_dtype=None, x2_scale_dtype=None):
if world_size <= 0:
raise RuntimeError("world_size must be bigger than zero")
out_x = self.size(0)
if gather_index == 0:
out_x = self.size(0) * world_size
out_y = x2.size(1)
out_gather_x = x2.size(0) * world_size
out_gather_y = x2.size(1)
if gather_index == 0:
out_gather_x = self.size(0) * world_size
out_gather_y = self.size(1)
torch_dtype = self.dtype if y_dtype is None else TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[y_dtype]
if gather_output:
return (self.new_empty((out_x, out_y), dtype=torch_dtype), self.new_empty(out_gather_x, out_gather_y),
self.new_empty(0, dtype=torch.float32))
else:
return (self.new_empty((out_x, out_y), dtype=torch_dtype), self.new_empty(0),
self.new_empty(0, dtype=torch.float32))
@impl(m, "npu_mhc_post")
def npu_mhc_post(x, h_res, h_out, h_post):
y_size = tuple(x.shape)
return torch.empty(y_size, dtype=x.dtype, device='meta')
@impl(m, "npu_moe_init_routing")
def npu_moe_init_routing_meta(x, row_idx, expert_idx, active_num=99):
n = x.size(0)
h = x.size(1)
k = row_idx.size(1)
active_num = min(n, active_num)
expanded_x_dim_list = [active_num * k, h]
expanded_row_idx_dim_list = [n * k]
expanded_expert_idx_dim_list = [n * k]
return (x.new_empty(tuple(expanded_x_dim_list)), row_idx.new_empty(tuple(expanded_row_idx_dim_list)), row_idx.new_empty(tuple(expanded_row_idx_dim_list)))
@impl(m, "npu_mhc_sinkhorn")
def npu_mhc_sinkhorn(x, eps, num_iters, out_flag):
x_shape = list(x.shape)
torch._check(
x_shape[-1] in [4, 6, 8],
lambda: f"The x's dim should be 4 or 6 or 8, but got {x_shape[-1]}.",
)
torch._check(
x.dim() == 3 or x.dim() == 4,
lambda: f"The x's dim should be 3 or 4, but got {x.dim()}.",
)
torch._check(
num_iters > 0 and num_iters <= 100,
lambda: f"num_iters must be within (0, 100], but got {num_iters}.",
)
y_size = tuple(x.shape)
return (torch.empty(y_size, dtype=x.dtype, device='meta'),
torch.empty(y_size, dtype=x.dtype, device='meta'),
torch.empty(y_size, dtype=x.dtype, device='meta'))
@impl(m, "npu_moe_init_routing_v2")
def npu_moe_init_routing_v2_meta(x, expert_idx, *, scale=None, offset=None, active_num=-1, expert_capacity=-1, expert_num=-1, drop_pad_mode=0, expert_tokens_num_type=0, expert_tokens_num_flag=False, quant_mode=-1, active_expert_range=[], row_idx_type=0, x_dtype=None):
x_dim = x.dim()
torch._check(
x_dim == 2,
lambda: "the x shape support only 2d" + ops_error(ErrCode.VALUE),
)
expert_idx_dim = expert_idx.dim()
torch._check(
expert_idx_dim == 2,
lambda: "the expert_idx shape support only 2d" + ops_error(ErrCode.VALUE),
)
torch._check(
x.size(0) == expert_idx.size(0),
lambda: "the first dim of expert_idx and x should be the same" + ops_error(ErrCode.VALUE),
)
if active_expert_range:
torch._check(
active_expert_range is not None and isinstance(active_expert_range, list) and len(active_expert_range) == 2,
lambda: "active_expert_range is None or invalid. must be int[2]"
)
torch._check(
active_expert_range[1] > active_expert_range[0],
lambda: "active_expert_range is invalid. must be increasing"
)
torch._check(
active_expert_range[0] >= 0 and active_expert_range[1] <= 10240,
lambda: "active_expert_range must be within [0, 10240]"
)
expert_range_length = active_expert_range[1] - active_expert_range[0]
else:
expert_range_length = expert_num
torch._check(
drop_pad_mode is not None and isinstance(drop_pad_mode, int) and drop_pad_mode in [0, 1],
lambda: "drop_pad_mode is None or invalid. must be in [0, 1]"
)
torch._check(
expert_tokens_num_type is not None and isinstance(expert_tokens_num_type, int) and expert_tokens_num_type in [0, 1, 2],
lambda: "expert_tokens_num_type is None or invalid. must be in [0, 1, 2]"
)
torch._check(
expert_tokens_num_flag is not None and isinstance(expert_tokens_num_flag, bool) and expert_tokens_num_flag in [True, False],
lambda: "expert_tokens_num_flag is None or invalid. must be in [True, False]"
)
torch._check(
quant_mode is not None and isinstance(quant_mode, int) and quant_mode in [-1, 0, 1, 2, 3, 6, 7, 8],
lambda: "quant_mode is None or invalid. must be in [-1, 0, 1, 2, 3, 6, 7, 8]"
)
torch._check(
row_idx_type is not None and isinstance(row_idx_type, int) and row_idx_type in [0, 1],
lambda: "row_idx_type is None or invalid. must be in [0, 1]"
)
if scale is not None:
scale_dim = scale.dim()
if quant_mode == -1:
torch._check(
scale_dim == 1,
lambda: "the scale shape support only 1D (bs,) in no quant mode" + ops_error(ErrCode.VALUE),
)
torch._check(
x.size(0) == scale.size(0),
lambda: "the first dim of scale and the first dim of x should be the same" + ops_error(ErrCode.VALUE),
)
elif quant_mode == 0:
torch._check(
scale_dim == 1,
lambda: "the scale shape support only 1D in static quant mode" + ops_error(ErrCode.VALUE),
)
torch._check(
scale.size(0) == 1,
lambda: "the shape of scale should be 1" + ops_error(ErrCode.VALUE),
)
if offset is not None:
offset_dim = offset.dim()
torch._check(
offset_dim == 1,
lambda: "the offset shape support only 1D" + ops_error(ErrCode.VALUE),
)
torch._check(
scale.size(0) == offset.size(0),
lambda: "the 1st dim of offset and the 1st dim of scale should be the same" + ops_error(ErrCode.VALUE),
)
elif quant_mode == 1:
torch._check(
scale_dim == 2,
lambda: "the scale shape support only 2D in dynamic quant mode" + ops_error(ErrCode.VALUE),
)
torch._check(
scale.size(0) in [expert_range_length, 1],
lambda: "the first dim of scale must be in [expert_range_length, 1]" + ops_error(ErrCode.VALUE),
)
torch._check(
x.size(1) == scale.size(1),
lambda: "the 2nd dim of scale should be the same with the 2nd dim of x" + ops_error(ErrCode.VALUE),
)
if quant_mode == 7:
torch._check(
scale_dim == 1,
lambda: "the scale shape support only 1D (bs,) in no quant mode" + ops_error(ErrCode.VALUE),
)
bs = x.size(0)
h = x.size(1)
k = expert_idx.size(1)
expanded_x_dtype = x.dtype
expanded_scale_dtype = torch.float32
if x_dtype == torch_npu.hifloat8:
expanded_x_dtype = torch.uint8
if quant_mode in [0, 1]:
expanded_x_dtype = torch.int8
elif quant_mode == 2:
expanded_x_dtype = torch.float8_e5m2
expanded_scale_dtype = torch.float8_e8m0fnu
elif quant_mode == 3:
expanded_x_dtype = torch.float8_e4m3fn
expanded_scale_dtype = torch.float8_e8m0fnu
elif quant_mode in [6, 7, 8]:
expanded_x_dtype = torch.uint8
if drop_pad_mode == 1:
expanded_x_dim_list = [expert_num, expert_capacity, h]
expanded_scale_dim_list = [expert_num * expert_capacity]
else:
num_expanded_rows = bs * k if active_num <= 0 else min(active_num, bs * k)
expanded_x_dim_list = [num_expanded_rows, h]
if quant_mode in [2, 3]:
MXQUANT_BLOCK_SIZE = 32
PAD_TO_EVEN_FACTOR = 2
scale_cols = (h + MXQUANT_BLOCK_SIZE - 1) // MXQUANT_BLOCK_SIZE
scale_cols = (scale_cols + PAD_TO_EVEN_FACTOR - 1) // PAD_TO_EVEN_FACTOR * PAD_TO_EVEN_FACTOR
expanded_scale_dim_list = [num_expanded_rows, scale_cols]
elif quant_mode in [-1, 1, 8]:
expanded_scale_dim_list = [num_expanded_rows]
if quant_mode in [0, 6, 7]:
expanded_scale_dim_list = []
expanded_row_idx_dim_list = [bs * k]
if not expert_tokens_num_flag:
expert_token_cumsum_or_count_dim_list = []
elif (expert_tokens_num_type in range(0, 2)):
expert_token_cumsum_or_count_dim_list = [expert_range_length]
elif (expert_tokens_num_type == 2):
expert_token_cumsum_or_count_dim_list = [expert_num, 2]
return (x.new_empty(tuple(expanded_x_dim_list), dtype=expanded_x_dtype),
x.new_empty(tuple(expanded_row_idx_dim_list), dtype=torch.int32),
x.new_empty(tuple(expert_token_cumsum_or_count_dim_list), dtype=torch.int64),
x.new_empty(tuple(expanded_scale_dim_list), dtype=expanded_scale_dtype))
@impl(m, "ffn_worker_scheduler_")
def ffn_worker_scheduler__meta(self, *, sync_group_size=1, execute_mode=0):
return self
@impl(m, "attention_worker_scheduler_")
def attention_worker_scheduler__meta(self):
return self
@impl(m, "ffn_worker_scheduler")
def ffn_worker_scheduler_meta(self, *, sync_group_size=1, execute_mode=0):
return torch.empty_like(self)
@impl(m, "attention_worker_scheduler")
def attention_worker_scheduler_meta(self):
return torch.empty_like(self)
@impl(m, "npu_ffn_worker_batching")
def npu_ffn_worker_batching(schedule_context, expert_num, max_out_shape, *, token_dtype=0, need_schedule=0, layer_num=0):
Y_size = max_out_shape[0] * max_out_shape[1] * max_out_shape[2]
H_size = max_out_shape[3]
H_dtype = torch.float16
if token_dtype == 1:
H_dtype = torch.bfloat16
if token_dtype == 2:
H_dtype = torch.int8
return (torch.empty(Y_size, H_size, dtype=H_dtype, device=schedule_context.device),
torch.empty(expert_num, 2, dtype=torch.int64, device=schedule_context.device),
torch.empty(Y_size, dtype=torch.int32, device=schedule_context.device),
torch.empty(Y_size, dtype=torch.int32, device=schedule_context.device),
torch.empty(Y_size, dtype=torch.int32, device=schedule_context.device),
torch.empty(Y_size, dtype=torch.int32, device=schedule_context.device),
torch.empty(Y_size, dtype=torch.float32, device=schedule_context.device),
torch.empty(1, dtype=torch.int64, device=schedule_context.device)
)
@impl(m, "npu_moe_gating_top_k_softmax")
def npu_moe_gating_top_k_softmax_meta(x, finished=None, k=1):
x_dim = x.dim()
torch._check(
x_dim == 2 or x_dim == 3,
lambda: "the x shape support only 2d and 3d)" + ops_error(ErrCode.VALUE),
)
if x_dim == 3:
y_dim_list = [x.size(0), x.size(1), k]
expert_idx_dim_list = [x.size(0), x.size(1), k]
row_idx_dim_list = [x.size(0), x.size(1), k]
else:
y_dim_list = [x.size(0), k]
expert_idx_dim_list = [x.size(0), k]
row_idx_dim_list = [x.size(0), k]
return (x.new_empty(tuple(y_dim_list), dtype=x.dtype),
x.new_empty(tuple(expert_idx_dim_list), dtype=torch.int32),
x.new_empty(tuple(row_idx_dim_list), dtype=torch.int32))
@impl(m, "npu_moe_gating_top_k_softmax_v2")
def npu_moe_gating_top_k_softmax_v2_meta(x, *, k=1, finished=None, renorm=0, output_softmax=False):
x_dim = x.dim()
torch._check(
x_dim == 2 or x_dim == 3,
lambda: "the x shape support only 2d and 3d)" + ops_error(ErrCode.VALUE),
)
if x_dim == 3:
y_dim_list = [x.size(0), x.size(1), k]
expert_idx_dim_list = [x.size(0), x.size(1), k]
else:
y_dim_list = [x.size(0), k]
expert_idx_dim_list = [x.size(0), k]
if renorm == 0 and output_softmax:
if x.dim == 3:
softmax_result_dim_list = [x.size(0), x.size(1), x.size(2)]
else:
softmax_result_dim_list = [x.size(0), x.size(1)]
else:
softmax_result_dim_list = [0, ]
return (x.new_empty(tuple(y_dim_list), dtype=x.dtype),
x.new_empty(tuple(expert_idx_dim_list), dtype=torch.int32),
x.new_empty(tuple(softmax_result_dim_list), dtype=torch.float32))
@impl(m, "npu_moe_gating_top_k")
def npu_moe_gating_top_k_meta(x, k=1, bias=None, k_group=1, group_count=1, group_select_mode=0, renorm=0, norm_type=0, out_flag=False, routed_scaling_factor=1.0, eps=1e-20):
x_dim = x.dim()
torch._check(
x_dim == 2,
lambda: "the x shape support only 2d)" + ops_error(ErrCode.VALUE),
)
if bias is not None:
bias_dim = bias.dim()
torch._check(
bias_dim == 1,
lambda: "the bias shape support only 1d)" + ops_error(ErrCode.VALUE),
)
y_dim_list = [x.size(0), k]
expert_idx_dim_list = [x.size(0), k]
y2_dim_list = [x.size(0), x.size(1)]
return (x.new_empty(tuple(y_dim_list), dtype=x.dtype),
x.new_empty(tuple(expert_idx_dim_list), dtype=torch.int32),
x.new_empty(tuple(y2_dim_list), dtype=torch.float32))
def get_query_and_attention_out_layout(query, input_layout):
class ParserLayout:
def __init__(self, qLayout: str, outLayout: str, qDim: int):
self.qLayout = qLayout
self.outLayout = outLayout
self.qDim = qDim
LAYOUT_MAP: Dict[str, ParserLayout] = {
"BSH": ParserLayout("BSH", "BSH", 3),
"BSND": ParserLayout("BSND", "BSND", 4),
"BNSD": ParserLayout("BNSD", "BNSD", 4),
"TND": ParserLayout("TND", "TND", 3),
"NTD": ParserLayout("NTD", "NTD", 3),
"BNSD_BSND": ParserLayout("BNSD", "BSND", 4),
"BSH_BNSD": ParserLayout("BSH", "BNSD", 3),
"BSND_BNSD": ParserLayout("BSND", "BNSD", 4),
"NTD_TND": ParserLayout("NTD", "TND", 3),
"BSH_NBSD": ParserLayout("BSH", "NBSD", 3),
"BSND_NBSD": ParserLayout("BSND", "NBSD", 4),
"BNSD_NBSD": ParserLayout("BNSD", "NBSD", 4),
"TND_NTD": ParserLayout("TND", "NTD", 3),
"NSD": ParserLayout("NSD", "NSD", 3)
}
if input_layout in LAYOUT_MAP:
layout_entry = LAYOUT_MAP[input_layout]
query_layout = layout_entry.qLayout
attention_out_layout = layout_entry.outLayout
query_dim = layout_entry.qDim
torch._check(
query.dim() == query_dim,
lambda: (
f"Layout {query_layout}, queryDims({query.dim()}) must be {query_dim}!" + ops_error(ErrCode.VALUE)
),
)
else:
torch._check(
False,
lambda: (
f"Layout {input_layout} is not supported!" + ops_error(ErrCode.VALUE)
),
)
return query_layout, attention_out_layout
def get_query_b_s_n_d(query, query_layout, num_heads):
if query_layout == "BSH":
b = query.size(0)
s1 = query.size(1)
n1 = num_heads
d1 = query.size(2) // num_heads
elif query_layout == "BSND":
b = query.size(0)
s1 = query.size(1)
n1 = query.size(2)
d1 = query.size(3)
elif query_layout == "BNSD":
b = query.size(0)
s1 = query.size(2)
n1 = query.size(1)
d1 = query.size(3)
elif query_layout == "NSD":
b = 1
s1 = query.size(1)
n1 = query.size(0)
d1 = query.size(2)
else:
torch._check(
False,
lambda: (
f"Layout {query_layout} is not supported in get_query_b_s_n_d function!" + ops_error(ErrCode.VALUE)
),
)
return b, s1, n1, d1
def get_query_t_n_d(query, query_layout):
if query_layout == "TND":
t = query.size(0)
n1 = query.size(1)
d1 = query.size(2)
elif query_layout == "NTD":
t = query.size(1)
n1 = query.size(0)
d1 = query.size(2)
else:
torch._check(
False,
lambda: (
f"Layout {query_layout} is not supported in get_query_t_n_d function!" + ops_error(ErrCode.VALUE)
),
)
return t, n1, d1
def get_value_d(block_table, value, query, query_layout, num_key_value_heads):
if block_table is not None:
if value.dim() == 3:
value_d = value.size(2) // num_key_value_heads
elif value.dim() == 4:
value_d = value.size(3)
elif value.dim() == 5:
value_d = value.size(2) * value.size(4)
else:
torch._check(
False,
lambda: "when Page Attention enabled, value's dim should be 3/4/5, but got " + str(value.dim()) +
ops_error(ErrCode.VALUE),
)
else:
torch._check(
value.dim() == query.dim(),
lambda: (
f"when Page Attention not enabled, value'dim{value.dim()} should equal to query's dim{query.dim()}!" +
ops_error(ErrCode.VALUE)
),
)
if query_layout == "BSH":
value_d = value.size(2) // num_key_value_heads
if query_layout == "BNSD" or query_layout == "BSND":
value_d = value.size(3)
if query_layout == "TND" or query_layout == "NTD" or query_layout == "NSD":
value_d = value.size(2)
return value_d
def get_change_d_scale(value):
change_d_scale = 1
if value is not None and value.dtype == torch.int32:
change_d_scale = 8
return change_d_scale
def get_change_d_scale_v2(value, value_dtype):
change_d_scale = 1
if value is None:
return change_d_scale
if value.dtype == torch.int32:
change_d_scale = 8
if (hasattr(torch, 'float4_e2m1fn_x2') and value.dtype == torch.float4_e2m1fn_x2) or value_dtype == torch_npu.float4_e2m1fn_x2:
change_d_scale = 2
if (hasattr(torch, 'float4_e1m2fn_x2') and value.dtype == torch.float4_e1m2fn_x2) or value_dtype == torch_npu.float4_e1m2fn_x2:
change_d_scale = 2
return change_d_scale
def get_infer_attention_out_d(query_d, value_d):
out_d = value_d
if out_d == 0 or query_d == 0:
out_d = query_d
return out_d
def infer_attention_out_shape(attention_out_layout, query, query_layout, num_heads, value_d):
attention_out = torch.empty_like(query, dtype=query.dtype, device='meta')
if attention_out_layout == "BSH":
b, s1, n1, _ = get_query_b_s_n_d(query, query_layout, num_heads)
out_h = n1 * value_d
if out_h == 0 or query.size(2) == 0:
out_h = query.size(2)
attention_out = torch.empty([b, s1, out_h], dtype=query.dtype, device='meta')
elif attention_out_layout == "BSND":
b, s1, n1, d1 = get_query_b_s_n_d(query, query_layout, num_heads)
out_d = get_infer_attention_out_d(d1, value_d)
attention_out = torch.empty([b, s1, n1, out_d], dtype=query.dtype, device='meta')
elif attention_out_layout == "BNSD":
b, s1, n1, d1 = get_query_b_s_n_d(query, query_layout, num_heads)
out_d = get_infer_attention_out_d(d1, value_d)
attention_out = torch.empty([b, n1, s1, out_d], dtype=query.dtype, device='meta')
elif attention_out_layout == "NBSD":
b, s1, n1, d1 = get_query_b_s_n_d(query, query_layout, num_heads)
out_d = get_infer_attention_out_d(d1, value_d)
attention_out = torch.empty([n1, b, s1, out_d], dtype=query.dtype, device='meta')
elif attention_out_layout == "TND":
t, n1, d1 = get_query_t_n_d(query, query_layout)
out_d = get_infer_attention_out_d(d1, value_d)
attention_out = torch.empty([t, n1, out_d], dtype=query.dtype, device='meta')
elif attention_out_layout == "NTD":
t, n1, d1 = get_query_t_n_d(query, query_layout)
out_d = get_infer_attention_out_d(d1, value_d)
attention_out = torch.empty([n1, t, out_d], dtype=query.dtype, device='meta')
elif attention_out_layout == "NSD":
_, s1, n1, d1 = get_query_b_s_n_d(query, query_layout, num_heads)
out_d = get_infer_attention_out_d(d1, value_d)
attention_out = torch.empty([n1, s1, out_d], dtype=query.dtype, device='meta')
return attention_out
def infer_lse_out_shape(query, input_layout, query_layout, num_heads):
lse_out = torch.empty([0], dtype=torch.float32, device='meta')
tnd_like_layouts = {"TND", "NTD", "TND_NTD", "NTD_TND"}
if input_layout in tnd_like_layouts:
t, n1, _ = get_query_t_n_d(query, query_layout)
lse_out = torch.empty([t, n1, 1], dtype=torch.float32, device='meta')
else:
b, s1, n1, _ = get_query_b_s_n_d(query, query_layout, num_heads)
lse_out = torch.empty([b, n1, s1, 1], dtype=torch.float32, device='meta')
return lse_out
@impl(m, "_npu_attention_pioneer")
def npu_attention_pioneer_forward(query, key, value, *, atten_mask=None, actual_seq_lengths=None, actual_seq_lengths_kv=None,
block_table=None, query_rope=None, key_rope=None, key_sink=None, key_rope_sink=None, value_sink=None, num_heads=1, scale=1.0, pre_tokens=2147483647, next_tokens=2147483647,
input_layout="BSH", num_key_value_heads=0, sparse_mode=0, block_size=0, softmax_lse_flag=False):
torch._check(
num_heads > 0,
lambda: "numHeads should be greater than 0, but got " + str(num_heads) +
ops_error(ErrCode.VALUE),
)
num_key_value_heads = num_heads if num_key_value_heads == 0 else num_key_value_heads
query_layout, attention_out_layout = get_query_and_attention_out_layout(query, input_layout)
value_d = get_value_d(block_table, value, query, query_layout, num_key_value_heads)
change_d_scale = get_change_d_scale(value)
value_d = value_d * change_d_scale
tmp_out = infer_attention_out_shape(attention_out_layout, query, query_layout, num_heads, value_d)
if query.dtype == torch.int8:
if query_rope is not None:
attention_out = torch.empty_like(tmp_out, dtype=query_rope.dtype)
else:
attention_out = torch.empty_like(tmp_out, dtype=torch.half)
else:
attention_out = torch.empty_like(tmp_out, dtype=query.dtype)
tmp_lse_out = infer_lse_out_shape(query, input_layout, query_layout, num_heads)
if softmax_lse_flag:
lse_out = torch.empty_like(tmp_lse_out, dtype=torch.float32)
else:
lse_out = torch.empty([0], dtype=torch.float32, device='meta')
return attention_out, lse_out
@impl(m, "npu_fused_infer_attention_score")
def npu_fused_infer_attention_score_forward(query, key, value, *, pse_shift=None, atten_mask=None, actual_seq_lengths=None, actual_seq_lengths_kv=None,
dequant_scale1=None, quant_scale1=None, dequant_scale2=None, quant_scale2=None,
quant_offset2=None, antiquant_scale=None, antiquant_offset=None, block_table=None,
query_padding_size=None, kv_padding_size=None, key_antiquant_scale=None, key_antiquant_offset=None,
value_antiquant_scale=None, value_antiquant_offset=None, key_shared_prefix=None, value_shared_prefix=None,
actual_shared_prefix_len=None, query_rope=None, key_rope=None, key_rope_antiquant_scale=None, num_heads=1, scale=1.0, pre_tokens=2147483647, next_tokens=2147483647,
input_layout="BSH", num_key_value_heads=0, sparse_mode=0, inner_precise=0, block_size=0, antiquant_mode=0,
softmax_lse_flag=False, key_antiquant_mode=0, value_antiquant_mode=0):
torch._check(
num_heads > 0,
lambda: "numHeads should be greater than 0, but got " + str(num_heads) +
ops_error(ErrCode.VALUE),
)
num_key_value_heads = num_heads if num_key_value_heads == 0 else num_key_value_heads
query_layout, attention_out_layout = get_query_and_attention_out_layout(query, input_layout)
value_d = get_value_d(block_table, value, query, query_layout, num_key_value_heads)
change_d_scale = get_change_d_scale(value)
value_d = value_d * change_d_scale
tmp_out = infer_attention_out_shape(attention_out_layout, query, query_layout, num_heads, value_d)
if quant_scale2 is not None:
attention_out = torch.empty_like(tmp_out, dtype=torch.int8)
elif query.dtype == torch.int8:
if query_rope is not None:
attention_out = torch.empty_like(tmp_out, dtype=query_rope.dtype)
else:
attention_out = torch.empty_like(tmp_out, dtype=torch.half)
else:
attention_out = torch.empty_like(tmp_out, dtype=query.dtype)
tmp_lse_out = infer_lse_out_shape(query, input_layout, query_layout, num_heads)
if softmax_lse_flag:
lse_out = torch.empty_like(tmp_lse_out, dtype=torch.float32)
else:
lse_out = torch.empty([0], dtype=torch.float32, device='meta')
return attention_out, lse_out
@impl(m, "npu_fused_infer_attention_score_v2")
def npu_fused_infer_attention_score_v2_forward(query, key, value, *, query_rope=None, key_rope=None, pse_shift=None, atten_mask=None, actual_seq_qlen=None, actual_seq_kvlen=None,
block_table=None, dequant_scale_query=None, dequant_scale_key=None, dequant_offset_key=None, dequant_scale_value=None,
dequant_offset_value=None, dequant_scale_key_rope=None, quant_scale_out=None, quant_offset_out=None, quant_scale_p=None, learnable_sink=None,
num_query_heads=1, num_key_value_heads=0, softmax_scale=1.0, pre_tokens=2147483647, next_tokens=2147483647,
input_layout="BSH", sparse_mode=0, block_size=0, query_quant_mode=0, key_quant_mode=0, value_quant_mode=0, inner_precise=0,
return_softmax_lse=False, query_dtype=None, key_dtype=None, value_dtype=None, query_rope_dtype=None, key_rope_dtype=None,
key_shared_prefix_dtype=None, value_shared_prefix_dtype=None, dequant_scale_query_dtype=None,
dequant_scale_key_dtype=None, dequant_scale_value_dtype=None, dequant_scale_key_rope_dtype=None, out_dtype=None):
torch._check(
num_query_heads > 0,
lambda: "numHeads should be greater than 0, but got " + str(num_query_heads) +
ops_error(ErrCode.VALUE),
)
num_key_value_heads = num_query_heads if num_key_value_heads == 0 else num_key_value_heads
query_layout, attention_out_layout = get_query_and_attention_out_layout(query, input_layout)
value_d = get_value_d(block_table, value, query, query_layout, num_key_value_heads)
change_d_scale = get_change_d_scale_v2(value, value_dtype)
value_d = value_d * change_d_scale
tmp_out = infer_attention_out_shape(attention_out_layout, query, query_layout, num_query_heads, value_d)
is_hifloat8_input = query.dtype == torch.uint8 and query_dtype is not None and query_dtype == torch_npu.hifloat8
if quant_scale_out is not None:
output_type = torch.int8
if out_dtype is not None:
output_type = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[out_dtype]
attention_out = torch.empty_like(tmp_out, dtype=output_type)
elif query.dtype == torch.int8 or query.dtype == torch.float8_e4m3fn or is_hifloat8_input:
if out_dtype is not None:
output_type = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[out_dtype]
attention_out = torch.empty_like(tmp_out, dtype=output_type)
elif query_rope is not None:
attention_out = torch.empty_like(tmp_out, dtype=query_rope.dtype)
else:
attention_out = torch.empty_like(tmp_out, dtype=torch.half)
else:
attention_out = torch.empty_like(tmp_out, dtype=query.dtype)
tmp_lse_out = infer_lse_out_shape(query, input_layout, query_layout, num_query_heads)
if return_softmax_lse:
lse_out = torch.empty_like(tmp_lse_out, dtype=torch.float32)
else:
lse_out = torch.empty([0], dtype=torch.float32, device='meta')
return attention_out, lse_out
@impl(m, "npu_quant_lightning_indexer")
def npu_quant_lightning_indexer_forward(query, key, weights, query_dequant_scale, key_dequant_scale, query_quant_mode, key_quant_mode, *, actual_seq_lengths_query=None,
actual_seq_lengths_key=None, block_table=None, layout_query="BSND", layout_key="BSND", sparse_count=2048, sparse_mode=3,
pre_tokens=9223372036854775807, next_tokens=9223372036854775807, query_dtype=None, key_dtype=None):
require_param = {"query": query, "key": key, "weights": weights, "query_dequant_scale": query_dequant_scale, "key_dequant_scale": key_dequant_scale}
for item_name, item in require_param.items():
torch._check(
item is not None,
lambda: item_name + " should not be None, but the actual value is None" + ops_error(ErrCode.VALUE),
)
torch._check(
query.numel() > 0,
lambda: "Input query should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
key.numel() > 0,
lambda: "Input key should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
sparse_count > 0,
lambda: "sparse_count should be greater than 0, but got " + str(sparse_count) +
ops_error(ErrCode.VALUE),
)
if layout_key == "TND":
keyHeadNum = key.size(1)
else:
keyHeadNum = key.size(2)
if layout_query == "BSND":
out = torch.empty([query.size(0), query.size(1), keyHeadNum, sparse_count], dtype=torch.int32, device='meta')
elif layout_query == "TND":
out = torch.empty([query.size(0), keyHeadNum, sparse_count], dtype=torch.int32, device='meta')
else:
torch._check(
False,
lambda: "No support of query: " + str(layout_query) + ops_error(ErrCode.VALUE),
)
return out
@impl(m, "npu_kv_quant_sparse_flash_attention")
def npu_kv_quant_sparse_flash_attention_forward(query, key, value, sparse_indices, scale_value, key_quant_mode,
value_quant_mode, *, key_dequant_scale=None, value_dequant_scale=None, block_table=None,
actual_seq_lengths_query=None, actual_seq_lengths_kv=None, sparse_block_size=1, layout_query="BSND",
layout_kv="BSND", sparse_mode=3, pre_tokens=9223372036854775807, next_tokens=9223372036854775807, attention_mode=0,
quant_scale_repo_mode=1, tile_size=128, rope_head_dim=64, key_dtype=None, value_dtype=None):
require_param = {"query": query, "key": key, "value": value, "sparse_indices": sparse_indices}
for item_name, item in require_param.items():
torch._check(
item is not None,
lambda: item_name + " should not be None, but the actual value is None" + ops_error(ErrCode.VALUE),
)
torch._check(
query.numel() > 0,
lambda: "Input query should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
key.numel() > 0,
lambda: "Input key should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
value.numel() > 0,
lambda: "Input value should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
sparse_indices.numel() > 0,
lambda: "Input sparse_indices should not be empty." + ops_error(ErrCode.VALUE),
)
if layout_query == "BSND":
torch._check(
query.dim() == 4,
lambda: "When the layout of query is BSND, the query dimension must be 4, but got " + str(query.dim()) + ops_error(ErrCode.VALUE),
)
out = torch.empty([query.size(0), query.size(1), query.size(2), query.size(3) - rope_head_dim], dtype=query.dtype, device='meta')
elif layout_query == "TND":
torch._check(
query.dim() == 3,
lambda: "When the layout of query is TND, the query dimension must be 3, but got " + str(query.dim()) + ops_error(ErrCode.VALUE),
)
out = torch.empty([query.size(0), query.size(1), query.size(2) - rope_head_dim], dtype=query.dtype, device='meta')
else:
torch._check(
False,
lambda: "Not support layout of query:" + layout_query + ops_error(ErrCode.VALUE),
)
return out
@impl(m, "_npu_kv_quant_sparse_flash_attention_pioneer")
def _npu_kv_quant_sparse_flash_attention_pioneer_forward(query, key, value, sparse_indices, scale_value, key_quant_mode,
value_quant_mode, *, key_dequant_scale=None, value_dequant_scale=None, block_table=None,
actual_seq_lengths_query=None, actual_seq_lengths_kv=None, key_sink=None, value_sink=None, sparse_block_size=1, layout_query="BSND",
layout_kv="BSND", sparse_mode=3, pre_tokens=9223372036854775807, next_tokens=9223372036854775807, attention_mode=0,
quant_scale_repo_mode=1, tile_size=128, rope_head_dim=64, key_dtype=None, value_dtype=None):
require_param = {"query": query, "key": key, "value": value, "sparse_indices": sparse_indices}
for item_name, item in require_param.items():
torch._check(
item is not None,
lambda: item_name + " should not be None, but the actual value is None" + ops_error(ErrCode.VALUE),
)
torch._check(
query.numel() > 0,
lambda: "Input query should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
key.numel() > 0,
lambda: "Input key should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
value.numel() > 0,
lambda: "Input value should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
sparse_indices.numel() > 0,
lambda: "Input sparse_indices should not be empty." + ops_error(ErrCode.VALUE),
)
if layout_query == "BSND":
torch._check(
query.dim() == 4,
lambda: "When the layout of query is BSND, the query dimension must be 4, but got " + str(query.dim()) + ops_error(ErrCode.VALUE),
)
out = torch.empty([query.size(0), query.size(1), query.size(2), query.size(3) - rope_head_dim], dtype=query.dtype, device='meta')
elif layout_query == "TND":
torch._check(
query.dim() == 3,
lambda: "When the layout of query is TND, the query dimension must be 3, but got " + str(query.dim()) + ops_error(ErrCode.VALUE),
)
out = torch.empty([query.size(0), query.size(1), query.size(2) - rope_head_dim], dtype=query.dtype, device='meta')
else:
torch._check(
False,
lambda: "Not support layout of query:" + layout_query + ops_error(ErrCode.VALUE),
)
return out
@impl(m, "npu_fusion_attention")
def npu_fusion_attention_forward(query, key, value, head_num, input_layout, pse=None, padding_mask=None,
atten_mask=None, scale=1.0, keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647,
inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0,
gen_mask_parallel=True, sync=False, softmax_layout="", sink=None, dropout_mask=None, seed=0, offset=0):
B = query.size(0)
N = head_num
S1 = query.size(2)
if input_layout == "BSH":
B = query.size(0)
S1 = query.size(1)
if input_layout == "SBH":
B = query.size(1)
S1 = query.size(0)
if input_layout == "BSND":
S1 = query.size(1)
seed = 0
offset = 0
numels = 0
attention_score = query.new_empty(
query.shape, dtype=query.dtype, device='meta')
if input_layout == "TND":
softmax_max = torch.empty(
[B, head_num, 8], dtype=torch.float32, device='meta')
softmax_sum = torch.empty(
[B, head_num, 8], dtype=torch.float32, device='meta')
else:
softmax_max = torch.empty(
[B, head_num, S1, 8], dtype=torch.float32, device='meta')
softmax_sum = torch.empty(
[B, head_num, S1, 8], dtype=torch.float32, device='meta')
softmax_out = torch.empty([0], dtype=query.dtype, device='meta')
return (attention_score, softmax_max, softmax_sum, softmax_out, seed, offset, numels)
@impl(m, "npu_fusion_attention_grad")
def npu_fusion_attention_backward(query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None, atten_mask=None,
softmax_max=None, softmax_sum=None, softmax_in=None, attention_in=None, scale_value=1.0,
keep_prob=1.0, pre_tockens=2147483647, next_tockens=2147483647, inner_precise=0, seed=0, offset=0,
numels=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0,
gen_mask_parallel=True, sync=False, softmax_layout="", sink=None):
dq = query.new_empty(query.shape, dtype=query.dtype, device='meta')
dk = key.new_empty(key.shape, dtype=query.dtype, device='meta')
dv = value.new_empty(value.shape, dtype=query.dtype, device='meta')
dpse = None
dsink = torch.empty([0], device='meta') if sink is None else torch.empty(
sink.shape, dtype=sink.dtype, device='meta')
return (dq, dk, dv, dpse, dsink)
@impl(m, "npu_quant_fusion_attention")
def npu_quant_fusion_attention_forward(query, key, value, head_num, input_layout, *, d_scale_q, d_scale_k,
d_scale_v, p_scale=None, scale=1.0, query_dtype=None):
B = query.size(0)
N = head_num
S1 = query.size(2)
S2 = key.size(2)
if input_layout == "BSH":
B = query.size(0)
S1 = query.size(1)
S2 = key.size(1)
if input_layout == "SBH":
B = query.size(1)
S1 = query.size(0)
S2 = key.size(0)
seed = 0
offset = 0
numels = 0
if out_dtype is not None and out_dtype == 1:
attention_score = torch.empty_like(query, dtype=torch.bfloat16, device='meta')
softmax_out = torch.empty([0], dtype=torch.bfloat16, device='meta')
else:
attention_score = torch.empty_like(query, dtype=torch.half, device='meta')
softmax_out = torch.empty([0], dtype=torch.half, device='meta')
softmax_max = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta')
softmax_sum = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta')
return (torch.empty_like(attention_score),
torch.empty_like(softmax_max),
torch.empty_like(softmax_sum),
torch.empty_like(softmax_out),
seed,
offset,
numels)
@impl(m, "npu_fusion_attention_v2")
def npu_fusion_attention_forward_v2(query, key, value, head_num, input_layout, *, pse=None, padding_mask=None, atten_mask=None, query_rope=None,
key_rope=None, scale=1.0, keep_prob=1.0,
pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None,
sparse_mode=0, gen_mask_parallel=True, sync=False, pse_type=1, q_start_idx=None, kv_start_idx=None,
softmax_layout="", sink=None, dropout_mask=None, seed=0, offset=0):
B = query.size(0)
N = head_num
S1 = query.size(2)
S2 = key.size(2)
if input_layout == "BSH":
B = query.size(0)
S1 = query.size(1)
S2 = key.size(1)
if input_layout == "SBH":
B = query.size(1)
S1 = query.size(0)
S2 = key.size(0)
numels = 0
attention_score = torch.empty_like(query, dtype=query.dtype, device='meta')
softmax_max = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta')
softmax_sum = torch.empty([B, head_num, S1, 8], dtype=torch.float32, device='meta')
softmax_out = torch.empty([0], dtype=query.dtype, device='meta')
return (torch.empty_like(attention_score),
torch.empty_like(softmax_max),
torch.empty_like(softmax_sum),
torch.empty_like(softmax_out),
seed,
offset,
numels)
@impl(m, "npu_fused_floyd_attention")
def npu_fused_floyd_attention(query_ik, key_ij, value_ij, key_jk, value_jk, *, atten_mask=None, scale_value=1.):
out0_out1_shape = (query_ik.shape[0], query_ik.shape[1], query_ik.shape[2], query_ik.shape[3], 8)
out0 = torch.empty(out0_out1_shape, dtype=torch.float32, device='meta')
out1 = torch.empty_like(out0, device='meta')
out2 = torch.empty_like(query_ik, device='meta')
return (out0, out1, out2)
@impl(m, "npu_fused_floyd_attention_backward")
def npu_fused_floyd_attention_backward(grad_output, query_ik, key_ij, value_ij, key_jk, value_jk, attention_out, softmax_max, softmax_sum, *, atten_mask=None, scale_value=1.):
dquery = torch.empty_like(query_ik, device='meta')
dkey_0 = torch.empty_like(key_ij, device='meta')
dvalue_0 = torch.empty_like(value_ij, device='meta')
dkey_1 = torch.empty_like(key_jk, device='meta')
dvalue_1 = torch.empty_like(value_jk, device='meta')
return (dquery, dkey_0, dvalue_0, dkey_1, dvalue_1)
@impl(m, "npu_lightning_indexer")
def npu_lightning_indexer_forward(query, key, weights, *, actual_seq_lengths_query=None,
actual_seq_lengths_key=None, block_table=None, layout_query="BSND", layout_key="BSND", sparse_count=2048, sparse_mode=3,
pre_tokens=9223372036854775807, next_tokens=9223372036854775807, return_value=False):
require_param = {"query": query, "key": key, "weights": weights}
for item_name, item in require_param.items():
torch._check(
item is not None,
lambda: item_name + " should not be None, but the actual value is None" + ops_error(ErrCode.VALUE),
)
torch._check(
query.numel() > 0,
lambda: "Input query should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
key.numel() > 0,
lambda: "Input key should not be empty." + ops_error(ErrCode.VALUE),
)
torch._check(
sparse_count > 0,
lambda: "sparse_count should be greater than 0, but got " + str(sparse_count) +
ops_error(ErrCode.VALUE),
)
torch._check(
not return_value,
lambda: "when return_value is true, not support pytorch compile." + ops_error(ErrCode.VALUE),
)
if layout_query == "BSND":
sparse_indices_out = torch.empty([query.size(0), query.size(1), key.size(2), sparse_count], dtype=torch.int32, device='meta')
else:
if layout_key == "TND":
n_dim_idx = 1
else:
n_dim_idx = 2
sparse_indices_out = torch.empty([query.size(0), key.size(n_dim_idx), sparse_count], dtype=torch.int32, device='meta')
if return_value:
if layout_query == "BSND":
sparse_values_out = torch.empty([query.size(0), query.size(1), key.size(2), sparse_count], dtype=query.dtype, device='meta')
else:
if layout_key == "TND":
n_dim_idx = 1
else:
n_dim_idx = 2
sparse_values_out = torch.empty([query.size(0), key.size(n_dim_idx), sparse_count], dtype=query.dtype, device='meta')
else:
sparse_values_out = torch.empty([0], dtype=query.dtype, device='meta')
return (sparse_indices_out, sparse_values_out)
@impl(m, "npu_lightning_indexer_grad")
def npu_lightning_indexer_grad_meta(query, key, dy, sparse_indices, weights, actual_seq_lengths_query=None, actual_seq_lengths_key=None, layout="BSND", sparse_mode=3, pre_tokens=9223372036854775807, next_tokens=9223372036854775807):
d_query = query.new_empty(query.shape, dtype=query.dtype, device='meta')
d_key = key.new_empty(key.shape, dtype=key.dtype, device='meta')
d_weights = weights.new_empty(weights.shape, dtype=weights.dtype, device='meta')
return (d_query, d_key, d_weights)
@impl(m, "npu_sparse_lightning_indexer_grad_kl_loss")
def npu_sparse_lightning_indexer_grad_kl_loss_meta(query, key, query_index, key_index, weights, sparse_indices, softmax_max, softmax_sum, scale_value, *, query_rope=None, key_rope=None, actual_seq_qlen=None, actual_seq_klen=None, layout='BSND', sparse_mode=3, pre_tokens=9223372036854775807, next_tokens=9223372036854775807):
d_query_index = query_index.new_empty(query_index.shape, dtype=query_index.dtype, device='meta')
d_key_index = key_index.new_empty(key_index.shape, dtype=key_index.dtype, device='meta')
d_weights = weights.new_empty(weights.shape, dtype=weights.dtype, device='meta')
loss = torch.empty([1], dtype=torch.float32, device='meta')
return (d_query_index, d_key_index, d_weights, loss)
@impl(m, "npu_dense_lightning_indexer_grad_kl_loss")
def npu_dense_lightning_indexer_grad_kl_loss_meta(query, key, query_index, key_index, weights, softmax_max, softmax_sum, softmax_max_index, softmax_sum_index, scale_value, *, query_rope=None, key_rope=None, actual_seq_qlen=None, actual_seq_klen=None, layout='BSND', sparse_mode=3, pre_tokens=9223372036854775807, next_tokens=9223372036854775807):
d_query_index = query_index.new_empty(query_index.shape, dtype=query_index.dtype, device='meta')
d_key_index = key_index.new_empty(key_index.shape, dtype=key_index.dtype, device='meta')
d_weights = weights.new_empty(weights.shape, dtype=weights.dtype, device='meta')
loss = torch.empty([1], dtype=torch.float32, device='meta')
return (d_query_index, d_key_index, d_weights, loss)
@impl(m, "npu_dense_lightning_indexer_softmax_lse")
def npu_dense_lightning_indexer_softmax_lse_meta(query_index, key_index, weights, *, actual_seq_qlen=None, actual_seq_klen=None, layout='BSND', sparse_mode=3, pre_tokens=9223372036854775807, next_tokens=9223372036854775807):
if layout == "TND":
output_size = [key_index.size(1), query_index.size(0)]
else:
output_size = [query_index.size(0), key_index.size(2), query_index.size(1)]
softmax_max_out = torch.empty(output_size, dtype=torch.float32, device='meta')
softmax_sum_out = torch.empty(output_size, dtype=torch.float32, device='meta')
return (softmax_max_out, softmax_sum_out)
@impl(m, "npu_quant_fusion_attention_backward")
def npu_quant_fusion_attention_backward(query, key, value, dy, head_num, input_layout, d_scale_q, d_scale_k, d_scale_v, d_scale_dy, *, p_scale=None, ds_scale=None, softmax_max=None, softmax_sum=None, attention_in=None, scale_value=1.0, query_dtype=None):
if out_dtype is not None and out_dtype == 1:
dq = torch.empty_like(query, dtype=torch.bfloat16, device='meta')
dq_rope = torch.empty_like([0], dtype=torch.bfloat16, device='meta')
dk = torch.empty_like(key, dtype=torch.bfloat16, device='meta')
dk_rope = torch.empty_like([0], dtype=torch.bfloat16, device='meta')
dv = torch.empty_like(value, dtype=torch.bfloat16, device='meta')
dpse = torch.empty_like([0], dtype=torch.bfloat16, device='meta')
else:
dq = torch.empty_like(query, dtype=torch.half, device='meta')
dq_rope = torch.empty_like([0], dtype=torch.half, device='meta')
dk = torch.empty_like(key, dtype=torch.half, device='meta')
dk_rope = torch.empty_like([0], dtype=torch.half, device='meta')
dv = torch.empty_like(value, dtype=torch.half, device='meta')
dpse = torch.empty_like([0], dtype=torch.half, device='meta')
dsink = None if sink is None else torch.empty_like(sink, dtype=sink.dtype, device='meta')
return (dq, dk, dv, dpse, dq_rope, dk_rope, dsink)
@impl(m, "npu_fusion_attention_grad_v2")
def npu_fusion_attention_backward_v2(query, key, value, dy, head_num, input_layout, *, pse=None, padding_mask=None, atten_mask=None,
softmax_max=None,
softmax_sum=None, softmax_in=None, attention_in=None, query_rope=None, key_rope=None, scale_value=1.0,
keep_prob=1.0, pre_tokens=2147483647, next_tokens=2147483647, inner_precise=0, seed=0, offset=0,
numels=0, prefix=None, actual_seq_qlen=None, actual_seq_kvlen=None, sparse_mode=0,
gen_mask_parallel=True, sync=False, pse_type=1, q_start_idx=None, kv_start_idx=None,
softmax_layout="", sink=None):
dq = torch.empty_like(query, dtype=query.dtype, device='meta')
dq_rope = torch.empty_like([0], dtype=query.dtype, device='meta')
dk = torch.empty_like(key, dtype=query.dtype, device='meta')
dk_rope = torch.empty_like([0], dtype=query.dtype, device='meta')
dv = torch.empty_like(value, dtype=query.dtype, device='meta')
dpse = torch.empty_like([0], dtype=query.dtype, device='meta')
dsink = None if sink is None else torch.empty_like(sink, dtype=sink.dtype, device='meta')
return (dq, dk, dv, dpse, dq_rope, dk_rope, dsink)
@impl(m, "npu_rotary_mul")
def npu_rotary_mul_meta(embedding, cosine, sine, mode='half', rotate=None):
return torch.empty_like(embedding)
@impl(m, "npu_rotary_mul_backward")
def npu_rotary_mul_backward(grad, embedding, cosine, sine, mode=0):
dx = torch.empty_like(embedding, dtype=embedding.dtype, device='meta')
dr1 = torch.empty_like(cosine, dtype=embedding.dtype, device='meta')
dr2 = torch.empty_like(sine, dtype=embedding.dtype, device='meta')
return (dx, dr1, dr2)
@impl(m, "fast_gelu")
def fast_gelu_meta(self):
return torch.empty_like(self)
@impl(m, "npu_fast_gelu_backward")
def npu_fast_gelu_backward_meta(grad, self):
return torch.empty_like(self)
@impl(m, "npu_fast_gelu")
def npu_fast_gelu_meta(self):
return torch.empty_like(self)
@impl(m, "npu_gelu")
def npu_gelu_meta(self, *, approximate="none"):
return torch.empty_like(self)
@impl(m, "npu_gelu_backward")
def npu_gelu_backward_meta(grad, self, *, approximate="none"):
return torch.empty_like(self)
@impl(m, "npu_silu")
def npu_silu_meta(self):
return torch.empty_like(self)
@impl(m, "npu_silu_backward")
def npu_silu_backward_meta(grad_output, x0, x1):
return torch.empty_like(grad_output)
@impl(m, "npu_layer_norm_eval")
def npu_layer_norm_eval_meta(input, normalized_shape, weight=None, bias=None, eps=1e-5):
return torch.empty_like(input)
@impl(m, "npu_add_layer_norm")
def npu_add_layer_norm_meta(x1, x2, gamma, beta, epsilon=1e-5, additional_output=False):
rstd_dim = x1.dim() - gamma.dim()
ret = []
for i in range(x1.dim()):
if i < rstd_dim:
ret.append(x1.size(i))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return (torch.empty_like(x1, dtype=x1.dtype), torch.empty_like(rstd), torch.empty_like(rstd), torch.empty_like(x1, dtype=x1.dtype))
@impl(m, "npu_gelu_mul")
def npu_gelu_mul_meta(input_tensor, *, approximate="none"):
output_shape = list(input_tensor.shape)
last_dim = input_tensor.shape[-1]
output_shape[-1] = last_dim // 2
output_shape = tuple(output_shape)
output_dtype = input_tensor.dtype
return torch.empty(size=output_shape, dtype=output_dtype, device=torch.device("meta"))
@impl(m, "npu_gelu_quant")
def npu_gelu_quant_meta(self, *, input_scale=None, input_offset=None,
approximate="none", quant_mode="dynamic", dst_type=1, round_mode='rint'):
if not (quant_mode == "dynamic" or quant_mode == "static"):
raise RuntimeError("Parameter(quant_mode) must be 'dynamic' or 'static', got " + quant_mode + ops_error(ErrCode.VALUE))
out_scale = None
if quant_mode == "static":
if input_scale is None:
raise RuntimeError("input_scale cannot be None when quant_mode is 'static'.")
else:
out_scale_shape = self.shape[:-1]
out_scale = self.new_empty(out_scale_shape, dtype=torch.float32)
y_dst_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type)
if dst_type is not None:
y = torch.empty_like(self, dtype=y_dst_dtype)
else:
raise RuntimeError("Parameter(dst_type) enum value:{} not found in " \
"TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP, please check.".format(dst_type) +
ops_error(ErrCode.PARAM))
return (y, out_scale)
@impl(m, "npu_dtype_cast")
def npu_dtype_cast_meta(self, dtype, input_dtype=None):
dim_num = self.dim()
input_shape = []
for dim in range(dim_num):
input_shape.append(self.size(dim))
if input_dtype == 296 or input_dtype == 297:
if dim_num != 0:
input_shape[-1] *= 2
else:
raise RuntimeError("Scalar input cannot be float4_e2m1fn_x2 or float4_e1m2fn_x2" +
ops_error(ErrCode.PARAM))
if dtype == 285 or dtype == 296 or dtype == 297:
if dim_num == 0 or input_shape[-1] % 2:
raise RuntimeError("If output dtype is float4_e2m1fn_x2, float4_e1m2fn_x2 or int4, " \
"the last dim of input must be divisible by 2" +
ops_error(ErrCode.PARAM))
input_shape[-1] //= 2
if dtype in [285, 290, 296, 297]:
output = self.new_empty(input_shape, dtype=torch.uint8)
else:
output_dst_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dtype)
if output_dst_dtype is not None:
output = self.new_empty(input_shape, dtype=output_dst_dtype)
else:
raise RuntimeError("Parameter(dtype) enum value:{} not found in " \
"TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP, please check.".format(dtype) +
ops_error(ErrCode.PARAM))
return output
@impl(m, "_npu_dtype_cast")
def _npu_dtype_cast_meta(self, dtype):
return self.new_empty(self.shape, dtype=dtype)
@impl(m, "_npu_dtype_cast_backward")
def _npu_dtype_cast_backward_meta(self, dtype):
return self.new_empty(self.shape, dtype=dtype)
@impl(m, "npu_dtype_cast_backward")
def npu_dtype_cast_backward_meta(self, dtype, grad_dtype=None, input_dtype=None):
dim_num = self.dim()
input_shape = []
for dim in range(dim_num):
input_shape.append(self.size(dim))
if grad_dtype == 296 or grad_dtype == 297:
if dim_num != 0:
input_shape[-1] *= 2
else:
raise RuntimeError("Scalar input cannot be float4_e2m1 or float4_e1m2" +
ops_error(ErrCode.PARAM))
if input_dtype == 296 or input_dtype == 297:
if dim_num == 0 or input_shape[-1] % 2:
raise RuntimeError("If output dtype is float4_e2m1 or float4_e1m2, " \
"the last dim of input must be divisible by 2" +
ops_error(ErrCode.PARAM))
input_shape[-1] //= 2
if input_dtype in [290, 296, 297]:
output = self.new_empty(input_shape, dtype=torch.uint8)
else:
output = self.new_empty(input_shape, dtype=dtype)
return output
@impl(m, "npu_block_sparse_attention")
def npu_block_sparse_attention_meta(query, key, value, block_sparse_mask, block_shape, *, q_input_layout='TND', kv_input_layout='TND',
num_key_value_heads=1, scale_value=0.0, inner_precise=1,
actual_seq_lengths=None, actual_seq_lengths_kv=None, softmax_lse_flag=0):
torch._check(
query.dim() == 3 or query.dim() == 4,
lambda: "query should be 3 or 4 dimensional, but got " + str(query.dim()) + ops_error(ErrCode.PARAM),
)
torch._check(
len(block_shape) >= 2 and block_shape[1] % 128 == 0,
lambda: "block_shape[1] (blockShapeY) must be a multiple of 128, got " + str(block_shape[1])
+ ops_error(ErrCode.PARAM),
)
torch._check(
query.dtype != torch.bfloat16 or inner_precise == 0,
lambda: "when query/key/value are bfloat16, inner_precise must be 0, got " + str(inner_precise)
+ ops_error(ErrCode.PARAM),
)
attention_out_shape = list(query.size())
attention_out_shape[-1] = value.size(-1)
attention_out = query.new_empty(attention_out_shape)
if query.dim() == 4:
softmax_lse = query.new_empty([query.size(0), query.size(1), query.size(2), 1], dtype=torch.float32)
else:
softmax_lse = query.new_empty([query.size(0), query.size(1), 1], dtype=torch.float32)
return (attention_out, softmax_lse)
@impl(m, "npu_block_sparse_attention_backward")
def npu_block_sparse_attention_backward_meta(d_out, query, key, value, attention_out, softmax_lse, block_sparse_mask,
block_shape, actual_seq_lengths, actual_seq_lengths_kv,
q_input_layout, kv_input_layout, num_key_value_heads, scale_value):
d_query = query.new_empty(query.size())
d_key = key.new_empty(key.size())
d_value = value.new_empty(value.size())
return (d_query, d_key, d_value)
@impl(m, "npu_bmmV2")
def npu_bmmV2_meta(self, mat2, output_sizes):
dim1 = self.size(0)
dim2 = self.size(1)
dim3 = mat2.size(2)
return self.new_empty((dim1, dim2, dim3))
@impl(m, "npu_transpose")
def npu_transpose_meta(self, perm, require_contiguous=True):
output = self.permute(perm)
return torch.empty_like(output, dtype=self.dtype)
@impl(m, "npu_deep_norm")
def npu_deep_norm_meta(self, gx, beta, gamma, alpha=0.3, epsilon=1e-6):
rstd_dim = self.dim() - gamma.dim()
ret = []
for i in range(self.dim()):
if i < rstd_dim:
ret.append(self.size(i))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return (torch.empty_like(rstd), torch.empty_like(rstd), torch.empty_like(self, dtype=self.dtype))
@impl(m, "npu_deep_norm_backward")
def npu_deep_norm_backward_meta(dy, x, gx, gamma, mean, rstd, alpha=0.3):
return (torch.empty_like(x), torch.empty_like(gx), torch.empty_like(gamma), torch.empty_like(gamma))
@impl(m, "npu_group_norm_swish")
def npu_group_norm_swish_meta(input, num_groups, weight, bias, eps=1e-5, swish_scale=1.0):
mean_rstd_shape = [input.size(0), num_groups]
return (torch.empty_like(input, dtype=input.dtype),
torch.empty(mean_rstd_shape, dtype=input.dtype, device=torch.device('meta')),
torch.empty(mean_rstd_shape, dtype=input.dtype, device=torch.device('meta')))
@impl(m, "npu_rms_norm")
def npu_rms_norm_meta(self, gamma, epsilon=1e-6):
rstd_dim = self.dim() - gamma.dim()
ret = []
for i in range(self.dim()):
if i < rstd_dim:
ret.append(self.size(i))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return (torch.empty_like(self, dtype=self.dtype), torch.empty_like(rstd))
@impl(m, "npu_gemma_rms_norm")
def npu_gemma_rms_norm_meta(self, gamma, epsilon=1e-6):
rstd_dim = self.dim() - gamma.dim()
ret = []
for i in range(self.dim()):
if i < rstd_dim:
ret.append(self.size(i))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return (torch.empty_like(self, dtype=self.dtype), torch.empty_like(rstd))
@impl(m, "npu_add_rms_norm")
def npu_add_rms_norm_meta(x1, x2, gamma, epsilon=1e-6):
rstd_dim = x1.dim() - gamma.dim()
ret = []
for i in range(x1.dim()):
if i < rstd_dim:
ret.append(x1.size(i))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return (torch.empty_like(x1, dtype=x1.dtype), torch.empty_like(rstd), torch.empty_like(x1, dtype=x1.dtype))
@impl(m, "npu_add_rms_norm_v2")
def npu_add_rms_norm_v2_meta(x1, x2, gamma, epsilon=1e-6):
rstd_dim = x1.dim() - gamma.dim()
ret = []
for i in range(x1.dim()):
if i < rstd_dim:
ret.append(x1.size(i))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return torch.empty_like(rstd)
@impl(m, "npu_add_rms_norm_v2_functional")
def npu_add_rms_norm_v2_functional_meta(x1, x2, gamma, epsilon=1e-6):
rstd_dim = x1.dim() - gamma.dim()
ret = []
for i in range(x1.dim()):
if i < rstd_dim:
ret.append(x1.size(i))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return (torch.empty_like(rstd), torch.empty_like(x1), torch.empty_like(x2))
@impl(m, "npu_rms_norm_quant")
def npu_rms_norm_quant_meta(x, gamma, beta, scale, offset, epsilon=1e-06, dst_dtype=None):
dst_dtype = dst_dtype if dst_dtype is not None else 1
dst_torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_dtype, torch.int8)
if dst_torch_dtype == torch.quint4x2:
dim_num = x.dim()
if x.size(dim_num - 1) % 8:
raise RuntimeError("If dtype is quint4x2, the last dim of input must be divided by 8" +
ops_error(ErrCode.NOT_SUPPORT))
output_shape = []
for dim in range(dim_num - 1):
output_shape.append(x.size(dim))
output_shape.append(x.size(dim_num - 1) // 8)
return torch.empty(output_shape, dtype=torch.int32, device=x.device)
return torch.empty(x.size(), dtype=dst_torch_dtype, device=x.device)
@impl(m, "npu_rms_norm_dynamic_mx_quant")
def npu_rms_norm_dynamic_mx_quant_meta(x, gamma, *, beta=None, epsilon=1e-06, scale_alg=0, round_mode='rint', dst_type=296):
if scale_alg not in [0, 1]:
raise RuntimeError(f"Invalid scale_alg value: {scale_alg}. Expected 0 or 1." +
ops_error(ErrCode.PARAM))
align_num = 2
mxscale_block_size = 32
dst_torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.int8)
if dst_torch_dtype == torch.float8_e5m2 or dst_type == 291:
y = torch.empty_like(x, dtype=torch.float8_e5m2)
elif dst_torch_dtype == torch.float8_e4m3fn or dst_type == 292:
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
else:
if x.size(x.dim() - 1) % 2:
raise RuntimeError("If output dtype is float4_e2m1 or float4_e1m2, " \
"the last dim of input must be divisible by 2, " +
ops_error(ErrCode.PARAM))
y_shape = []
for dim in range(x.dim() - 1):
y_shape.append(x.size(dim))
y_shape.append(x.size(x.dim() - 1) // align_num)
y = x.new_empty(y_shape, dtype=torch.uint8)
mxscale_shape = []
for dim in range(x.dim()):
mxscale_shape.append(x.size(dim))
mxscale_shape.append(2)
last_axis_change = x.dim() - 1
last_dim_size = int(math.ceil(mxscale_shape[last_axis_change] / (mxscale_block_size * align_num)))
mxscale_shape[last_axis_change] = last_dim_size
mxscale = x.new_empty(mxscale_shape, dtype=torch.uint8)
rstd_dim = x.dim() - gamma.dim()
ret = []
for dim in range(x.dim()):
if dim < rstd_dim:
ret.append(x.size(dim))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return (y, mxscale, rstd)
@impl(m, "npu_add_rms_norm_cast")
def npu_add_rms_norm_cast_meta(x1, x2, gamma, epsilon=1e-6):
rstd_dim = x1.dim() - gamma.dim()
ret = []
for i in range(x1.dim()):
if i < rstd_dim:
ret.append(x1.size(i))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return (torch.empty_like(x1, dtype=torch.float32), torch.empty_like(x1, dtype=x1.dtype), torch.empty_like(rstd), torch.empty_like(x1, dtype=x1.dtype))
@impl(m, "npu_add_rms_norm_dynamic_quant")
def npu_add_rms_norm_dynamic_quant_meta(x1, x2, gamma, *, smooth_scale1=None, smooth_scale2=None, beta=None, epsilon=1e-6, output_mask=None, y_dtype=None):
if y_dtype is None or y_dtype == torch.int8:
y_shape = x1.size()
y_dtype_actual = torch.int8
else:
y_shape = list(x1.size())
y_shape[-1] = y_shape[-1] // 8
y_dtype_actual = torch.int32
y2_shape = y_shape if (output_mask is None or (len(output_mask) > 1 and output_mask[1])) else (0,)
return (torch.empty(y_shape, dtype=y_dtype_actual, device=x1.device),
torch.empty(y2_shape, dtype=y_dtype_actual, device=x1.device),
torch.empty(x1.size(), dtype=x1.dtype, device=x1.device),
torch.empty(x1.size()[:-1], dtype=torch.float32, device=x1.device),
torch.empty(x1.size()[:-1], dtype=torch.float32, device=x1.device))
@impl(m, "npu_add_rms_norm_dynamic_mx_quant")
def npu_add_rms_norm_dynamic_mx_quant_meta(x1, x2, gamma, *, beta=None, epsilon=1e-6, scale_alg=0, round_mode='rint', dst_type=296):
if scale_alg not in [0, 1]:
raise RuntimeError(f"Invalid scale_alg value: {scale_alg}. Expected 0 or 1." +
ops_error(ErrCode.PARAM))
dim_num = x1.dim()
align_num = 2
mxscale_block_size = 32
torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.int8)
if torch_dtype == torch.float8_e5m2 or dst_type == torch_npu.float8_e5m2:
y = torch.empty_like(x1, dtype=torch.float8_e5m2)
elif torch_dtype == torch.float8_e4m3fn or dst_type == torch_npu.float8_e4m3fn:
y = torch.empty_like(x1, dtype=torch.float8_e4m3fn)
else:
if x1.size(dim_num - 1) % 2:
raise RuntimeError("If output dtype is float4_e2m1 or float4_e1m2, " \
"the last dim of input must be divisible by 2, " +
ops_error(ErrCode.PARAM))
y_shape = []
for dim in range(dim_num - 1):
y_shape.append(x1.size(dim))
y_shape.append(x1.size(dim_num - 1) // align_num)
y = x1.new_empty(y_shape, dtype=torch.uint8)
x_out = torch.empty_like(x1, dtype=x1.dtype)
mxscale_shape = []
for dim in range(dim_num):
mxscale_shape.append(x1.size(dim))
mxscale_shape.append(2)
last_axis_change = dim_num - 1
last_dim_size = int(math.ceil(mxscale_shape[last_axis_change] / (mxscale_block_size * align_num)))
mxscale_shape[last_axis_change] = last_dim_size
mxscale = x1.new_empty(mxscale_shape, dtype=torch.uint8)
rstd_dim = dim_num - gamma.dim()
ret = []
for i in range(dim_num):
if i < rstd_dim:
ret.append(x1.size(i))
else:
ret.append(1)
rstd = torch.empty(ret, dtype=torch.float32, device='meta')
return (y, x_out, mxscale, rstd)
@impl(m, "npu_rms_norm_backward")
def npu_rms_norm_backward_meta(dy, self, gamma, rstd):
return (torch.empty_like(self, dtype=self.dtype), torch.empty_like(gamma, dtype=torch.float32))
@impl(m, "scatter_update")
def scatter_update_meta(self, indices, updates, axis):
return torch.empty_like(self)
@impl(m, "scatter_update_")
def scatter_update__meta(self, indices, updates, axis):
return self
@impl(m, "_npu_dropout")
def _npu_dropout_meta(self, p):
mask = math.floor(math.floor((self.numel() + BIT_NUMBER - 1) / BIT_NUMBER) * BIT_NUMBER / UINT8_BIT_NUMBER)
return (torch.empty_like(self, dtype=self.dtype), torch.empty(mask, dtype=torch.uint8, device='meta'))
@impl(m, "npu_quant_scatter")
def npu_quant_scatter_meta(self, indices, updates, quant_scales, quant_zero_points=None, axis=-2, quant_axis=-1,
reduce='update', dst_type=1, round_mode='rint'):
return torch.empty_like(self)
@impl(m, "npu_quant_scatter_")
def npu_quant_scatter__meta(self, indices, updates, quant_scales, quant_zero_points=None, axis=-2, quant_axis=-1,
reduce='update', dst_type=1, round_mode='rint'):
return self
@impl(m, "npu_scatter_list_")
def scatter_list__meta(self, indices, updates, mask, reduce='update', axis=-2):
return self
@impl(m, "npu_scatter_list")
def scatter_list_meta(self, indices, updates, mask, reduce='update', axis=-2):
var_list = []
for item in self:
var_list.append(torch.empty_like(item))
return var_list
@impl(m, "npu_scatter_nd_update")
def scatter_nd_update_meta(self, indices, updates):
return torch.empty_like(self, dtype=self.dtype)
@impl(m, "npu_scatter_nd_update_")
def scatter_nd_update__meta(self, indices, updates):
return self
@impl(m, "npu_scatter_pa_kv_cache_functional")
def npu_scatter_pa_kv_cache_functional_meta(key, value, key_cache, value_cache, slot_mapping, *, compress_lens=None,
compress_seq_offsets=None, seq_lens=None, cache_mode='PA_NZ'):
return (torch.empty_like(key_cache, dtype=key_cache.dtype), torch.empty_like(value_cache, dtype=value_cache.dtype))
@impl(m, "npu_scatter_pa_kv_cache")
def npu_scatter_pa_kv_cache_meta(key, value, key_cache, value_cache, slot_mapping, *, compress_lens=None,
compress_seq_offsets=None, seq_lens=None, cache_mode='PA_NZ'):
return
@impl(m, "npu_geglu")
def npu_geglu_meta(self, dim=-1, approximate=1, activate_left=False):
dim_num = self.dim()
input_shape = list(self.shape)
if dim_num < 1 or dim_num > 8:
raise RuntimeError("dim num out of range [1, 8]" + ops_error(ErrCode.PARAM))
if dim >= dim_num or dim < -dim_num:
raise RuntimeError("attribute [dim] out of range [-" + str(dim_num) + ", " + str(dim_num - 1) + "]" + ops_error(ErrCode.VALUE))
if input_shape[dim] % 2 == 1:
raise RuntimeError("x shape: " + str(input_shape) + ". Dim [" + str(dim) + "] of x should be divisible by 2, but get [" + str(input_shape[dim]) + "]" + ops_error(ErrCode.PARAM))
input_shape[dim] //= 2
return (self.new_empty(input_shape, dtype=self.dtype), self.new_empty(input_shape, dtype=self.dtype))
@impl(m, "npu_geglu_grad")
def npu_geglu_backward_meta(grad_output, self, gelu, dim, approximate, activate_left=False):
return (torch.empty_like(self, dtype=self.dtype), torch.empty_like(self, dtype=self.dtype))
@impl(m, "npu_dropout_backward")
def npu_dropout_backward_meta(grad_output, mask, p):
return torch.empty_like(grad_output, dtype=grad_output.dtype)
@impl(m, "npu_masked_softmax_with_rel_pos_bias")
def npu_masked_softmax_with_rel_pos_bias_meta(x, atten_mask, relative_pos_bias, scale_value=1.0, inner_precision_mode=0):
return torch.empty_like(x, dtype=x.dtype)
@impl(m, "npu_scaled_masked_softmax")
def npu_scaled_masked_softmax_meta(x, mask, scale=1, fixed_triu_mask=False):
return torch.empty_like(x, dtype=x.dtype)
@impl(m, "npu_scaled_masked_softmax_backward")
def npu_scaled_masked_softmax_backward_meta(y_grad, y, mask, scale, fixed_triu_mask):
return torch.empty_like(y_grad, dtype=y_grad.dtype)
@impl(m, "npu_moe_distribute_dispatch")
def npu_moe_distribute_dispatch_meta(x, expert_ids, group_ep, ep_world_size, ep_rank_id, moe_expert_num, scales=None, x_active_mask=None, expert_scales=None, group_tp="", tp_world_size=0,
tp_rank_id=0, expert_shard_type=0, shared_expert_num=1, shared_expert_rank_num=0, quant_mode=0, global_bs=0, expert_token_nums_type=1):
n = x.size(0)
h = x.size(1)
k = expert_ids.size(1)
shared_front = 0
outDtype = x.dtype
if expert_shard_type == 0:
shared_front = 1
local_moe_expert_num = 0
global_bs_real = 0
if global_bs == 0:
global_bs_real = n * ep_world_size
else:
global_bs_real = global_bs
a = 0
if shared_front == 1:
if ep_rank_id < shared_expert_rank_num:
local_moe_expert_num = 1
a = global_bs_real // shared_expert_rank_num
else:
local_moe_expert_num = moe_expert_num // (ep_world_size - shared_expert_rank_num)
a = global_bs_real * min(local_moe_expert_num, k)
else:
if ep_rank_id >= ep_world_size - shared_expert_rank_num:
local_moe_expert_num = 1
a = global_bs_real // shared_expert_rank_num
else:
local_moe_expert_num = moe_expert_num // (ep_world_size - shared_expert_rank_num)
a = global_bs_real * min(local_moe_expert_num, k)
ep_recv_cnt_num = 0
if tp_world_size == 2:
ep_recv_cnt_num = ep_world_size * local_moe_expert_num * tp_world_size
else:
ep_recv_cnt_num = ep_world_size * local_moe_expert_num
if scales is not None or quant_mode != 0:
outDtype = torch.int8
local_moe_expert_num = int(local_moe_expert_num)
expand_idx = x.new_empty(tuple([n * k]), dtype=torch.int32)
if tp_world_size == 0:
expand_x = x.new_empty(tuple([a, h]), dtype=outDtype)
dynamic_scales = x.new_empty(tuple([a]), dtype=torch.float32)
else:
expand_x = x.new_empty(tuple([a * tp_world_size, h]), dtype=outDtype)
dynamic_scales = x.new_empty(tuple([a * tp_world_size]), dtype=torch.float32)
expert_token_nums = x.new_empty(tuple([local_moe_expert_num]), dtype=torch.int64)
ep_recv_counts = x.new_empty(tuple([ep_recv_cnt_num]), dtype=torch.int32)
tp_recv_counts = x.new_empty(tuple([tp_world_size]), dtype=torch.int32)
expand_scales = x.new_empty(tuple([0]), dtype=torch.float32)
if expert_scales is not None:
ep_recv_cnt_num = ep_world_size * local_moe_expert_num + global_bs_real * 2 * k * (ep_world_size // 8)
ep_recv_counts = x.new_empty(tuple([ep_recv_cnt_num]), dtype=torch.int32)
expand_scales = x.new_empty(tuple([a]), dtype=torch.float32)
return (expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, expand_scales)
def get_dispatch_dynamic_scales_dtype(x, scales, quant_mode):
dynamic_scales_dtype = torch.float32
if quant_mode == 0:
if x.dtype != torch.bfloat16 and x.dtype != torch.float16 and scales is not None:
dynamic_scales_dtype = scales.dtype
elif quant_mode == 4:
dynamic_scales_dtype = torch.uint8
return dynamic_scales_dtype
def get_dispatch_dynamic_shape(scales, quant_mode, a, h):
shape = tuple([a])
if quant_mode == 0 and scales is not None:
if scales.dim() < 2:
raise RuntimeError(f"Expected scales to be at least 2-d, but got {scales.dim()}-d.")
shape = tuple([a * scales.shape[1]])
elif quant_mode == 2:
shape = tuple([a])
elif quant_mode == 3:
shape = tuple([a, math.ceil(h / 128)])
elif quant_mode == 4:
shape = tuple([a, (math.ceil(h / 32) + 1) // 2 * 2])
return shape
@impl(m, "npu_moe_distribute_dispatch_v2")
def npu_moe_distribute_dispatch_v2_meta(x, expert_ids, group_ep, ep_world_size, ep_rank_id, moe_expert_num, scales=None, x_active_mask=None, expert_scales=None, elastic_info=None, performance_info=None, group_tp="", tp_world_size=0,
tp_rank_id=0, expert_shard_type=0, shared_expert_num=1, shared_expert_rank_num=0, quant_mode=0, global_bs=0, expert_token_nums_type=1, comm_alg="",
zero_expert_num=0, copy_expert_num=0, const_expert_num=0, y_dtype=None, x_dtype=None, scales_dtype=None):
torch._check(
(ep_rank_id >= 0) and (ep_rank_id < ep_world_size),
lambda: (
f"ep_rank_id should be in [0, ep_world_size), "
f"but got {ep_world_size=}, {ep_rank_id=}."
f"{ops_error(ErrCode.VALUE)}."
),
)
torch._check(
(shared_expert_rank_num >= 0) and (shared_expert_rank_num < ep_world_size),
lambda: (
f"shared_expert_rank_num should be in [0, ep_world_size), "
f"but got {ep_world_size=}, {shared_expert_rank_num=}."
f"{ops_error(ErrCode.VALUE)}."
),
)
is_shared_default = ((shared_expert_num == 1) and (shared_expert_rank_num == 0))
is_no_shared = ((shared_expert_num == 0) and (shared_expert_rank_num == 0))
is_valid_shared = (
(shared_expert_num > 0)
and ((shared_expert_rank_num // shared_expert_num) > 0)
and ((shared_expert_rank_num % shared_expert_num) == 0)
)
torch._check(
is_shared_default or is_no_shared or is_valid_shared,
lambda: (
f"shared expert setting invalid, "
f"got {shared_expert_num=}, {shared_expert_rank_num=}."
f"{ops_error(ErrCode.VALUE)}."
),
)
torch._check(
expert_token_nums_type in [0, 1],
lambda: "the expert_token_nums_type should be 0 or 1" + ops_error(ErrCode.VALUE)
)
bs = x.size(0)
h = x.size(1)
k = expert_ids.size(1)
shared_front = (expert_shard_type == 0)
outDtype = torch.int8
local_moe_expert_num = 1
global_bs_real = 0
if global_bs == 0:
global_bs_real = bs * ep_world_size
else:
global_bs_real = global_bs
a = 0
if shared_front:
if ep_rank_id < shared_expert_rank_num:
local_moe_expert_num = 1
max_bs = global_bs_real // ep_world_size
rank_num_per_shared_expert = shared_expert_rank_num // shared_expert_num
max_shared_group_num = (ep_world_size + rank_num_per_shared_expert - 1) // rank_num_per_shared_expert
a = max_bs * max_shared_group_num
else:
local_moe_expert_num = moe_expert_num // (ep_world_size - shared_expert_rank_num)
a = global_bs_real * min(local_moe_expert_num, k)
if elastic_info is not None:
if ((is_shared_default) or (is_no_shared)):
local_moe_expert_num = max(local_moe_expert_num, moe_expert_num // (ep_world_size - shared_expert_rank_num))
a = global_bs_real * min(local_moe_expert_num, k)
else:
max_bs = global_bs_real // ep_world_size
rank_num_per_shared_expert = shared_expert_rank_num // shared_expert_num
max_shared_group_num = (ep_world_size + rank_num_per_shared_expert - 1) // rank_num_per_shared_expert
a = max(max_bs * max_shared_group_num, global_bs_real * min(moe_expert_num // (ep_world_size - shared_expert_rank_num), k))
local_moe_expert_num = max(local_moe_expert_num, moe_expert_num // (ep_world_size - shared_expert_rank_num))
ep_recv_cnt_num = 0
if tp_world_size == 2:
ep_recv_cnt_num = ep_world_size * local_moe_expert_num * tp_world_size
else:
ep_recv_cnt_num = ep_world_size * local_moe_expert_num
if quant_mode == 0:
outDtype = x.dtype
elif y_dtype is not None:
outDtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[y_dtype]
expand_idx = x.new_empty((max(bs * k, a * 128)), dtype=torch.int32)
expand_x = x.new_empty(tuple([max(a, a * tp_world_size), h]), dtype=outDtype)
dynamic_scales_dtype = get_dispatch_dynamic_scales_dtype(x, scales, quant_mode)
if tp_world_size == 0:
dynamic_scales = x.new_empty((a), dtype=dynamic_scales_dtype)
elif tp_world_size == 1:
dynamic_scales_shape = get_dispatch_dynamic_shape(scales, quant_mode, a, h)
dynamic_scales = x.new_empty(dynamic_scales_shape, dtype=dynamic_scales_dtype)
else:
dynamic_scales = x.new_empty((a * tp_world_size), dtype=dynamic_scales_dtype)
expert_token_nums = x.new_empty((local_moe_expert_num), dtype=torch.int64)
ep_recv_counts = x.new_empty((ep_recv_cnt_num), dtype=torch.int32)
tp_recv_counts = x.new_empty((tp_world_size), dtype=torch.int32)
expand_scales = x.new_empty((0), dtype=torch.float32)
if expert_scales is not None:
ep_recv_cnt_num = ep_world_size * local_moe_expert_num + global_bs_real * 2 * k * (ep_world_size // 8)
ep_recv_counts = x.new_empty((ep_recv_cnt_num), dtype=torch.int32)
expand_scales = x.new_empty((a), dtype=torch.float32)
return (expand_x, dynamic_scales, expand_idx, expert_token_nums, ep_recv_counts, tp_recv_counts, expand_scales)
@impl(m, "npu_moe_distribute_combine")
def npu_moe_distribute_combine_meta(expand_x, expert_ids, expand_idx, ep_send_counts, expert_scales, group_ep, ep_world_size, ep_rank_id, moe_expert_num,
tp_send_counts=None, x_active_mask=None, activation_scale=None, weight_scale=None, group_list=None, expand_scales=None, group_tp="", tp_world_size=0,
tp_rank_id=0, expert_shard_type=0, shared_expert_num=1, shared_expert_rank_num=0, global_bs=0, out_dtype=0, comm_quant_mode=0, group_list_type=0):
dim_list = []
dim_list.append(expert_ids.size(0))
dim_list.append(expand_x.size(1))
return expand_x.new_empty(tuple(dim_list), dtype=expand_x.dtype)
@impl(m, "npu_moe_distribute_combine_v2")
def npu_moe_distribute_combine_v2_meta(expand_x, expert_ids, assist_info_for_combine, ep_send_counts, expert_scales, group_ep, ep_world_size, ep_rank_id, moe_expert_num,
tp_send_counts=None, x_active_mask=None, expand_scales=None, shared_expert_x=None, elastic_info=None, ori_x=None, const_expert_alpha_1=None, const_expert_alpha_2=None, const_expert_v=None, performance_info=None, group_tp="", tp_world_size=0,
tp_rank_id=0, expert_shard_type=0, shared_expert_num=1, shared_expert_rank_num=0, global_bs=0, comm_quant_mode=0, comm_alg="", zero_expert_num=0, copy_expert_num=0, const_expert_num=0):
dim_tuple = (expert_ids.size(0), expand_x.size(1))
return expand_x.new_empty(dim_tuple)
@impl(m, "npu_moe_distribute_combine_add_rms_norm")
def npu_moe_distribute_combine_add_rms_norm_meta(expand_x, expert_ids, expand_idx, ep_send_counts, expert_scales, residual_x, gamma, group_ep, ep_world_size, ep_rank_id, moe_expert_num,
tp_send_counts=None, x_active_mask=None, activation_scale=None, weight_scale=None, group_list=None, expand_scales=None, shared_expert_x=None, elastic_info=None, ori_x=None, const_expert_alpha_1=None, const_expert_alpha_2=None, const_expert_v=None,
group_tp="", tp_world_size=0, tp_rank_id=0, expert_shard_type=0, shared_expert_num=1, shared_expert_rank_num=0, global_bs=0, out_dtype=0, comm_quant_mode=0, group_list_type=0, comm_alg="", norm_eps=0, zero_expert_num=0, copy_expert_num=0, const_expert_num=0):
dim_list = []
dim_list.append(expert_ids.size(0))
dim_list.append(1)
dim_list.append(expand_x.size(1))
dim_list2 = []
dim_list2.append(expert_ids.size(0))
dim_list2.append(1)
dim_list2.append(1)
return (expand_x.new_empty(tuple(dim_list), dtype=expand_x.dtype), expand_x.new_empty(tuple(dim_list2), dtype=torch.float32), expand_x.new_empty(tuple(dim_list), dtype=expand_x.dtype))
@impl(m, "npu_moe_distribute_dispatch_setup")
def npu_moe_distribute_dispatch_setup_meta(x, expert_ids, group_ep, ep_world_size, ep_rank_id, moe_expert_num,
scales=None, x_active_mask=None, expert_shard_type=0, shared_expert_num=1,
shared_expert_rank_num=0, quant_mode=0, global_bs=0, comm_type=0,
comm_alg="", y_dtype=None):
def Align(x, align_len):
if (align_len <= 0):
return -1
return math.ceil(x / align_len) * align_len
DIM_2 = 2
UNQUANT = 0
STATIC_QUANT = 1
PERTOKEN_DYNAMIC_QUANT = 2
PERGROUP_DYNAMIC_QUANT = 3
MX_QUANT = 4
BYTE_2 = 2
BYTE_4 = 4
ALGIN_2 = 2
ALGIN_32 = 32
ALGIN_128 = 128
ALGIN_256 = 256
ALGIN_512 = 512
torch._check(
(x.dim() == DIM_2),
lambda: (
f"The dims of input x should be 2 dimensional, "
f"but got {x.dim()}."
f"{ops_error(ErrCode.VALUE)}."
),
)
torch._check(
(expert_ids.dim() == DIM_2),
lambda: (
f"The dims of input expert_ids should be 2 dimensional, "
f"but got {expert_ids.dim()}."
f"{ops_error(ErrCode.VALUE)}."
),
)
bs = x.size(0)
h = x.size(1)
k = expert_ids.size(1)
if quant_mode == UNQUANT:
hs = Align(Align(h * BYTE_2, ALGIN_32), ALGIN_512) // BYTE_2
elif quant_mode == STATIC_QUANT:
hs = Align(Align(h, ALGIN_32), ALGIN_512)
elif quant_mode == PERTOKEN_DYNAMIC_QUANT:
hs = Align(Align(h, ALGIN_32) + BYTE_4, ALGIN_512)
elif quant_mode == PERGROUP_DYNAMIC_QUANT:
hs = Align((Align(h, ALGIN_128) + math.ceil(h / ALGIN_128) * BYTE_4), ALGIN_512)
elif quant_mode == MX_QUANT:
hs = Align(Align(h, ALGIN_256) + Align(math.ceil(h / ALGIN_32), ALGIN_2), ALGIN_512)
local_moe_expert_num = 0
torch._check(
(ep_rank_id >= 0 and ep_rank_id < ep_world_size),
lambda: (
f"ep_rank_id should be in [0, ep_world_size), "
f"but got ep_rank_id: {ep_rank_id}, ep_world_size: {ep_world_size}."
f"{ops_error(ErrCode.VALUE)}."
),
)
torch._check(
(shared_expert_rank_num >= 0 and shared_expert_rank_num <= ep_world_size // 2),
lambda: (
f"shared_expert_rank_num should be in [0, ep_world_size / 2], "
f"but got shared_expert_rank_num: {shared_expert_rank_num}, ep_world_size: {ep_world_size}."
f"{ops_error(ErrCode.VALUE)}."
),
)
if expert_shard_type == 0:
if ep_rank_id < shared_expert_rank_num:
local_moe_expert_num = 1
else:
local_moe_expert_num = moe_expert_num // (ep_world_size - shared_expert_rank_num)
local_moe_expert_num = int(local_moe_expert_num)
outDtype = x.dtype
if scales is not None or quant_mode != 0:
outDtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[y_dtype]
y = x.new_empty((bs * (k + shared_expert_num), hs), dtype=outDtype)
expand_idx = x.new_empty((bs * k), dtype=torch.int32)
comm_cmd_info = x.new_empty(((bs * (k + shared_expert_num) + ep_world_size * local_moe_expert_num) * 16), dtype=torch.int32)
return (y, expand_idx, comm_cmd_info)
@impl(m, "npu_moe_distribute_dispatch_teardown")
def npu_moe_distribute_dispatch_teardown_meta(x, y, expert_ids, comm_cmd_info, group_ep, ep_world_size, ep_rank_id,
moe_expert_num, expert_shard_type=0, shared_expert_num=1,
shared_expert_rank_num=0, quant_mode=0, global_bs=0,
expert_token_nums_type=1, comm_type=0, comm_alg=""):
DIM_2 = 2
torch._check(
(x.dim() == DIM_2),
lambda: (
f"The dims of input x should be 2 dimensional, "
f"but got {x.dim()}."
f"{ops_error(ErrCode.VALUE)}."
),
)
torch._check(
(expert_ids.dim() == DIM_2),
lambda: (
f"The dims of input expert_ids should be 2 dimensional, "
f"but got {expert_ids.dim()}."
f"{ops_error(ErrCode.VALUE)}."
),
)
bs = x.size(0)
h = x.size(1)
k = expert_ids.size(1)
local_moe_expert_num = 0
global_bs_real = (bs * ep_world_size) if global_bs == 0 else global_bs
a = 0
if expert_shard_type == 0:
if ep_rank_id < shared_expert_rank_num:
local_moe_expert_num = 1
a = global_bs_real // shared_expert_rank_num
else:
local_moe_expert_num = moe_expert_num // (ep_world_size - shared_expert_rank_num)
a = global_bs_real * min(local_moe_expert_num, k)
else:
if ep_rank_id >= ep_world_size - shared_expert_rank_num:
local_moe_expert_num = 1
a = global_bs_real // shared_expert_rank_num
else:
local_moe_expert_num = moe_expert_num // (ep_world_size - shared_expert_rank_num)
a = global_bs_real * min(local_moe_expert_num, k)
local_moe_expert_num = int(local_moe_expert_num)
outDtype = torch.int8 if (quant_mode != 0) else x.dtype
expand_x = x.new_empty(tuple([a, h]), dtype=outDtype)
dynamic_scales = x.new_empty(tuple([a]), dtype=torch.float32)
assist_info_for_combine = x.new_empty(tuple([a * 128]), dtype=torch.int32)
expert_token_nums = x.new_empty(tuple([local_moe_expert_num]), dtype=torch.int64)
return (expand_x, dynamic_scales, assist_info_for_combine, expert_token_nums)
@impl(m, "npu_moe_distribute_combine_setup")
def npu_moe_distribute_combine_setup_meta(expand_x, expert_ids, assist_info_for_combine, group_ep, ep_world_size,
ep_rank_id, moe_expert_num, expert_shard_type=0, shared_expert_num=1,
shared_expert_rank_num=0, global_bs=0, comm_quant_mode=0, comm_type=0,
comm_alg="", y_dtype=0):
def Align(x, align_len):
if (align_len <= 0):
return -1
return math.ceil(x / align_len) * align_len
DIM_2 = 2
torch._check(
(expand_x.dim() == DIM_2),
lambda: (
f"The dims of input expand_x should be 2 dimensional, "
f"but got {expand_x.dim()}."
f"{ops_error(ErrCode.VALUE)}."
),
)
a = expand_x.size(0)
h = expand_x.size(1)
hs = Align(Align(h, 32) + Align(h, 8) // 8 * 4, 512)
quant_expand_x = expand_x.new_empty(tuple([a, hs]), dtype=torch.int8)
comm_cmd_info = expand_x.new_empty(tuple([(a + ep_world_size) * 16]), dtype=torch.int32)
return (quant_expand_x, comm_cmd_info)
@impl(m, "npu_moe_distribute_combine_teardown")
def npu_moe_distribute_combine_teardown_meta(expand_x, quant_expand_x, expert_ids, expand_idx, expert_scales,
group_ep, ep_world_size, ep_rank_id, moe_expert_num, x_active_mask=None,
shared_expert_x=None, expert_shard_type=0, shared_expert_num=1,
shared_expert_rank_num=0, global_bs=0, comm_quant_mode=0, comm_type=0):
DIM_2 = 2
torch._check(
(expand_x.dim() == DIM_2),
lambda: (
f"The dims of input expand_x should be 2 dimensional, "
f"but got {expand_x.dim()}."
f"{ops_error(ErrCode.VALUE)}."
),
)
torch._check(
(expert_ids.dim() == DIM_2),
lambda: (
f"The dims of input expert_ids should be 2 dimensional, "
f"but got {expert_ids.dim()}."
f"{ops_error(ErrCode.VALUE)}."
),
)
return (expand_x.new_empty(tuple([expert_ids.size(0), expand_x.size(1)]), dtype=expand_x.dtype))
@impl(m, "_npu_distribute_barrier")
def _npu_distribute_barrier(x_ref, group, world_size, *, time_out=None, elastic_info=None):
return torch.empty_like(x_ref)
@impl(m, "npu_moe_update_expert")
def npu_moe_update_expert_meta(expert_ids, eplb_table, expert_scales=None, pruning_threshold=None, active_mask=None, local_rank_id=-1, world_size=-1, balance_mode=0):
dim_list = []
dim_list.append(expert_ids.size(0))
dim_list.append(expert_ids.size(1))
return (expert_ids.new_empty(tuple(dim_list), dtype=expert_ids.dtype), expert_ids.new_empty(tuple(dim_list), dtype=torch.bool))
@impl(m, "npu_ffn")
def npu_ffn_meta(x, weight1, weight2, activation, *, expert_tokens=None, expert_tokens_index=None, bias1=None,
bias2=None, scale=None, offset=None, deq_scale1=None, deq_scale2=None, antiquant_scale1=None,
antiquant_scale2=None, antiquant_offset1=None, antiquant_offset2=None, inner_precise=0,
output_dtype=None):
dim_list = []
for i in range(0, x.dim() - 1):
dim_list.append(x.size(i))
dim_list.append(weight2.size(weight2.dim() - 1))
if x.dtype == torch.int8:
if output_dtype is not None and output_dtype == torch.bfloat16:
return x.new_empty(tuple(dim_list), dtype=torch.bfloat16)
else:
return x.new_empty(tuple(dim_list), dtype=torch.float16)
else:
return x.new_empty(tuple(dim_list))
def gmm_get_dtype(output_dtype):
if not output_dtype:
return output_dtype
elif output_dtype not in [TORCH_DTYPE_MAP[torch.float16], TORCH_DTYPE_MAP[torch.bfloat16], TORCH_DTYPE_MAP[torch.float32], TORCH_DTYPE_MAP[torch.int32]]:
raise RuntimeError("The output dtype ", str(output_dtype), " is not supported for now.")
else:
return TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(output_dtype)
def is_transpose_weight(weight):
return weight.stride()[-2] == 1 and weight.stride()[-1] == weight.shape[-2]
@impl(m, "npu_grouped_matmul")
@impl(m, "npu_grouped_matmul.List")
def npu_grouped_matmul_meta(x, weight, *, bias=None, scale=None, offset=None, antiquant_scale=None,
antiquant_offset=None, per_token_scale=None, group_list=None,
activation_input=None, activation_quant_scale=None, activation_quant_offset=None,
split_item=0, group_type=None, group_list_type=0, act_type=0, tuning_config=None,
output_dtype=None, x_dtype=None, weight_dtype=None, scale_dtype=None, per_token_scale_dtype=None):
torch._check(
group_type == -1 or group_type == 0 or group_type == 2 or (isinstance(group_list, list) and group_type is None),
lambda: f"group_type only supports -1, 0 and 2, but got {group_type} {ops_error(ErrCode.VALUE)}",
)
if x_dtype is not None:
torch._check(
x_dtype == torch_npu.hifloat8 or x_dtype == torch_npu.float4_e2m1fn_x2,
lambda: "x_dtype supports hifloat8, mxfp4 for now, but it is " + npu_dtype_to_str(x_dtype),
)
if weight_dtype is not None:
torch._check(
weight_dtype == torch_npu.hifloat8 or weight_dtype == torch_npu.float4_e2m1fn_x2,
lambda: "weight_dtype only supports hifloat8, mxfp4 for now, but it is " + npu_dtype_to_str(weight_dtype),
)
if scale_dtype is not None:
torch._check(
scale_dtype == torch_npu.float8_e8m0fnu,
lambda: "scale_dtype only supports float8_e8m0fnu for now, but it is " + npu_dtype_to_str(scale_dtype),
)
if per_token_scale_dtype is not None:
torch._check(
per_token_scale_dtype == torch_npu.float8_e8m0fnu,
lambda: "per_token_scale_dtype only supports float8_e8m0fnu for now, but it is " + npu_dtype_to_str(per_token_scale_dtype),
)
y = []
num_x = len(x)
singleWeight = len(weight) == 1 and len(weight[0].shape) == 3
n = weight[0].shape[2] if singleWeight else weight[0].shape[1]
output_dtype = gmm_get_dtype(output_dtype)
INT4_IN_INT32 = 8
FP4_IN_INT8 = 2
is_a4w4_mxfp = x_dtype == torch_npu.float4_e2m1fn_x2 and weight_dtype == torch_npu.float4_e2m1fn_x2
if num_x > 0 and output_dtype is None:
output_dtype = x[0].dtype
if split_item == 0:
for i in range(num_x):
ni = n if singleWeight else weight[i].shape[1]
dim_n = ni * INT4_IN_INT32 if weight[i].dtype == torch.int32 else ni
y.append(x[i].new_empty((*x[i].shape[:-1], dim_n), dtype=output_dtype))
elif split_item == 1:
num_group_list = group_list.shape[0] if isinstance(group_list, torch.Tensor) else len(group_list)
pre_offset = group_list[0]
dim_n = n * INT4_IN_INT32 if weight[0].dtype == torch.int32 else n
y.append(x[0].new_empty((pre_offset, dim_n), dtype=output_dtype))
for i in range(1, num_group_list):
ni = n if singleWeight else weight[i].shape[1]
cur_offset = group_list[i]
dim_n = ni * INT4_IN_INT32 if weight[i].dtype == torch.int32 else ni
y.append(x[0].new_empty((cur_offset - pre_offset, dim_n), dtype=output_dtype))
pre_offset = cur_offset
elif split_item == 2 or split_item == 3:
dim_m = 0
dim_n = n * INT4_IN_INT32 if (weight[0].dtype == torch.int32 or (weight[0].dtype == torch.float32 and weight[0].dtype != x[0].dtype)) and \
not is_transpose_weight(weight[0]) else n
for i in range(num_x):
dim_m += x[i].shape[0]
if is_a4w4_mxfp:
dim_n = n if x[0].size(x[0].dim() - 1) == weight[0].size(weight[0].dim() - 2) else n * FP4_IN_INT8
if group_type != 2:
y.append(x[0].new_empty((dim_m, dim_n), dtype=output_dtype))
else:
num_group_list = group_list.shape[0]
y.append(x[0].new_empty((num_group_list, dim_m, dim_n), dtype=output_dtype))
return y
@impl(m, "npu_grouped_matmul_add_")
def npu_grouped_matmul_add__meta(y, x1, x2, group_list, *, transpose_x=True,
transpose_weight=False, group_type=2, group_list_type=0):
torch._check(
group_type == 2,
lambda: f"group_type only supports 2, but got {group_type} {ops_error(ErrCode.VALUE)}",
)
return y
@impl(m, "npu_matmul_all_to_all")
def npu_matmul_all_to_all_meta(x1, x2, hcom, world_size, bias=None, all2all_axes=None):
x1_dim = x1.dim()
x2_dim = x2.dim()
torch._check(
x1_dim == 2,
lambda: "x1_dim should be 2, but now it is " + str(x1_dim) + "." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim == 2,
lambda: "x2_dim should be 2, but now it is " + str(x2_dim) + "." + ops_error(ErrCode.VALUE),
)
x1_dim2 = x1.size(1)
x2_dim1 = x2.size(0)
x2_dim2 = x2.size(1)
torch._check(
x1_dim2 != 0,
lambda: "The second dim of x1 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim1 != 0,
lambda: "The first dim of x2 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim2 != 0,
lambda: "The second dim of x2 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
world_size != 0,
lambda: "world_size should not be 0." + ops_error(ErrCode.VALUE),
)
out_m = x1.size(0) * world_size
out_n = x2.size(1) // world_size
size = [out_m, out_n]
dtype = x1.dtype
return torch.empty(size, dtype=dtype, device='meta')
@impl(m, "npu_quant_matmul_all_to_all")
def npu_quant_matmul_all_to_all_meta(x1, x2, hcom, world_size, bias=None, x1_scale=None, x2_scale=None, common_scale=None,
x1_offset=None, x2_offset=None, x1_quant_mode=None, x2_quant_mode=None, common_quant_mode=None,
group_sizes=None, all2all_axes=None, comm_quant_dtype=None, x1_dtype=None, x2_dtype=None,
x1_scale_dtype=None, x2_scale_dtype=None,
output_scale_dtype=None, comm_scale_dtype=None, y_dtype=None):
x1_dim = x1.dim()
x2_dim = x2.dim()
torch._check(
x1_dim == 2,
lambda: "x1_dim should be 2, but now it is " + str(x1_dim) + "." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim == 2,
lambda: "x2_dim should be 2, but now it is " + str(x2_dim) + "." + ops_error(ErrCode.VALUE),
)
x1_dim1 = x1.size(0)
x1_dim2 = x1.size(1)
x2_dim1 = x2.size(0)
x2_dim2 = x2.size(1)
torch._check(
x1_dim1 != 0,
lambda: "The first dim of x1 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
x1_dim2 != 0,
lambda: "The second dim of x1 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim1 != 0,
lambda: "The first dim of x2 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim2 != 0,
lambda: "The second dim of x2 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
world_size != 0,
lambda: "world_size should not be 0." + ops_error(ErrCode.VALUE),
)
out_m = x1.size(0) * world_size
out_n = x2.size(1) // world_size
size = [out_m, out_n]
if y_dtype is None:
dtype = torch.float32
else:
dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[y_dtype]
return torch.empty(size, dtype=dtype, device='meta')
@impl(m, "npu_all_to_all_matmul")
def npu_all_to_all_matmul_meta(x1, x2, hcom, world_size, bias=None, all2all_axes=None, all2all_out_flag=True):
x1_dim = x1.dim()
x2_dim = x2.dim()
torch._check(
x1_dim == 2,
lambda: "x1_dim should be 2, but now it is " + str(x1_dim) + "." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim == 2,
lambda: "x2_dim should be 2, but now it is " + str(x2_dim) + "." + ops_error(ErrCode.VALUE),
)
x1_dim2 = x1.size(1)
x2_dim1 = x2.size(0)
x2_dim2 = x2.size(1)
torch._check(
x1_dim2 != 0,
lambda: "The second dim of x1 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim1 != 0,
lambda: "The first dim of x2 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim2 != 0,
lambda: "The second dim of x2 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
world_size != 0,
lambda: "world_size should not be 0." + ops_error(ErrCode.VALUE),
)
out_m = x1.size(0) // world_size
out_n = x2.size(1)
size = [out_m, out_n]
dtype = x1.dtype
if all2all_out_flag:
all2all_out_size = [out_m, x1.size(1) * world_size]
return (torch.empty(size, dtype=dtype, device='meta'),
torch.empty(all2all_out_size, dtype=dtype, device='meta'))
else:
return (torch.empty(size, dtype=dtype, device='meta'), None)
@impl(m, "npu_all_to_all_quant_matmul")
def npu_all_to_all_quant_matmul_meta(x1, x2, hcom, world_size, all2all_out_flag=True, bias=None, x1_scale=None, x2_scale=None, common_scale=None,
x1_offset=None, x2_offset=None, x1_quant_mode=None, x2_quant_mode=None, common_quant_mode=None, group_sizes=None,
all2all_axes=None, comm_quant_dtype=None, x1_quant_dtype=None, x1_dtype=None, x2_dtype=None, x1_scale_dtype=None,
x2_scale_dtype=None, output_scale_dtype=None, comm_scale_dtype=None, y_dtype=None):
x1_dim = x1.dim()
x2_dim = x2.dim()
torch._check(
x1_dim == 2,
lambda: "x1_dim should be 2, but now it is " + str(x1_dim) + "." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim == 2,
lambda: "x2_dim should be 2, but now it is " + str(x2_dim) + "." + ops_error(ErrCode.VALUE),
)
x1_dim1 = x1.size(0)
x1_dim2 = x1.size(1)
x2_dim1 = x2.size(0)
x2_dim2 = x2.size(1)
torch._check(
x1_dim1 != 0,
lambda: "The first dim of x1 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
x1_dim2 != 0,
lambda: "The second dim of x1 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim1 != 0,
lambda: "The first dim of x2 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
x2_dim2 != 0,
lambda: "The second dim of x2 should not be 0." + ops_error(ErrCode.VALUE),
)
torch._check(
world_size != 0,
lambda: "world_size should not be 0." + ops_error(ErrCode.VALUE),
)
INT4_IN_INT32 = 8
is_w4 = x2.dtype == torch.int32
if is_w4:
out_n = x2.size(1) * INT4_IN_INT32
else:
out_n = x2.size(1)
out_m = x1.size(0) // world_size
size = [out_m, out_n]
if y_dtype is None:
dtype = torch.float32
else:
dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[y_dtype]
if all2all_out_flag:
all2all_out_size = [out_m, x1.size(1) * world_size]
return (torch.empty(size, dtype=dtype, device='meta'),
torch.empty(all2all_out_size, dtype=x1.dtype, device='meta'))
else:
return (torch.empty(size, dtype=dtype, device='meta'), None)
def add_quant_gmm_check(*args):
group_sizes, x1_dtype, x2_dtype, x1_scale_dtype, x2_scale_dtype = args
torch._check(
group_sizes is None,
lambda: "group_sizes is not supported for now",
)
if x1_dtype is not None:
torch._check(
x1_dtype == torch_npu.hifloat8,
lambda: "x1_dtype is only supported hifloat8 for now, but it is " + str(x1_dtype),
)
if x2_dtype is not None:
torch._check(
x2_dtype == torch_npu.hifloat8,
lambda: "x2_dtype is only supported hifloat8 for now, but it is " + str(x2_dtype),
)
if x1_scale_dtype is not None:
torch._check(
x1_scale_dtype == torch_npu.float8_e8m0fnu,
lambda: "x1_scale_dtype is only supported float8_e8m0fnu for now, but it is " + str(x1_scale_dtype),
)
if x2_scale_dtype is not None:
torch._check(
x2_scale_dtype == torch_npu.float8_e8m0fnu,
lambda: "x2_scale_dtype is only supported float8_e8m0fnu for now, but it is " + str(x2_scale_dtype),
)
@impl(m, "npu_add_quant_gmm_")
def npu_add_quant_gmm__meta(y, x1, x2, x2_scale, group_list, *, x1_scale=None, group_list_type=0, group_sizes=None,
x1_dtype=None, x2_dtype=None, x1_scale_dtype=None, x2_scale_dtype=None):
add_quant_gmm_check(group_sizes, x1_dtype, x2_dtype, x1_scale_dtype, x2_scale_dtype)
return y
@impl(m, "npu_add_quant_gmm")
def npu_add_quant_gmm_meta(y, x1, x2, x2_scale, group_list, *, x1_scale=None, group_list_type=0, group_sizes=None,
x1_dtype=None, x2_dtype=None, x1_scale_dtype=None, x2_scale_dtype=None):
add_quant_gmm_check(group_sizes, x1_dtype, x2_dtype, x1_scale_dtype, x2_scale_dtype)
return torch.empty_like(y)
@impl(m, "npu_add_quant_matmul_")
def npu_add_quant_matmul__meta(y, x1, x2, x2_scale, *, x1_scale=None, group_sizes=None,
x1_dtype=None, x2_dtype=None, x1_scale_dtype=None, x2_scale_dtype=None):
return y
@impl(m, "npu_add_quant_matmul")
def npu_add_quant_matmul_meta(y, x1, x2, x2_scale, *, x1_scale=None, group_sizes=None,
x1_dtype=None, x2_dtype=None, x1_scale_dtype=None, x2_scale_dtype=None):
return torch.empty_like(y)
@impl(m, "npu_grouped_matmul_finalize_routing")
def npu_grouped_matmul_finalize_routing_meta(x, w, group_list, *, scale=None, bias=None, offset=None,
pertoken_scale=None, shared_input=None, logit=None,
row_index=None, dtype=None, shared_input_weight=1.0,
shared_input_offset=0, output_bs=0, group_list_type=1, tuning_config=None,
x_dtype=None, w_dtype=None, scale_dtype=None, pertoken_scale_dtype=None):
torch._check(
torch.is_tensor(x),
lambda: "x must be tensor." + ops_error(ErrCode.VALUE)
)
torch._check(
torch.is_tensor(w),
lambda: "w must be tensor." + ops_error(ErrCode.VALUE)
)
if x_dtype is not None:
torch._check(
x_dtype == torch_npu.hifloat8 or x_dtype == torch_npu.float4_e2m1fn_x2,
lambda: "x_dtype supports float4_e2m1fn_x2 or hifloat8 for now, but it is " + npu_dtype_to_str(x_dtype),
)
if w_dtype is not None:
torch._check(
w_dtype == torch_npu.hifloat8 or w_dtype == torch_npu.float4_e2m1fn_x2,
lambda: "weight_dtype only supports float4_e2m1fn_x2 or hifloat8 for now, but it is " + npu_dtype_to_str(w_dtype),
)
if scale_dtype is not None:
torch._check(
scale_dtype == torch_npu.float8_e8m0fnu,
lambda: "scale_dtype only supports float8_e8m0fnu for now, but it is " + npu_dtype_to_str(scale_dtype),
)
if pertoken_scale_dtype is not None:
torch._check(
pertoken_scale_dtype == torch_npu.float8_e8m0fnu,
lambda: "pertoken_scale_dtype only supports float8_e8m0fnu for now, but it is " + npu_dtype_to_str(per_token_scale_dtype),
)
dimm = x.size(0)
x_dim = x.dim()
w_dim = w.dim()
dimn = w.size(w_dim - 1)
INT4_IN_INT32 = 8
torch._check(
x_dim == 2 and w_dim == 3,
lambda: "x_dim should be 2 and w_dim should be 3." + ops_error(ErrCode.VALUE),
)
if dtype is None:
dtype = torch.float32
if shared_input is not None and logit is not None:
torch._check(
dtype == torch.float32,
lambda: "When shared_input is not None, output_dtype must be float32, but it is " +
str(dtype) + ops_error(ErrCode.TYPE),
)
y_dimm = output_bs
if output_bs == 0:
y_dimm = dimm
FP4_IN_INT8 = 2
w_trans = x.size(-1) == w.size(-2)
is_a4w4_input = False
if x_dtype is not None and w_dtype is not None:
is_a4w4_input = x_dtype == torch_npu.float4_e2m1fn_x2 and w_dtype == torch_npu.float4_e2m1fn_x2
if w.dtype == torch.int32:
dim_n = dimn * INT4_IN_INT32
elif is_a4w4_input and not w_trans:
dim_n = dimn * FP4_IN_INT8
else:
dim_n = dimn
dim_list = [y_dimm, dim_n]
if dtype == torch.float32:
return x.new_empty(tuple(dim_list), dtype=torch.float32)
else:
raise RuntimeError("Not supportted output dtype is " + str(dtype))
@impl(m, "npu_group_norm_silu")
def group_norm_silu_meta(self, gemma, beta, group, eps=0.00001):
N = self.size(0)
if gemma is None or beta is None:
return (torch.empty_like(self, dtype=self.dtype), self.new_empty((N, group), dtype=self.dtype), self.new_empty((N, group), dtype=self.dtype))
else:
return (torch.empty_like(self, dtype=self.dtype), gemma.new_empty((N, group), dtype=gemma.dtype), beta.new_empty((N, group), dtype=beta.dtype))
@impl(m, "npu_mm_all_reduce_base")
def npu_mm_all_reduce_base_forward(x1, x2, hcom, reduce_op='sum', bias=None, antiquant_scale=None,
antiquant_offset=None, x3=None, dequant_scale=None, pertoken_scale=None,
comm_quant_scale_1=None, comm_quant_scale_2=None, antiquant_group_size=0,
comm_turn=0, group_sizes=None, y_dtype=None, x1_dtype=None, x2_dtype=None,
dequant_scale_dtype=None, pertoken_scale_dtype=None, comm_quant_mode=0):
dim_list = []
for i in range(x1.dim()):
dim_list.append(x1.size(i))
dim_list[-1] = x2.size(1)
dim_tuple = tuple(dim_list)
if dequant_scale is not None:
if y_dtype is None:
dtype = torch.bfloat16 if dequant_scale.dtype == torch.bfloat16 else torch.float16
return x1.new_empty(dim_tuple, dtype=dtype)
else:
return x1.new_empty(dim_tuple, dtype=TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(y_dtype, torch.float16))
return x1.new_empty(dim_tuple)
@impl(m, "npu_weight_quant_batchmatmul")
def npu_weight_quant_batchmatmul_meta(x, weight, antiquant_scale, antiquant_offset=None, quant_scale=None, quant_offset=None, bias=None, antiquant_group_size=0, inner_precise=0,
weight_dtype=None):
dim_m = x.size(0)
if (weight.dtype == torch.int32 or weight.dtype == torch.float32) and weight.is_contiguous():
dim_n = weight.size(1) * 8
else:
dim_n = weight.size(1)
if quant_scale is not None:
return x.new_empty((dim_m, dim_n), dtype=torch.int8)
return x.new_empty((dim_m, dim_n), dtype=x.dtype)
def is_transpose_last_two_dims(tensor):
if tensor.dim() < 2 or tensor.dim() > 6:
return False
dim1 = tensor.dim() - 1
dim2 = tensor.dim() - 2
if tensor.stride(dim2) == 1 and tensor.stride(dim1) == tensor.size(dim2):
tmpNxD = tensor.size(dim1) * tensor.size(dim2)
for batchDim in range(tensor.dim() - 3, -1, -1):
if tensor.stride(batchDim) != tmpNxD:
return False
tmpNxD = tmpNxD * tensor.size(batchDim)
if tensor.size(dim1) == 1 and tensor.size(dim2) == 1:
return False
return True
return False
def bias_shape_check(*args):
x2, bias, batch_val, is_a4w4, is_a8w4_float, transpose_x2, is_mxfp4_valid = args
bias_dim_num = bias.dim()
if is_a4w4:
torch._check(
bias_dim_num == 1,
lambda: "bias_dim_num should be 1 when x1's dtype is int32, please check bias dim num " + ops_error(ErrCode.VALUE),
)
elif is_a8w4_float:
torch._check(bias_dim_num == 2,
lambda: "in a8w4 float, bias_dim_num should be 2 , please check bias dim num " + ops_error(ErrCode.VALUE),
)
return
else:
torch._check(
bias_dim_num == 1 or bias_dim_num == 3,
lambda: "bias_dim_num should be 1 or 3 when x1's dtype is int8, please check bias dim num " + ops_error(ErrCode.VALUE),
)
x2_dim_num = x2.dim()
x2_n_dim = x2.size(x2_dim_num - 1) * 8 if (is_a4w4 and not transpose_x2) else x2.size(x2_dim_num - 1)
if is_mxfp4_valid:
x2_n_dim = x2.size(x2_dim_num - 1) if transpose_x2 else x2.size(x2_dim_num - 1) * 2
bias_first_dim = bias.size(0)
if bias_dim_num == 1:
torch._check(
bias_first_dim == x2_n_dim,
lambda: "bias_first_dim should be equal to x2 n dim, please check bias 1st dim value " + ops_error(ErrCode.VALUE),
)
return
bias_second_dim = bias.size(1)
bias_third_dim = bias.size(2)
torch._check(
bias_first_dim == batch_val,
lambda: "infered batch value should be equal to bias batch dim value, please check bias batch dim value" + ops_error(ErrCode.VALUE),
)
torch._check(
bias_second_dim == 1,
lambda: "bias_second_dim should be 1, please check bias second dim value " + ops_error(ErrCode.VALUE),
)
torch._check(
bias_third_dim == x2_n_dim,
lambda: "bias_third_dim should be equal to x2_n_dim, please check bias third dim value " + ops_error(ErrCode.VALUE),
)
def quant_matmul_shape_check(*args):
x1, x2, scale, offset, pertoken_scale, is_a4w4, transpose_x1, transpose_x2, is_a8w4_int, is_a8w4_float, group_sizes, is_mxfp4_valid = args
X_MAX_DIM = 6
X_MIN_DIM = 2
INT4_IN_INT32 = 8
FP4_IN_INT8 = 2
GROUP_SIZE_A8W4 = 256
x1_dim_num = x1.dim()
x2_dim_num = x2.dim()
x1_m_dim = x1.size(x1_dim_num - 2)
x1_k_dim = x1.size(x1_dim_num - 1)
x2_k_dim = x2.size(x2_dim_num - 2)
x2_n_dim = x2.size(x2_dim_num - 1) * INT4_IN_INT32 if ((is_a4w4 and not transpose_x2) or is_a8w4_int) else x2.size(x2_dim_num - 1)
if is_mxfp4_valid:
x1_m_dim = x1.size(x1_dim_num - 2) if not transpose_x1 else x1.size(x1_dim_num - 2) * FP4_IN_INT8
x1_k_dim = x1.size(x1_dim_num - 1) * FP4_IN_INT8 if not transpose_x1 else x1.size(x1_dim_num - 1)
x2_k_dim = x2.size(x2_dim_num - 2) if not transpose_x2 else x2.size(x2_dim_num - 2) * FP4_IN_INT8
x2_n_dim = x2.size(x2_dim_num - 1) * FP4_IN_INT8 if not transpose_x2 else x2.size(x2_dim_num - 1)
torch._check(
x1_dim_num >= X_MIN_DIM and x1_dim_num <= X_MAX_DIM,
lambda: f"x1 dim num should be 2 ~ 6, please check x1 dim num {ops_error(ErrCode.VALUE)}",
)
if is_a4w4 and not transpose_x2:
torch._check(
x1_k_dim * INT4_IN_INT32 == x2_k_dim,
lambda: f"k dim of x2 should be 8 multiple of k dim of x1, \
please check k dim of x1 and x2 {ops_error(ErrCode.VALUE)}",
)
elif is_a8w4_float:
if (x2.dtype == torch.float32):
if pertoken_scale is not None:
torch._check(
x1_k_dim == x2_k_dim * INT4_IN_INT32,
lambda: "a8w4 nz mx quant only support x1 not transpose and x2 transpose and k dim of x1 should be 8 multiple of k dim of x2." + ops_error(ErrCode.VALUE),
)
else:
torch._check(
x1_k_dim == x2_k_dim,
lambda: "a8w4 nz t-cg quant only support x1 not transpose and x2 not transpose and k dim of x1 and x2 need be same." + ops_error(ErrCode.VALUE),
)
else:
torch._check(
x1_k_dim == x2_k_dim * FP4_IN_INT8,
lambda: "a8w4_float nd only support x1 not transpose and x2 transpose and k dim of x1 should be 2 multiple of k dim of x2, please check k dim of x1 and x2" + ops_error(ErrCode.VALUE),
)
else:
torch._check(
x1_k_dim == x2_k_dim,
lambda: f"k dim of x1 and x2 need be same, please check k dim of x1 and x2 {ops_error(ErrCode.VALUE)}",
)
if is_a4w4:
torch._check(
x2_dim_num == X_MIN_DIM,
lambda: f"x2 dim num should be 2 when x1's dtype is int32, \
please check x2 dim num {ops_error(ErrCode.VALUE)}",
)
else:
torch._check(
x2_dim_num >= X_MIN_DIM and x2_dim_num <= X_MAX_DIM,
lambda: f"x2 dim num should be 2 ~ 6 when x1's dtype is int8, \
please check x2 dim num {ops_error(ErrCode.VALUE)}",
)
if offset is not None:
offset_dim_num = offset.dim()
torch._check(
offset_dim_num == 1,
lambda: f"the offset dim num must be 1, please check offset dim num {ops_error(ErrCode.VALUE)}",
)
offset_first_dim = offset.size(0)
torch._check(
offset_first_dim == 1 or offset_first_dim == x2_n_dim,
lambda: f"the offset 1st dim value must be 1 or x2 n dim value, \
please check offset 1st dim value {ops_error(ErrCode.VALUE)}",
)
if group_sizes is None:
if pertoken_scale is not None:
pertoken_scale_dim_num = pertoken_scale.dim()
if is_a8w4_int:
torch._check(
pertoken_scale_dim_num == 2,
lambda: f"the pertoken_scale dim num must be 2, please check scale dim num {ops_error(ErrCode.VALUE)}",
)
else:
torch._check(
pertoken_scale_dim_num == 1,
lambda: f"the pertoken_scale dim num must be 1, please check scale dim num {ops_error(ErrCode.VALUE)}",
)
scale_dim_num = scale.dim()
if is_a8w4_int:
torch._check(
scale_dim_num == 2,
lambda: f"the scale dim num must be 2, please check scale dim num {ops_error(ErrCode.VALUE)}",
)
scale_first_dim = scale.size(0)
torch._check(
scale_first_dim == x1_k_dim // GROUP_SIZE_A8W4,
lambda: f"the scale 1st dim value must equal to x1 k dim divide 256, \
please check scale 1st dim value {ops_error(ErrCode.VALUE)}",
)
scale_last_dim = scale.size(1)
torch._check(
scale_last_dim == x2_n_dim,
lambda: f"the scale last dim value must equal to x2 n dim value, \
please check scale last dim value {ops_error(ErrCode.VALUE)}",
)
else:
torch._check(
scale_dim_num == 1,
lambda: f"the scale dim num must be 1, please check scale dim num {ops_error(ErrCode.VALUE)}",
)
scale_first_dim = scale.size(0)
torch._check(
scale_first_dim == 1 or scale_first_dim == x2_n_dim,
lambda: f"the scale 1st dim value must be 1 or x2 n dim value, \
please check scale 1st dim value {ops_error(ErrCode.VALUE)}",
)
def quant_matmul_bias_dtype_check(bias, pertoken_scale, output_dtype):
bias_dtype_supported_list = [torch.int32, torch.bfloat16, torch.float32, torch.float16]
torch._check(
bias.dtype in bias_dtype_supported_list,
lambda: "bias's type supported for int32, bfloat16, float16 and float32, but bias.dtype is " + str(bias.dtype) + ops_error(ErrCode.TYPE),
)
if bias.dtype == torch.bfloat16:
torch._check(
output_dtype == TORCH_DTYPE_MAP[torch.bfloat16],
lambda: "When bias dtype is bfloat16, output_dtype must be bfloat16, but it is " +
str(output_dtype) + ops_error(ErrCode.TYPE),
)
if output_dtype == TORCH_DTYPE_MAP[torch.int32]:
torch._check(
bias.dtype == torch.int32,
lambda: "When output_dtype dtype is int32, bias_dtype must be int32, but it is " +
str(bias.dtype) + ops_error(ErrCode.TYPE),
)
if pertoken_scale is not None:
if bias.dtype == torch.float16:
torch._check(
output_dtype == TORCH_DTYPE_MAP[torch.float16],
lambda: "When bias dtype is float16 and pertoken is given, output_dtype must be float16, but it is " +
str(output_dtype) + ops_error(ErrCode.TYPE),
)
else:
torch._check(
bias.dtype != torch.float16,
lambda: "Bias dtype cannot be float16 when pertoken not given." + ops_error(ErrCode.TYPE),
)
def quant_matmul_extra_dtype_check(*args):
x1, x2, scale, pertoken_scale, x1_dtype, x2_dtype, scale_dtype, is_a8w4_float, pertoken_scale_dtype = args
if x1_dtype is not None:
torch._check(
x1_dtype == torch_npu.float4_e2m1fn_x2 or x1_dtype == torch_npu.hifloat8,
lambda: "The x1_dtype supported for torch_npu.float4_e2m1fn_x2, torch_npu.hifloat8, but x1_dtype is " +
npu_dtype_to_str(x2_dtype) + ops_error(ErrCode.TYPE),
)
torch._check(
x1.element_size() == 1,
lambda: "When x1_dtype is not None, x1 must be a 1 byte tensor, but the byte size of x1 is" +
str(x1.element_size()) + ops_error(ErrCode.TYPE),
)
if x2_dtype is not None and not is_a8w4_float:
torch._check(
x2_dtype == torch_npu.float4_e2m1fn_x2 or x2_dtype == torch_npu.hifloat8,
lambda: "The x1_dtype supported for torch_npu.float4_e2m1fn_x2, torch_npu.hifloat8, but x1_dtype is " +
npu_dtype_to_str(x2_dtype) + ops_error(ErrCode.TYPE),
)
torch._check(
x2.element_size() == 1,
lambda: "When x2_dtype is not None, x2 must be a 1 byte tensor, but the byte size of x2 is" +
str(x2.element_size()) + ops_error(ErrCode.TYPE),
)
if scale_dtype is not None:
torch._check(
scale_dtype == torch_npu.float8_e8m0fnu,
lambda: "The scale_dtype supported for torch_npu.float8_e8m0fnu, but scale_dtype is " +
npu_dtype_to_str(scale_dtype) + ops_error(ErrCode.TYPE),
)
torch._check(
scale.element_size() == 1,
lambda: "When scale_dtype is not None, scale must be a 1 byte tensor, but the byte size of scale is" +
str(scale.element_size()) + ops_error(ErrCode.TYPE),
)
if pertoken_scale_dtype is not None:
torch._check(
pertoken_scale_dtype == torch_npu.float8_e8m0fnu,
lambda: "The pertoken_scale_dtype supported for torch_npu.float8_e8m0fnu, but pertoken_scale_dtype is " +
npu_dtype_to_str(pertoken_scale_dtype) + ops_error(ErrCode.TYPE),
)
torch._check(
pertoken_scale.element_size() == 1,
lambda: "When pertoken_scale_dtype is not None, pertoken_scale must be a 1 byte tensor, but the byte size of pertoken_scale is" +
str(pertoken_scale.element_size()) + ops_error(ErrCode.TYPE),
)
def quant_matmul_dtype_check(*args):
x1, x2, scale, offset, pertoken_scale, bias, output_dtype, is_a4w4, is_a8w4_int, is_a8w4_float, y_scale = args
if is_a8w4_int:
torch._check(
x1.dtype == torch.int8,
lambda: f"x1's type should be torch.int8 in A8W4, but x1.dtype is {str(x1.dtype)} {ops_error(ErrCode.TYPE)}",
)
torch._check(
x2.dtype == torch.int32,
lambda: f"x2's type should be torch.int32 in A8W4, but x2.dtype is {str(x2.dtype)} {ops_error(ErrCode.TYPE)}",
)
torch._check(
scale.dtype == torch.int64,
lambda: f"scale's type should be torch.int64 in A8W4, \
but scale.dtype is {str(scale.dtype)} {ops_error(ErrCode.TYPE)}",
)
if offset is not None:
torch._check(
offset.dtype == torch.float32,
lambda: f"offset's type should be torch.float32 in A8W4, \
but offset.dtype is {str(offset.dtype)} {ops_error(ErrCode.TYPE)}",
)
if pertoken_scale is not None:
torch._check(
pertoken_scale.dtype == torch.float32,
lambda: f"pertoken_scale's type should be torch.float32 in A8W4, \
but pertoken_scale.dtype is {str(pertoken_scale.dtype)} {ops_error(ErrCode.TYPE)}",
)
if bias is not None:
torch._check(
bias.dtype == torch.int32,
lambda: f"bias's type should be torch.int32 in A8W4, \
but bias.dtype is {str(bias.dtype)} {ops_error(ErrCode.TYPE)}",
)
if output_dtype is not None:
torch._check(
output_dtype == TORCH_DTYPE_MAP[torch.float16] or output_dtype == TORCH_DTYPE_MAP[torch.bfloat16],
lambda: f"output_dtype's type should be torch.int32 or torch.bfloat16 in A8W4, \
but output_dtype.dtype is {npu_dtype_to_str(output_dtype)} {ops_error(ErrCode.TYPE)}",
)
else:
if offset is not None:
torch._check(
offset.dtype == torch.float32,
lambda: f"offset's type supported for float32, \
but offset.dtype is {str(offset.dtype)} {ops_error(ErrCode.TYPE)}",
)
if bias is not None:
quant_matmul_bias_dtype_check(bias, pertoken_scale, output_dtype)
if is_a8w4_float and y_scale is not None:
torch._check(
y_scale.dtype == torch.int64,
lambda: "y_scale's type supported for int64, but y_scale.dtype is " + str(y_scale.dtype) + ops_error(ErrCode.TYPE),
)
def quant_matmul_scale_offset_out_check(scale, offset, pertoken_scale, output_dtype, is_a4w4):
if scale.dtype == torch.bfloat16:
torch._check(
output_dtype in [TORCH_DTYPE_MAP[torch.bfloat16], TORCH_DTYPE_MAP[torch.int32]],
lambda: "When scale's dtype is bfloat16, output_dtype must be bfloat16 or int32, but output_dtype is " +
npu_dtype_to_str(output_dtype) + ops_error(ErrCode.TYPE),
)
if output_dtype == TORCH_DTYPE_MAP[torch.int32]:
torch._check(
scale.dtype in [torch.bfloat16, torch.float32],
lambda: "When output_dtype is int32, scale's dtype must be bfloat16 or float32, but scale's dtype is " +
str(scale.dtype) + ops_error(ErrCode.TYPE),
)
if is_a4w4:
torch._check(
output_dtype == TORCH_DTYPE_MAP[torch.float16],
lambda: "When input's dtype is int32, output_dtype must be float16, but output_dtype is " +
npu_dtype_to_str(output_dtype) + ops_error(ErrCode.TYPE),
)
def quant_matmul_group_sizes_check(*args):
x1, x2, scale, pertoken_scale, group_sizes, x1_dtype, x2_dtype, scale_dtype, pertoken_scale_dtype, is_a8w4_float = args
if not is_a8w4_float and pertoken_scale is not None and pertoken_scale.dim() >= 2 and scale.dim() >= 2:
if pertoken_scale_dtype is not None and pertoken_scale_dtype == torch_npu.float8_e8m0fnu:
pertoken_scale_k_idx = pertoken_scale.dim() - 2
scale_k_idx = scale.dim() - 3
else:
pertoken_scale_k_idx = pertoken_scale.dim() - 1
scale_k_idx = scale.dim() - 2
torch._check(
(pertoken_scale.size(pertoken_scale_k_idx) == scale.size(scale_k_idx)),
lambda: "In mx, B-B, G-B quantification, k dimension of scale and pertoken_scale must be equal, \
please check the sizes of scale and pertoken_scale" + ops_error(ErrCode.VALUE),
)
if group_sizes is None:
return
torch._check(
len(group_sizes) == 3,
lambda: "group_sizes's length must be 3, please check group_sizes's length" + ops_error(ErrCode.VALUE),
)
if is_a8w4_float:
torch._check(
(group_sizes[0] == 0 and group_sizes[1] == 0 and group_sizes[2] == 32) or \
(group_sizes[0] == 1 and group_sizes[1] == 1 and group_sizes[2] == 32),
lambda: "when the dtype of input is A8W4, group_sizes's value must be [0,0,32] or [1,1,32], please check group_sizes's value" + ops_error(ErrCode.VALUE),
)
return
is_a8w8_int = x1_dtype is None and x2_dtype is None and x1.dtype == torch.int8 and x2.dtype == torch.int8
if is_a8w8_int:
torch._check(
(group_sizes[0] == 0 and group_sizes[1] == 0 or group_sizes[2] == 0) or \
(group_sizes[0] == 1 and group_sizes[1] == 128 and group_sizes[2] == 128),
lambda: "when the dtype of input is int8, group_sizes's value must be 0 or [1,128,128], please check group_sizes's value" + ops_error(ErrCode.VALUE),
)
if pertoken_scale is None:
torch._check(
group_sizes[0] == 0 or group_sizes[1] == 0 or group_sizes[2] == 0,
lambda: "when the pertoken_scale is None, group_sizes's value must be 0, please check group_sizes's value" + ops_error(ErrCode.VALUE),
)
group_input_dtype_lst = [torch.uint8, torch.bits8, torch.float8_e4m3fn, torch.float8_e5m2, torch.int8]
group_scale_dtype_lst = [torch.float32]
has_group = (group_sizes[0] > 1 or group_sizes[1] > 1 or group_sizes[2] > 1)
if group_sizes is not None and has_group:
torch._check(
(scale_dtype is not None and pertoken_scale_dtype is not None) or (scale.dtype in group_scale_dtype_lst or pertoken_scale.dtype in group_scale_dtype_lst),
lambda: "When group_sizes's value is not 0, scale_dtype and pertoken_scale_dtype are None, dtype of scale and pertoken_scale must be both float32, but " +
"scale's dtype is " + str(scale.dtype) + " pertoken_scale's dtype is " + str(pertoken_scale.dtype) + ops_error(ErrCode.TYPE),
)
torch._check(
(x1_dtype is not None and x2_dtype is not None) or (x1.dtype in group_input_dtype_lst or x2.dtype in group_input_dtype_lst),
lambda: "When group_sizes's value is not 0, x1_dtype and x2_dtype are None, dtype of input must be uint8, float8_e4m3fn, float8_e5m2 , int8 or int32, but x1's dtype is " +
str(x1.dtype) + " x2's dtype is " + str(x2.dtype) + ops_error(ErrCode.TYPE),
)
if group_sizes[0] > 1:
torch._check(
pertoken_scale.dim() >= 2 and pertoken_scale.size(pertoken_scale.dim() - 2) == math.ceil(x1.size(x1.dim() - 2) / group_sizes[0]),
lambda: "When group_sizes[0] > 1, ceil(x1.size(-2) / group_sizes[0]) must be equal to " +
"pertoken_scale's size(-2), please check your input" + ops_error(ErrCode.VALUE),
)
torch._check(
group_sizes[1] == group_sizes[0] and group_sizes[1] == 128,
lambda: "When group_sizes[1] > 1, group_sizes[1] must be equal to group_sizes[2] and must be equal to 128" + ops_error(ErrCode.VALUE),
)
if group_sizes[2] > 1:
group_k_support_lst = [32, 128]
torch._check(
group_sizes[2] in group_k_support_lst,
lambda: "When group_sizes[2] > 1, group_sizes[2] must be equal to 32 or 128, but group_sizes[2] is " +
str(group_sizes[2]) + ops_error(ErrCode.VALUE),
)
if group_sizes[1] > 1:
torch._check(
scale.dim() >= 2 and scale.size(scale.dim() - 1) == math.ceil(x2.size(x2.dim() - 1) / group_sizes[1]),
lambda: "When group_sizes[2] > 1, ceil(x2.size(-1) / group_sizes[2]) must be equal to scale's size(-1), " +
"please check your input" + ops_error(ErrCode.VALUE),
)
torch._check(
group_sizes[1] == group_sizes[2] and group_sizes[1] == 128,
lambda: "When group_sizes[1] > 1, group_sizes[1] must be equal to group_sizes[2] and must be equal to 128" + ops_error(ErrCode.VALUE),
)
@impl(m, "obfuscation_calculate")
def obfuscation_calculate_meta(fd, x, param, cmd):
return torch.empty_like(x)
@impl(m, "obfuscation_finalize")
def obfuscation_finalize_meta(fd_to_close):
return torch.empty_like(fd_to_close)
@impl(m, "npu_quant_matmul")
def npu_quant_matmul_meta(x1, x2, scale, *, offset=None, pertoken_scale=None, bias=None, output_dtype=None,
x1_dtype=None, x2_dtype=None, pertoken_scale_dtype=None, scale_dtype=None,
group_sizes=None, y_scale=None):
INT4_IN_INT32 = 8
FP4_IN_FP32 = 8
batch_val = 1
x1_dim_num = x1.dim()
x2_dim_num = x2.dim()
out_dim_num = max(x1_dim_num, x2_dim_num)
shape_long = x1 if x1_dim_num > x2_dim_num else x2
shape_short = x2 if x1_dim_num > x2_dim_num else x1
vaild_offset = out_dim_num - min(x1_dim_num, x2_dim_num)
is_a4w4 = x1.dtype == torch.int32 and x2.dtype == torch.int32
is_mxfp4_valid = x1_dtype == torch_npu.float4_e2m1fn_x2 and x2_dtype == torch_npu.float4_e2m1fn_x2
is_a8w4_int = x1.dtype == torch.int8 and x2.dtype == torch.int32
is_a8w4_float = x1.dtype == torch.float8_e4m3fn and (x2_dtype == torch_npu.float4_e2m1fn_x2 or x2.dtype == torch.float32)
dim_list = []
transpose_x1 = False
transpose_x2 = False
if is_a8w4_int:
dim_list = [x1.shape[0], x2.shape[1] * INT4_IN_INT32]
transpose_x2 = False
else:
for i in range(0, out_dim_num - 2):
short_dim = 1 if i < vaild_offset else shape_short.size(i - vaild_offset)
long_dim = shape_long.size(i)
torch._check(
not (short_dim > 1 and long_dim > 1 and short_dim != long_dim),
lambda: "the batch shape cannot be broadcast" + ops_error(ErrCode.VALUE),
)
cur_batch_val = max(short_dim, long_dim)
batch_val = batch_val * cur_batch_val
dim_list.append(cur_batch_val)
if is_mxfp4_valid:
FP4_IN_INT8 = 2
transpose_x1 = is_transpose_last_two_dims(x1)
transpose_x2 = is_transpose_last_two_dims(x2)
x1_size_last_second = x1.size(x1_dim_num - 2)
x2_size_last = x2.size(x2_dim_num - 1)
real_m = x1_size_last_second if not transpose_x1 else x1_size_last_second * FP4_IN_INT8
real_n = x2_size_last if transpose_x2 else x2_size_last * FP4_IN_INT8
dim_list.append(real_m)
dim_list.append(real_n)
else:
dimm = x1.size(x1.dim() - 2)
transpose_x2 = x1.size(x1.dim() - 1) == x2.size(x2.dim() - 2)
dimn = x2.size(x2.dim() - 1)
if (is_a4w4 and not transpose_x2):
dimn = x2.size(x2.dim() - 1) * INT4_IN_INT32
elif (is_a8w4_float and x2.dtype == torch.float32 and pertoken_scale is None):
dimn = x2.size(x2.dim() - 1) * FP4_IN_FP32
dim_list.append(dimm)
dim_list.append(dimn)
if bias is not None:
if bias.dim() == 3:
torch._check(
len(dim_list) == 3,
lambda: "when bias dim is 3, out dim need to be 3" + ops_error(ErrCode.TYPE),
)
bias_shape_check(x2, bias, batch_val, is_a4w4, is_a8w4_float, transpose_x2, is_mxfp4_valid)
quant_matmul_scale_offset_out_check(scale, offset, pertoken_scale, output_dtype, is_a4w4)
quant_matmul_extra_dtype_check(x1, x2, scale, pertoken_scale,
x1_dtype, x2_dtype, scale_dtype, is_a8w4_float, pertoken_scale_dtype)
quant_matmul_group_sizes_check(x1, x2, scale, pertoken_scale, group_sizes,
x1_dtype, x2_dtype, scale_dtype, pertoken_scale_dtype, is_a8w4_float)
quant_matmul_dtype_check(x1, x2, scale, offset, pertoken_scale, bias, output_dtype, is_a4w4, is_a8w4_int, is_a8w4_float, y_scale)
quant_matmul_shape_check(x1, x2, scale, offset, pertoken_scale, is_a4w4, transpose_x1, transpose_x2, is_a8w4_int, is_a8w4_float, group_sizes, is_mxfp4_valid)
tensor_dtype = torch.int8
if output_dtype is not None:
tensor_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(output_dtype)
if tensor_dtype is None or (tensor_dtype not in TORCH_DTYPE_MAP.keys() and tensor_dtype != torch.uint8):
raise RuntimeError("Not supported output dtype is " + npu_dtype_to_str(output_dtype))
return shape_long.new_empty(tuple(dim_list), dtype=tensor_dtype)
@impl(m, "npu_matmul_compress_dequant")
def npu_matmul_compress_dequant_meta(x1, x2, compress_index, bias, scale, *, offsetW=None, offsetX=None):
torch._check(
x1.dim() == 2,
lambda: "the x1 dim support only 2" + ops_error(ErrCode.VALUE),
)
torch._check(
x1.dtype == torch.int8,
lambda: "the x1 dtype support only int8" + ops_error(ErrCode.VALUE),
)
torch._check(
x2.dim() == 1,
lambda: "the x2 dim support only 1" + ops_error(ErrCode.VALUE),
)
torch._check(
x2.dtype == torch.int8,
lambda: "the x2 dtype support only int8" + ops_error(ErrCode.VALUE),
)
torch._check(
compress_index.dim() == 1,
lambda: "the compress_index dim support only 1" + ops_error(ErrCode.VALUE),
)
torch._check(
compress_index.dtype == torch.int8,
lambda: "the compress_index dtype support only int8" + ops_error(ErrCode.VALUE),
)
torch._check(
bias.dim() == 2,
lambda: "the bias dim support only 2" + ops_error(ErrCode.VALUE),
)
torch._check(
bias.dtype == torch.int32,
lambda: "the bias dtype support only int32" + ops_error(ErrCode.VALUE),
)
torch._check(
scale.dim() == 2,
lambda: "the scale dim support only 2" + ops_error(ErrCode.VALUE),
)
torch._check(
scale.dtype == torch.uint64,
lambda: "the scale dtype support only uint64" + ops_error(ErrCode.VALUE),
)
dim_list = (x1.shape[0], bias.shape[1])
return torch.empty(dim_list, dtype=torch.float16, device='meta')
@impl(m, "npu_quant_matmul_dequant")
def npu_quant_matmul_dequant_meta(x, quantized_weight, weight_scale, *,
bias=None, x_scale=None, x_offset=None, smooth_scale=None, quant_mode="pertoken"):
torch._check(
x.dim() == 2,
lambda: "the x dim support only 2" + ops_error(ErrCode.VALUE),
)
torch._check(
x.dtype == torch.float16,
lambda: "the x dtype support only float16" + ops_error(ErrCode.VALUE),
)
torch._check(
quantized_weight.dim() == 2,
lambda: "the quantized_weight dim support only 2" + ops_error(ErrCode.VALUE),
)
torch._check(
quantized_weight.dtype == torch.int8,
lambda: "the quantized_weight dtype support only int8" + ops_error(ErrCode.VALUE),
)
torch._check(
weight_scale.dim() == 1,
lambda: "the weight_scale dim support only 1" + ops_error(ErrCode.VALUE),
)
torch._check(
weight_scale.dtype == torch.float,
lambda: "the weight_scale dtype support only float" + ops_error(ErrCode.VALUE),
)
torch._check(
x.shape[1] == quantized_weight.shape[1],
lambda: "x shape[1] not equal to quantized_weight shape[1]" + ops_error(ErrCode.VALUE),
)
torch._check(
weight_scale.shape[0] == quantized_weight.shape[0],
lambda: "weight_scale shape[0] not equal to quantized_weight shape[0]" + ops_error(ErrCode.VALUE),
)
return torch.empty((x.shape[0], weight_scale.shape[0]), dtype=x.dtype, device='meta')
@impl(m, "npu_quant_matmul_reduce_sum")
def npu_quant_matmul_reduce_sum_meta(x1, x2, *, x1_scale=None, x2_scale=None):
torch._check(x1.dim() == 3, lambda: f"x1 dim must be 3, but got {x.dim()}.")
torch._check(x2.dim() == 3, lambda: f"x2 dim must be 3, but got {w.dim()}.")
torch._check(x1.size(2) == x2.size(1), lambda: f"K dim of x1 must be same as x2.")
torch._check(x1_scale is not None, lambda: f"x1_scale should not be None.")
torch._check(x1_scale.dim() == 2, lambda: f"x1_scale dim must be 2, but got {x1_scale.dim()}.")
torch._check(x2_scale is not None, lambda: f"x2_scale should not be None.")
torch._check(x2_scale.dim() == 1, lambda: f"x2_scale dim must be 1, but got {x2_scale.dim()}.")
dst_shape = (x1.size(1), x2.size(2))
return torch.empty(dst_shape, dtype=torch.bfloat16, device=x1.device)
@impl(m, "npu_quant_matmul_gelu")
def npu_quant_matmul_gelu_meta(x1, x2, x1_scale, x2_scale, *, bias=None, approximate="gelu_erf"):
INT4_IN_INT32 = 8
LAST_SECOND_DIM_INDEX = 2
torch._check(
approximate in ["gelu_tanh", "gelu_erf"],
lambda: f"approximate must be 'gelu_tanh' or 'gelu_erf', but got {approximate} {ops_error(ErrCode.PARAM)}",
)
is_a4w4 = ((x1.dtype == torch.int32 or x1.dtype == torch.quint4x2) and
(x2.dtype == torch.int32 or x2.dtype == torch.quint4x2))
is_a8w8 = (x1.dtype == torch.int8 and x2.dtype == torch.int8)
torch._check(
is_a4w4 or is_a8w8,
lambda: f"Only A4W4 (int4/int32) or A8W8 (int8) quantization is supported, "
f"but got x1.dtype={x1.dtype}, x2.dtype={x2.dtype} {ops_error(ErrCode.TYPE)}",
)
torch._check(x1_scale is not None, lambda: f"x1_scale should not be None.")
torch._check(x2_scale is not None, lambda: f"x2_scale should not be None.")
torch._check(x1_scale.dim() == 1, lambda: f"x1_scale dim must be 1, but got {x1_scale.dim()}.")
torch._check(x2_scale.dim() == 1, lambda: f"x2_scale dim must be 1, but got {x2_scale.dim()}.")
x1_dim_num = x1.dim()
x2_dim_num = x2.dim()
out_dim_num = max(x1_dim_num, x2_dim_num)
shape_long = x1 if x1_dim_num > x2_dim_num else x2
shape_short = x2 if x1_dim_num > x2_dim_num else x1
valid_offset = out_dim_num - min(x1_dim_num, x2_dim_num)
batch_val = 1
dim_list = []
for i in range(0, out_dim_num - LAST_SECOND_DIM_INDEX):
short_dim = 1 if i < valid_offset else shape_short.size(i - valid_offset)
long_dim = shape_long.size(i)
torch._check(
not (short_dim > 1 and long_dim > 1 and short_dim != long_dim),
lambda: "the batch shape cannot be broadcast" + ops_error(ErrCode.VALUE),
)
cur_batch_val = max(short_dim, long_dim)
batch_val = batch_val * cur_batch_val
dim_list.append(cur_batch_val)
x1_m_dim = x1.size(x1_dim_num - LAST_SECOND_DIM_INDEX)
x1_k_dim = x1.size(x1_dim_num - 1)
x2_k_dim = x2.size(x2_dim_num - LAST_SECOND_DIM_INDEX)
x2_n_dim = x2.size(x2_dim_num - 1)
if is_a4w4:
if x1.dtype == torch.int32 and x2.dtype == torch.int32:
torch._check(
x1_k_dim == x2_k_dim,
lambda: f"A4W4 (int32): k dim of x1 ({x1_k_dim}) must equal k dim of x2 ({x2_k_dim}) {ops_error(ErrCode.VALUE)}",
)
elif x1.dtype == torch.quint4x2 and x2.dtype == torch.quint4x2:
torch._check(
x1_k_dim == x2_k_dim,
lambda: f"A4W4 (quint4x2): k dim of x1 ({x1_k_dim}) must equal k dim of x2 ({x2_k_dim}) {ops_error(ErrCode.VALUE)}",
)
elif x1.dtype == torch.int32 and x2.dtype == torch.quint4x2:
torch._check(
x1_k_dim * INT4_IN_INT32 == x2_k_dim,
lambda: f"A4W4 (int32/quint4x2): k dim of x1 ({x1_k_dim}) * 8 must equal k dim of x2 ({x2_k_dim}) {ops_error(ErrCode.VALUE)}",
)
elif x1.dtype == torch.quint4x2 and x2.dtype == torch.int32:
torch._check(
x1_k_dim == x2_k_dim * INT4_IN_INT32,
lambda: f"A4W4 (quint4x2/int32): k dim of x1 ({x1_k_dim}) must equal k dim of x2 ({x2_k_dim}) * 8 {ops_error(ErrCode.VALUE)}",
)
if x2.dtype == torch.int32:
x2_n_dim = x2_n_dim * INT4_IN_INT32
else:
torch._check(
x1_k_dim == x2_k_dim,
lambda: f"A8W8: k dim of x1 ({x1_k_dim}) must equal k dim of x2 ({x2_k_dim}) {ops_error(ErrCode.VALUE)}",
)
dim_list.append(x1_m_dim)
dim_list.append(x2_n_dim)
torch._check(
x1_scale.size(0) == x1_m_dim,
lambda: f"x1_scale size(0) must equal to x1's m dimension ({x1_m_dim}), but got {x1_scale.size(0)} {ops_error(ErrCode.VALUE)}",
)
torch._check(
x2_scale.size(0) == 1 or x2_scale.size(0) == x2_n_dim,
lambda: f"x2_scale size(0) must be 1 or equal to x2's n dimension ({x2_n_dim}), but got {x2_scale.size(0)} {ops_error(ErrCode.VALUE)}",
)
if bias is not None:
if is_a4w4:
torch._check(
bias.dim() == 1,
lambda: f"A4W4 quantization only supports 1D bias, but got bias.dim()={bias.dim()} {ops_error(ErrCode.VALUE)}",
)
torch._check(
bias.size(0) == x2_n_dim,
lambda: f"bias size(0) must equal to x2's n dimension ({x2_n_dim}), but got {bias.size(0)} {ops_error(ErrCode.VALUE)}",
)
else:
torch._check(
bias.dim() == 1 or bias.dim() == 3,
lambda: f"A8W8 quantization supports 1D or 3D bias, but got bias.dim()={bias.dim()} {ops_error(ErrCode.VALUE)}",
)
if bias.dim() == 1:
torch._check(
bias.size(0) == x2_n_dim,
lambda: f"bias size(0) must equal to x2's n dimension ({x2_n_dim}), but got {bias.size(0)} {ops_error(ErrCode.VALUE)}",
)
else:
torch._check(
len(dim_list) == 3,
lambda: f"when bias dim is 3, out dim need to be 3 {ops_error(ErrCode.TYPE)}",
)
torch._check(
bias.size(0) == batch_val and bias.size(1) == 1 and bias.size(2) == x2_n_dim,
lambda: f"bias shape must be ({batch_val}, 1, {x2_n_dim}), but got {tuple(bias.shape)} {ops_error(ErrCode.VALUE)}",
)
output_dtype = torch.bfloat16 if x2_scale.dtype == torch.bfloat16 else torch.float16
output_device = 'meta' if x1.device.type == 'meta' else x1.device
return torch.empty(tuple(dim_list), dtype=output_dtype, device=output_device)
@impl(m, "npu_quant_grouped_matmul_dequant")
def npu_quant_grouped_matmul_dequant_meta(x, quantized_weight, weight_scale, group_list, *,
bias=None, x_scale=None, x_offset=None, smooth_scale=None, quant_mode="pertoken"):
torch._check(
x.dim() == 2,
lambda: "the x dim support only 2" + ops_error(ErrCode.VALUE),
)
torch._check(
x.dtype == torch.float16,
lambda: "the x dtype support only float16" + ops_error(ErrCode.VALUE),
)
torch._check(
quantized_weight.dim() == 3,
lambda: "the quantized_weight dim support only 3" + ops_error(ErrCode.VALUE),
)
torch._check(
quantized_weight.dtype == torch.int8,
lambda: "the quantized_weight dtype support only int8" + ops_error(ErrCode.VALUE),
)
torch._check(
weight_scale.dim() == 2,
lambda: "the weight_scale dim support only 2" + ops_error(ErrCode.VALUE),
)
torch._check(
weight_scale.dtype == torch.float,
lambda: "the weight_scale dtype support only float" + ops_error(ErrCode.VALUE),
)
torch._check(
x.shape[1] == quantized_weight.shape[2],
lambda: "x shape[1] not equal to quantized_weight shape[1]" + ops_error(ErrCode.VALUE),
)
torch._check(
weight_scale.shape[0] == quantized_weight.shape[0],
lambda: "weight_scale shape[0] not equal to quantized_weight shape[0]" + ops_error(ErrCode.VALUE),
)
torch._check(
weight_scale.shape[1] == quantized_weight.shape[1],
lambda: "weight_scale shape[1] not equal to quantized_weight shape[1]" + ops_error(ErrCode.VALUE),
)
return torch.empty((x.shape[0], weight_scale.shape[1]), dtype=x.dtype, device='meta')
@impl(m, "npu_transpose_batchmatmul")
def npu_transpose_batchmatmul_meta(input_, weight, *, bias=None, scale=None,
perm_x1=None, perm_x2=None, perm_y=None,
batch_split_factor=1):
perm_x1 = perm_x1 or [0, 1, 2]
perm_x2 = perm_x2 or [0, 1, 2]
perm_y = perm_y or [1, 0, 2]
check_perm_x1 = ((perm_x1[0] == 0 and perm_x1[1] == 1 and perm_x1[2] == 2) or
(perm_x1[0] == 1 and perm_x1[1] == 0 and perm_x1[2] == 2))
torch._check(
check_perm_x1,
lambda: "perm_x1 should be [0, 1, 2] or [1, 0, 2]" + ops_error(ErrCode.VALUE),
)
if get_cann_version() >= "8.5.0":
check_perm_x2 = ((perm_x2[0] == 0 and perm_x2[1] == 1 and perm_x2[2] == 2) or
(perm_x2[0] == 0 and perm_x2[1] == 2 and perm_x2[2] == 1))
torch._check(
check_perm_x2,
lambda: "perm_x2 should be [0, 1, 2] or [0, 2, 1]" + ops_error(ErrCode.VALUE),
)
else:
check_perm_x2 = (perm_x2[0] == 0 and perm_x2[1] == 1 and perm_x2[2] == 2)
torch._check(
check_perm_x2,
lambda: "perm_x2 should be [0, 1, 2]" + ops_error(ErrCode.VALUE),
)
check_perm_y = perm_y[0] == 1 and perm_y[1] == 0 and perm_y[2] == 2
torch._check(
check_perm_y,
lambda: "perm_y should be [1, 0, 2]" + ops_error(ErrCode.VALUE),
)
input_dtype_supported_list = [torch.float16, torch.float32, torch.bfloat16]
torch._check(
input_.dtype in input_dtype_supported_list,
lambda: "input's type supported for float16, float32 and bfloat16, but now is " + str(input_.dtype) + ops_error(ErrCode.TYPE),
)
torch._check(
weight.dtype in input_dtype_supported_list,
lambda: "weight's type supported for float16, float32 and bfloat16, but now is " + str(weight.dtype) + ops_error(ErrCode.TYPE),
)
torch._check(
bias is None,
lambda: "The bias is not supported in TransposeBatchMatMul" + ops_error(ErrCode.TYPE),
)
M = input_.size(perm_x1.index(1))
batchM = input_.size(perm_x1.index(0))
N = weight.size(perm_x2.index(2))
dim_list = (M, batchM, N)
dtype = input_.dtype
if scale is not None:
dtype = torch.int8
dim_list = (M, 1, batchM * N)
if batch_split_factor > 1:
dim_list = (batch_split_factor, M, batchM * N // batch_split_factor)
return input_.new_empty(dim_list, dtype=dtype)
@impl(m, "npu_transpose_quant_batchmatmul")
def npu_transpose_quant_batchmatmul_meta(input_, weight, dtype, bias=None, x1_scale=None, x2_scale=None, group_sizes=None,
perm_x1=None, perm_x2=None, perm_y=None, batch_split_factor=1):
M = input_.size(perm_x1.index(1))
batch_m = input_.size(perm_x1.index(0))
N = weight.size(perm_x2.index(2))
dim_list = (M, batch_m, N)
if batch_split_factor > 1:
dim_list = (batch_split_factor, M, batch_m * N // batch_split_factor)
tensor_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dtype)
if tensor_dtype is None or (tensor_dtype != torch.float16 and tensor_dtype != torch.bfloat16):
raise RuntimeError("Not supported output dtype is " + npu_dtype_to_str(dtype))
return input_.new_empty(dim_list, dtype=tensor_dtype)
@impl(m, "npu_trans_quant_param")
def npu_trans_quant_param_meta(scale, offset=None, round_mode=0):
scale_dim_num = scale.dim()
torch._check(
scale_dim_num == 1 or (scale_dim_num == 2 and scale.size(0) == 1),
lambda: "the scale shape support only (1, ) and (1, n)" + ops_error(ErrCode.VALUE),
)
torch._check(
round_mode == 0 or round_mode == 1,
lambda: "round_mode should be 0 or 1, but round_mode is " + int(round_mode) + ops_error(ErrCode.VALUE),
)
output_shape = scale.size()
if scale_dim_num == 1:
scale_first_dim = scale.size(0)
dim_max = scale_first_dim
if offset is not None:
offset_first_dim = offset.size(0)
dim_max = max(dim_max, offset_first_dim)
if offset_first_dim != 1 and scale_first_dim != 1:
torch._check(
offset_first_dim == scale_first_dim,
lambda: "offset first dim should be equal to scale first dim if none of them are equal to one" + ops_error(ErrCode.VALUE),
)
output_shape = (dim_max)
else:
if offset is not None:
torch._check(
scale.size() == offset.size(),
lambda: "when the input shape of scale is (1, n), shape of scale and offset should be equal" + ops_error(ErrCode.VALUE),
)
return scale.new_empty(output_shape, dtype=torch.int64)
@impl(m, "npu_quantize")
def npu_quantize_meta(self, scales, zero_points, dtype, axis=1, div_mode=True):
torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dtype, torch.int8)
if torch_dtype == torch.quint8 or torch_dtype == torch.uint8:
return torch.empty_like(self, dtype=torch.uint8)
elif torch_dtype == torch.qint32 or torch_dtype == torch.int32:
return torch.empty_like(self, dtype=torch.int32)
elif torch_dtype == torch.int8:
return torch.empty_like(self, dtype=torch.int8)
elif torch_dtype == torch.float8_e4m3fn:
return torch.empty_like(self, dtype=torch.float8_e4m3fn)
elif torch_dtype == torch.float8_e5m2:
return torch.empty_like(self, dtype=torch.float8_e5m2)
elif dtype == 290:
return torch.empty_like(self, dtype=torch.bits8)
elif torch_dtype == torch.quint4x2:
dim_num = self.dim()
if self.size(dim_num - 1) % 8:
raise RuntimeError("If dtype is quint4x2, the last dim of input must be divided by 8" +
ops_error(ErrCode.NOT_SUPPORT))
output_shape = []
for dim in range(dim_num - 1):
output_shape.append(self.size(dim))
output_shape.append(self.size(dim_num - 1) // 8)
return self.new_empty(output_shape, dtype=torch.int32)
return torch.empty_like(self, dtype=torch.int8)
@impl(m, "npu_group_quant")
def npu_group_quant_meta(x, scale, group_index, *, offset=None, dst_dtype=None):
if dst_dtype == torch.quint8:
return torch.empty_like(x, dtype=torch.uint8)
elif dst_dtype == torch.qint8:
return torch.empty_like(x, dtype=torch.int8)
elif dst_dtype == torch.quint4x2:
dim_num = x.dim()
if x.size(dim_num - 1) % 8:
raise RuntimeError("If dst_dtype is quint4x2, last dim must be divisible by 8" +
ops_error(ErrCode.NOT_SUPPORT))
output_shape = []
for dim in range(dim_num - 1):
output_shape.append(x.size(dim))
output_shape.append(x.size(dim_num - 1) // 8)
return x.new_empty(output_shape, dtype=torch.int32)
return torch.empty_like(x, dtype=torch.int8)
@impl(m, "npu_dynamic_quant")
def npu_dynamic_quant(input_dummy, *, smooth_scales=None, group_index=None,
dst_type=1, quant_mode="pertoken", dst_type_max=0.0):
dim_num = input_dummy.dim()
scale_shape = []
for dim in range(dim_num - 2):
scale_shape.append(input_dummy.size(dim))
if quant_mode == "perchannel":
scale_shape.append(input_dummy.size(dim_num - 1))
else:
scale_shape.append(input_dummy.size(dim_num - 2))
scale = input_dummy.new_empty(scale_shape, dtype=torch.float32)
if quant_mode == "pertensor":
scale = input_dummy.new_empty([1], dtype=torch.float32)
torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.int8)
if torch_dtype == torch.quint4x2:
if input_dummy.size(dim_num - 1) % 8:
raise RuntimeError("If dst_dtype is quint4x2, the last dim of input must be divisible by 8" +
ops_error(ErrCode.PARAM))
scale_shape.append(input_dummy.size(dim_num - 1) // 8)
output = input_dummy.new_empty(scale_shape, dtype=torch.int32)
elif dst_type == 290:
output = torch.empty_like(input_dummy, dtype=torch.uint8)
elif torch_dtype == torch.float8_e5m2:
output = torch.empty_like(input_dummy, dtype=torch.float8_e5m2)
elif torch_dtype == torch.float8_e4m3fn:
output = torch.empty_like(input_dummy, dtype=torch.float8_e4m3fn)
else:
output = torch.empty_like(input_dummy, dtype=torch.int8)
return (output, scale)
@impl(m, "npu_dynamic_quant_asymmetric")
def npu_dynamic_quant_asymmetric(input_dummy, *, smooth_scales=None, group_index=None,
dst_type=1, quant_mode="pertoken", dst_type_max=0.0):
dim_num = input_dummy.dim()
scale_offset_shape = []
for dim in range(dim_num - 2):
scale_offset_shape.append(input_dummy.size(dim))
if quant_mode == "perchannel":
scale_offset_shape.append(input_dummy.size(dim_num - 1))
else:
scale_offset_shape.append(input_dummy.size(dim_num - 2))
scale = input_dummy.new_empty(scale_offset_shape, dtype=torch.float32)
offset = input_dummy.new_empty(scale_offset_shape, dtype=torch.float32)
if quant_mode == "pertensor":
scale = input_dummy.new_empty([1], dtype=torch.float32)
offset = input_dummy.new_empty([1], dtype=torch.float32)
torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.int8)
if torch_dtype == torch.quint4x2:
if input_dummy.size(dim_num - 1) % 8:
raise RuntimeError("If dst_dtype is quint4x2, the last dim of input must be divisible by 8" +
ops_error(ErrCode.PARAM))
scale_offset_shape.append(input_dummy.size(dim_num - 1) // 8)
output = input_dummy.new_empty(scale_offset_shape, dtype=torch.int32)
elif dst_type == 290:
output = torch.empty_like(input_dummy, dtype=torch.uint8)
elif torch_dtype == torch.float8_e5m2:
output = torch.empty_like(input_dummy, dtype=torch.float8_e5m2)
elif torch_dtype == torch.float8_e4m3fn:
output = torch.empty_like(input_dummy, dtype=torch.float8_e4m3fn)
else:
output = torch.empty_like(input_dummy, dtype=torch.int8)
return (output, scale, offset)
@impl(m, "npu_dynamic_mx_quant")
def npu_dynamic_mx_quant(input_dummy, *, axis=-1, round_mode="rint", dst_type=296, block_size=32,
scale_alg=0, dst_type_max=0.0):
dim_num = input_dummy.dim()
mxscale_shape = []
if axis < -dim_num or axis >= dim_num:
raise RuntimeError("Parameter axis is out of input dimension range [{0}, {1}]".format(-dim_num, dim_num - 1) +
ops_error(ErrCode.PARAM))
if not (block_size % 32 == 0 and block_size > 0 and block_size <= 1024):
raise RuntimeError("Parameter block_size must be divisible by 32 and no greater than 1024, greater than 0" +
ops_error(ErrCode.PARAM))
if scale_alg not in [0, 1, 2]:
raise RuntimeError("Invalid scale_alg value: {scale_alg}. Expected 0 or 1." +
ops_error(ErrCode.PARAM))
axis_change = axis if axis >= 0 else axis + dim_num
for dim in range(dim_num):
mxscale_shape.append(input_dummy.size(dim))
mxscale_shape.append(2)
dim_size = int(math.ceil(mxscale_shape[axis_change] / block_size))
dim_size = (dim_size + 2 - 1) // 2
mxscale_shape[axis_change] = dim_size
torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.int8)
if torch_dtype == torch.float8_e5m2 or dst_type == 291:
output = torch.empty_like(input_dummy, dtype=torch.float8_e5m2)
elif torch_dtype == torch.float8_e4m3fn or dst_type == 292:
output = torch.empty_like(input_dummy, dtype=torch.float8_e4m3fn)
else:
if input_dummy.size(dim_num - 1) % 2:
raise RuntimeError("If output dtype is float4_e2m1 or float4_e1m2, " \
"the last dim of input must be divisible by 2, " +
ops_error(ErrCode.PARAM))
output_shape = []
for dim in range(dim_num - 1):
output_shape.append(input_dummy.size(dim))
output_shape.append(input_dummy.size(dim_num - 1) // 2)
output = input_dummy.new_empty(output_shape, dtype=torch.uint8)
mxscale = input_dummy.new_empty(mxscale_shape, dtype=torch.uint8)
return (output, mxscale)
@impl(m, "npu_dynamic_dual_level_mx_quant")
def npu_dynamic_dual_level_mx_quant(input_dummy, *, smooth_scale=None, round_mode="rint"):
dim_num = input_dummy.dim()
level0_scale_shape = []
level1_scale_shape = []
for dim in range(dim_num):
level0_scale_shape.append(input_dummy.size(dim))
level1_scale_shape.append(input_dummy.size(dim))
level1_scale_shape.append(2)
level0_block_size = 512
level1_block_size = 32
dim0_size = int(math.ceil(level0_scale_shape[dim_num - 1] / level0_block_size))
level0_scale_shape[dim_num - 1] = dim0_size
dim1_size = int(math.ceil(level1_scale_shape[dim_num - 1] / level1_block_size))
dim1_size = (dim1_size + 2 - 1) // 2
level1_scale_shape[dim_num - 1] = dim1_size
if input_dummy.size(dim_num - 1) % 2:
raise RuntimeError("If output dtype is float4_e2m1, " \
"the last dim of input must be divisible by 2, " +
ops_error(ErrCode.PARAM))
output_shape = []
for dim in range(dim_num - 1):
output_shape.append(input_dummy.size(dim))
output_shape.append(input_dummy.size(dim_num - 1) // 2)
output = input_dummy.new_empty(output_shape, dtype=torch.uint8)
level0_scale = input_dummy.new_empty(level0_scale_shape, dtype=torch.float32)
level1_scale = input_dummy.new_empty(level1_scale_shape, dtype=torch.uint8)
return (output, level0_scale, level1_scale)
@impl(m, "npu_grouped_dynamic_mx_quant")
def npu_grouped_dynamic_mx_quant(x, group_index, *, round_mode="rint", dst_type=23, blocksize=32):
if x is None or group_index is None:
raise RuntimeError("Input x and group_index should must not be None" + ops_error(ErrCode.VALUE))
if x.dim() != 2:
raise RuntimeError("Input x must be 2-dimensional, got dimNum " +
str(x.dim()) + ops_error(ErrCode.VALUE))
if group_index.dim() != 1:
raise RuntimeError("Input group_index must be 1-dimensional, got dimNum " +
str(group_index.dim()) + ops_error(ErrCode.VALUE))
if blocksize != 32:
raise RuntimeError("Parameter blocksize only supports 32, got " +
str(blocksize) + ops_error(ErrCode.PARAM))
mxscale_shape = [x.shape[0] // 2 // blocksize + group_index.shape[0], x.shape[-1], 2]
if TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type) == torch.float8_e5m2:
output = torch.empty_like(x, dtype=torch.float8_e5m2)
elif TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type) == torch.float8_e4m3fn:
output = torch.empty_like(x, dtype=torch.float8_e4m3fn)
else:
raise RuntimeError("Parameter dst_type only supports torch.float8_e5m2(23), torch.float8_e4m3fn(24), "
"got " + str(dst_type) + ops_error(ErrCode.PARAM))
mxscale = x.new_empty(mxscale_shape, dtype=torch.uint8)
return (output, mxscale)
@impl(m, "npu_moe_compute_expert_tokens")
def npu_moe_compute_expert_tokens_meta(sorted_experts, num_experts=1):
out = torch.zeros(num_experts, dtype=torch.int32, device='meta')
return torch.empty_like(out)
@impl(m, "npu_anti_quant")
def npu_anti_quant_meta(x, scale, *, offset=None, dst_dtype=None, src_dtype=None):
if dst_dtype is None:
y_dtype = torch.float16
else:
y_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_dtype)
if x.dtype == torch.float8_e5m2 or x.dtype == torch.float8_e4m3fn or x.dtype == torch.uint8:
if scale.dtype != torch.float32 or (offset is not None and offset.dtype != torch.float32):
raise RuntimeError("When x datatype is hifloat8, float8_e5m2 or float8_e4m3fn, scale_dtype and offset_dtype is only support float32" +
ops_error(ErrCode.NOT_SUPPORT))
if x.dtype == torch.int32:
x_shape = x.size()
if len(x_shape) == 0:
raise RuntimeError("Not supported for x is scalar when x dtype is int32" + ops_error(ErrCode.NOT_SUPPORT))
y_shape = (*(x_shape[:-1]), x_shape[-1] * 8)
y = x.new_empty(y_shape, dtype=y_dtype)
return torch.empty_like(y)
else:
return torch.empty_like(x, dtype=y_dtype)
@impl(m, "npu_kronecker_quant")
def npu_kronecker_quant_meta(x, kronecker_p1, kronecker_p2, clip_ratio=1.0, dst_dtype=None):
if dst_dtype is None:
dst_dtype = torch.int32
if dst_dtype != torch.int32 and dst_dtype != torch_npu.float4_e2m1fn_x2:
raise RuntimeError("the dtype of dst_dtype must be int32, or mxfp4" + ops_error(ErrCode.NOT_SUPPORT))
dim_num = x.dim()
if (dst_dtype == torch_npu.float4_e2m1fn_x2):
if dim_num != 3:
raise RuntimeError("the dim num of input x must be 3" + ops_error(ErrCode.NOT_SUPPORT))
output_shape = [x.size(0), x.size(dim_num - 1) * x.size(dim_num - 2) // 2]
align_base = 64
align_size = (x.size(dim_num - 1) * x.size(dim_num - 2) + align_base - 1) // align_base
scale_shape = [x.size(0), align_size, 2]
return x.new_empty(output_shape, dtype=torch.uint8), x.new_empty(scale_shape, dtype=torch.uint8)
else:
if x.size(dim_num - 1) % 8:
raise RuntimeError("last dim of input x must be divisible by 8" + ops_error(ErrCode.NOT_SUPPORT))
output_shape = []
for dim in range(dim_num - 1):
output_shape.append(x.size(dim))
if dst_dtype == torch.int32:
output_shape.append(x.size(dim_num - 1) // 8)
scale_shape = []
scale_shape.append(x.size(0))
return x.new_empty(output_shape, dtype=torch.int32), x.new_empty(scale_shape, dtype=torch.float32)
@impl(m, "npu_kv_rmsnorm_rope_cache")
def npu_kv_rmsnorm_rope_cache_meta(kv, gamma, cos, sin, index, k_cache, ckv_cache, *, k_rope_scale=None,
c_kv_scale=None, k_rope_offset=None, c_kv_offset=None, v=None, epsilon=1e-5,
cache_mode='Norm', is_output_kv=False):
if kv.dim() != 4:
raise RuntimeError("4D tensor expected for input kv" + ops_error(ErrCode.PARAM))
if v is not None and v.dim() > 0:
if v.dtype != kv.dtype:
raise RuntimeError("v MUST have same data type as kv!" + ops_error(ErrCode.PARAM))
if v.dim() != 4:
raise RuntimeError("4D tensor expected for input v" + ops_error(ErrCode.PARAM))
if v.size(0) != kv.size(0) or v.size(1) != kv.size(1) or v.size(2) != kv.size(2):
raise RuntimeError("v MUST have same token shape as kv!" + ops_error(ErrCode.PARAM))
if gamma.dim() != 1:
raise RuntimeError("1D tensor expected for input gamma" + ops_error(ErrCode.PARAM))
if cos.dim() != 4:
raise RuntimeError("4D tensor expected for input cos" + ops_error(ErrCode.PARAM))
k_rope_size = []
c_kv_size = []
for i in range(kv.dim() - 1):
k_rope_size.append(kv.size(i))
c_kv_size.append(kv.size(i))
if v is None:
k_rope_size.append(cos.size(3))
c_kv_size.append(gamma.size(0))
else:
k_rope_size.append(kv.size(3))
c_kv_size.append(v.size(3))
return (torch.empty_like(k_cache), torch.empty_like(ckv_cache),
torch.empty(k_rope_size, dtype=kv.dtype, device=kv.device),
torch.empty(c_kv_size, dtype=kv.dtype, device=kv.device))
@impl(m, "npu_kv_rmsnorm_rope_cache_v2")
def npu_kv_rmsnorm_rope_cache_v2_meta(kv, gamma, cos, sin, index, k_cache, ckv_cache, *, k_rope_scale=None,
c_kv_scale=None, k_rope_offset=None, c_kv_offset=None, v=None, epsilon=1e-5,
cache_mode='Norm', is_output_kv=False, k_cache_dtype=None, ckv_cache_dtype=None):
if kv.dim() != 4:
raise RuntimeError("4D tensor expected for input kv" + ops_error(ErrCode.PARAM))
if v is not None and v.dim() > 0:
if v.dtype != kv.dtype:
raise RuntimeError("v MUST have same data type as kv!" + ops_error(ErrCode.PARAM))
if v.dim() != 4:
raise RuntimeError("4D tensor expected for input v" + ops_error(ErrCode.PARAM))
if v.size(0) != kv.size(0) or v.size(1) != kv.size(1) or v.size(2) != kv.size(2):
raise RuntimeError("v MUST have same token shape [B,N,S] as kv!" + ops_error(ErrCode.PARAM))
if gamma.dim() != 1:
raise RuntimeError("1D tensor expected for input gamma" + ops_error(ErrCode.PARAM))
if cos.dim() != 4:
raise RuntimeError("4D tensor expected for input cos" + ops_error(ErrCode.PARAM))
k_rope_size = []
c_kv_size = []
for i in range(kv.dim() - 1):
k_rope_size.append(kv.size(i))
c_kv_size.append(kv.size(i))
if v is None:
k_rope_size.append(cos.size(3))
c_kv_size.append(gamma.size(0))
else:
k_rope_size.append(kv.size(3))
c_kv_size.append(v.size(3))
return (torch.empty(k_rope_size, dtype=kv.dtype, device=kv.device),
torch.empty(c_kv_size, dtype=kv.dtype, device=kv.device))
@impl(m, "npu_kv_rmsnorm_rope_cache_v2_functional")
def npu_kv_rmsnorm_rope_cache_v2_functional_meta(kv, gamma, cos, sin, index, k_cache, ckv_cache, *,
k_rope_scale=None, c_kv_scale=None, k_rope_offset=None,
c_kv_offset=None, v=None, epsilon=1e-5, cache_mode='Norm',
is_output_kv=False, k_cache_dtype=None, ckv_cache_dtype=None):
if kv.dim() != 4:
raise RuntimeError("4D tensor expected for input kv" + ops_error(ErrCode.PARAM))
if v is not None and v.dim() > 0:
if v.dtype != kv.dtype:
raise RuntimeError("v MUST have same data type as kv!" + ops_error(ErrCode.PARAM))
if v.dim() != 4:
raise RuntimeError("4D tensor expected for input v" + ops_error(ErrCode.PARAM))
if v.size(0) != kv.size(0) or v.size(1) != kv.size(1) or v.size(2) != kv.size(2):
raise RuntimeError("v MUST have same token shape as kv!" + ops_error(ErrCode.PARAM))
if gamma.dim() != 1:
raise RuntimeError("1D tensor expected for input gamma" + ops_error(ErrCode.PARAM))
if cos.dim() != 4:
raise RuntimeError("4D tensor expected for input cos" + ops_error(ErrCode.PARAM))
k_rope_size = []
c_kv_size = []
for i in range(kv.dim() - 1):
k_rope_size.append(kv.size(i))
c_kv_size.append(kv.size(i))
if v is None:
k_rope_size.append(cos.size(3))
c_kv_size.append(gamma.size(0))
else:
k_rope_size.append(kv.size(3))
c_kv_size.append(v.size(3))
return (torch.empty(k_rope_size, dtype=kv.dtype, device=kv.device),
torch.empty(c_kv_size, dtype=kv.dtype, device=kv.device),
torch.empty_like(k_cache), torch.empty_like(ckv_cache))
@impl(m, "npu_qkv_rms_norm_rope_cache")
def npu_qkv_rms_norm_rope_cache_meta(qkv, q_gamma, k_gamma, cos, sin, index, q_out, k_cache, v_cache, qkv_size, head_nums,
*, k_scale=None, v_scale=None, k_offset=None, v_offset=None, epsilon=1e-6,
cache_mode='PA_NZ', is_output_qkv=False):
if qkv_size is None:
raise RuntimeError("qkv_size must not be None" + ops_error(ErrCode.PARAM))
if head_nums is None:
raise RuntimeError("head_nums must not be None" + ops_error(ErrCode.PARAM))
if len(qkv_size) != 4:
raise RuntimeError("qkv_size must be length 4 [B, S, N, D]" + ops_error(ErrCode.PARAM))
if len(head_nums) != 3:
raise RuntimeError("head_nums must be length 3 [n_q, n_k, n_v]" + ops_error(ErrCode.PARAM))
if qkv.dim() != 2:
raise RuntimeError("2D tensor expected for input qkv" + ops_error(ErrCode.PARAM))
if q_gamma.dim() != 1:
raise RuntimeError("1D tensor expected for input q_gamma" + ops_error(ErrCode.PARAM))
if k_gamma.dim() != 1:
raise RuntimeError("1D tensor expected for input k_gamma" + ops_error(ErrCode.PARAM))
if cos.dim() != 2:
raise RuntimeError("2D tensor expected for input cos" + ops_error(ErrCode.PARAM))
if sin.dim() != 2:
raise RuntimeError("2D tensor expected for input sin" + ops_error(ErrCode.PARAM))
q_out_before_quant_size = []
k_out_before_quant_size = []
v_out_before_quant_size = []
q_out_before_quant_size.append(qkv.size(0))
k_out_before_quant_size.append(qkv.size(0))
v_out_before_quant_size.append(qkv.size(0))
q_out_before_quant_size.append(head_nums[0] * qkv_size[3])
k_out_before_quant_size.append(head_nums[1] * qkv_size[3])
v_out_before_quant_size.append(head_nums[2] * qkv_size[3])
if is_output_qkv:
return (torch.empty(q_out_before_quant_size, dtype=qkv.dtype, device=qkv.device),
torch.empty(k_out_before_quant_size, dtype=qkv.dtype, device=qkv.device),
torch.empty(v_out_before_quant_size, dtype=qkv.dtype, device=qkv.device))
return (torch.empty([], dtype=qkv.dtype, device=qkv.device),
torch.empty([], dtype=qkv.dtype, device=qkv.device),
torch.empty([], dtype=qkv.dtype, device=qkv.device))
@impl(m, "npu_qkv_rms_norm_rope_cache_functional")
def npu_qkv_rms_norm_rope_cache_functional_meta(qkv, q_gamma, k_gamma, cos, sin, index, q_out, k_cache, v_cache, qkv_size, head_nums,
*, k_scale=None, v_scale=None, k_offset=None, v_offset=None, epsilon=1e-6,
cache_mode='PA_NZ', is_output_qkv=False):
if qkv_size is None:
raise RuntimeError("qkv_size must not be None" + ops_error(ErrCode.PARAM))
if head_nums is None:
raise RuntimeError("head_nums must not be None" + ops_error(ErrCode.PARAM))
if len(qkv_size) != 4:
raise RuntimeError("qkv_size must be length 4 [B, S, N, D]" + ops_error(ErrCode.PARAM))
if len(head_nums) != 3:
raise RuntimeError("head_nums must be length 3 [n_q, n_k, n_v]" + ops_error(ErrCode.PARAM))
if qkv.dim() != 2:
raise RuntimeError("2D tensor expected for input qkv" + ops_error(ErrCode.PARAM))
if q_gamma.dim() != 1:
raise RuntimeError("1D tensor expected for input q_gamma" + ops_error(ErrCode.PARAM))
if k_gamma.dim() != 1:
raise RuntimeError("1D tensor expected for input k_gamma" + ops_error(ErrCode.PARAM))
if cos.dim() != 2:
raise RuntimeError("2D tensor expected for input cos" + ops_error(ErrCode.PARAM))
if sin.dim() != 2:
raise RuntimeError("2D tensor expected for input sin" + ops_error(ErrCode.PARAM))
q_out_before_quant_size = []
k_out_before_quant_size = []
v_out_before_quant_size = []
q_out_before_quant_size.append(qkv.size(0))
k_out_before_quant_size.append(qkv.size(0))
v_out_before_quant_size.append(qkv.size(0))
q_out_before_quant_size.append(head_nums[0] * qkv_size[3])
k_out_before_quant_size.append(head_nums[1] * qkv_size[3])
v_out_before_quant_size.append(head_nums[2] * qkv_size[3])
if is_output_qkv:
return (torch.empty(q_out_before_quant_size, dtype=qkv.dtype, device=qkv.device),
torch.empty(k_out_before_quant_size, dtype=qkv.dtype, device=qkv.device),
torch.empty(v_out_before_quant_size, dtype=qkv.dtype, device=qkv.device),
torch.empty_like(q_out), torch.empty_like(k_cache), torch.empty_like(v_cache))
return (torch.empty([], dtype=qkv.dtype, device=qkv.device),
torch.empty([], dtype=qkv.dtype, device=qkv.device),
torch.empty([], dtype=qkv.dtype, device=qkv.device),
torch.empty_like(q_out), torch.empty_like(k_cache), torch.empty_like(v_cache))
@impl(m, "npu_apply_rotary_pos_emb")
def npu_apply_rotary_pos_emb_meta(query, key, cos, sin, layout=1, rotary_mode='half'):
return (torch.empty_like(query, dtype=query.dtype), torch.empty_like(key, dtype=key.dtype))
@impl(m, "npu_quant_conv2d")
def npu_quant_conv2d(input_, weight, scale, strides, pads, dilations,
groups=1, offset_x=0, round_mode='rint', output_dtype=None,
bias=None, offset=None, input_dtype=None, weight_dtype=None):
input_shape = input_.size()
weight_shape = weight.size()
scale_shape = scale.size()
input_dim = input_.dim()
weight_dim = weight.dim()
scale_dim = scale.dim()
def check_basic_inputs_dim_shape():
torch._check(
input_dim == weight_dim and weight_dim == INPUTS_DIM_LIMIT_QUANTCONV2D,
lambda: "input dim or weight dim is not equal to 4, but now input dim is " + str(input_dim) +
", and weight dim is " + str(weight_dim) + ops_error(ErrCode.VALUE),
)
torch._check(
scale_dim == 1,
lambda: "scale dim is not equal to 1, but now scale dim is " + str(scale_dim) + ops_error(ErrCode.VALUE),
)
torch._check(
input_shape[1] == weight_shape[1] * groups,
lambda: "input cin should equal to weight cin * groups, but now input cin is " + str(input_shape[1]) +
", weight cin is " + str(weight_shape[1]) + ", and groups is " + str(groups) +
ops_error(ErrCode.VALUE),
)
torch._check(
input_shape[1] % groups == 0,
lambda: "input cin should be an integer multiple of groups, but now input cin is " + str(input_shape[1]) +
", and groups is " + str(groups) + ops_error(ErrCode.VALUE),
)
torch._check(
weight_shape[0] % groups == 0,
lambda: "cout should be an integer multiple of groups, but now cout is " + str(weight_shape[0]) +
", and groups is " + str(groups) + ops_error(ErrCode.VALUE),
)
torch._check(
scale_shape[0] == weight_shape[0],
lambda: "scale shape should equal to cout, but now scale shape is " + str(scale_shape[0]) +
", and cout is " + str(weight_shape[0]) + ops_error(ErrCode.VALUE),
)
def check_basic_inputs_dtype():
torch._check(
(input_dtype is not None and weight_dtype is not None) or (input_dtype is None and weight_dtype is None),
lambda: "input_dtype and weight_dtype are only support both None or not None, " +
"but got input_dtype: " + str(input_dtype) + " and weight_dtype: " +
str(weight_dtype) + ops_error(ErrCode.TYPE))
if input_dtype is not None:
torch._check((input_dtype == torch_npu.hifloat8 and weight_dtype == torch_npu.hifloat8),
lambda: "input_dtype and weight_dtype are only support torch_npu.hifloat8, " +
"but got input_dtype: " + str(input_dtype) +
" and weight_dtype: " + str(weight_dtype) +
ops_error(ErrCode.TYPE))
if input_dtype is not None:
torch._check(((input_.dtype == torch.int8 or input_.dtype == torch.uint8) and
(weight.dtype == torch.int8 or weight.dtype == torch.uint8)),
lambda: "input and weight tensor dtype must be torch.int8 or torch.uint8 " +
"when input_dtype and weight_dtype is torch_npu.hifloat8, " +
"but got input tensor dtype: " + str(input_.dtype) + " and weight tensor dtype: " +
str(weight.dtype) + ops_error(ErrCode.TYPE))
if input_dtype is None:
torch._check(
((input_.dtype == torch.int8 and weight.dtype == torch.int8) or
(input_.dtype == torch.float8_e4m3fn and weight.dtype == torch.float8_e4m3fn)),
lambda: "input.dtype and weight.dtype should be torch.int8 or torch.float8_e4m3fn " +
"when not enable hifloat8 calculation, but got input.dtype: " + str(input_.dtype) +
" and weight.dtype is " + str(weight.dtype) + ops_error(ErrCode.TYPE)
)
torch._check(
scale.dtype == torch.int64,
lambda: "scale.dtype should be torch.int64, but scale.dtype is " + str(scale.dtype) +
ops_error(ErrCode.TYPE),
)
torch._check(
(output_dtype is not None),
lambda: "output_dtype can not be None " + ops_error(ErrCode.TYPE)
)
if input_.dtype == torch.int8 and input_dtype != torch_npu.hifloat8:
torch._check(
output_dtype == TORCH_DTYPE_MAP[torch.float16],
lambda: "output_dtype should be torch.float16 when input.dtype is torch.int8, but now dtype is " +
str(output_dtype) + ops_error(ErrCode.TYPE),
)
elif (input_.dtype == torch.float8_e4m3fn):
torch._check(
(output_dtype == TORCH_DTYPE_MAP[torch.float16] or
output_dtype == TORCH_DTYPE_MAP[torch.bfloat16] or
output_dtype == TORCH_DTYPE_MAP[torch.float32]),
lambda: "output_dtype should be one of "
"[torch.float16, torch.bfloat16, torch.float32] "
"when input.dtype is torch.float8_e4m3fn, but now output_dtype is " +
str(output_dtype) + ops_error(ErrCode.TYPE),
)
if (input_dtype == torch_npu.hifloat8):
torch._check((output_dtype == torch_npu.hifloat8 or
output_dtype == TORCH_DTYPE_MAP[torch.float16] or
output_dtype == TORCH_DTYPE_MAP[torch.bfloat16] or
output_dtype == TORCH_DTYPE_MAP[torch.float32]),
lambda: "output_dtype should be one of " +
"[torch.float16, torch.bfloat16, torch.float32, torch_npu.hifloat8] " +
"when input_dtype is torch_npu.hifloat8, but now output_dtype is " +
str(output_dtype) + ops_error(ErrCode.TYPE)
)
def check_bias_dim_shape_dtype():
bias_dim = bias.dim()
bias_shape = bias.size()
torch._check(
bias_dim == 1,
lambda: "bias dim is not equal to 1, but now bias dim is " + str(bias_dim) + ops_error(ErrCode.VALUE),
)
if input_.dtype == torch.int8 and input_dtype != torch_npu.hifloat8:
torch._check(
bias.dtype == torch.int32,
lambda: "bias.dtype should be torch.int32 when input.dtype is torch.int8, but bias.dtype is " +
str(bias.dtype) + ops_error(ErrCode.TYPE),
)
elif (input_dtype == torch_npu.hifloat8 or input_.dtype == torch.float8_e4m3fn):
torch._check(
bias.dtype == torch.float32,
lambda: "bias.dtype should be torch.float32 when input_dtype is " +
"torch_npu.hifloat8 or input.dtype is float8_e4m3fn, but bias.dtype is " +
str(bias.dtype) + ops_error(ErrCode.TYPE),
)
torch._check(
bias_shape[0] == weight_shape[0],
lambda: "bias shape should equal to cout, but now bias shape is " + str(bias_shape[0]) + ", and cout is " +
str(weight_shape[0]) + ops_error(ErrCode.VALUE),
)
def check_attrs():
pads_dim = len(pads)
strides_dim = len(strides)
dilations_dim = len(dilations)
torch._check(
pads_dim == ATTR_DIM_LIMIT_QUANTCONV2D and strides_dim == ATTR_DIM_LIMIT_QUANTCONV2D and
dilations_dim == ATTR_DIM_LIMIT_QUANTCONV2D,
lambda: "attrs's dim should be 2, but pads dim is " + str(pads_dim) + ", strides dim is "
+ str(strides_dim) + ", dilations dim is " + str(dilations_dim) + ops_error(ErrCode.VALUE),
)
torch._check(
pads[0] >= 0 and pads[1] >= 0,
lambda: "pads's value should large or equal to 0, but pads is " + str(pads[0]) + ", "
+ str(pads[1]) + ops_error(ErrCode.VALUE),
)
torch._check(
strides[0] > 0 and strides[1] > 0,
lambda: "strides's value should large than 0, but strides is " + str(strides[0]) + ", "
+ str(strides[1]) + ops_error(ErrCode.VALUE),
)
torch._check(
dilations[0] > 0 and dilations[1] > 0,
lambda: "dilations's value should large than 0, but dilations is " + str(dilations[0]) + ", "
+ str(dilations[1]) + ops_error(ErrCode.VALUE),
)
torch._check(
groups >= 1,
lambda: "groups should large than 0, but now " + str(groups) + ops_error(ErrCode.VALUE),
)
torch._check(
offset_x <= 127 and offset_x >= -128,
lambda: "offset_x should be [-128,127], but offset_x is " + str(offset_x) + ops_error(ErrCode.VALUE),
)
check_basic_inputs_dim_shape()
check_basic_inputs_dtype()
if bias is not None:
check_bias_dim_shape_dtype()
check_attrs()
nout = input_shape[0]
cout = weight_shape[0]
hout = (input_shape[2] + pads[0] * 2 - dilations[0] * (weight_shape[2] - 1) - 1) // strides[0] + 1
wout = (input_shape[3] + pads[1] * 2 - dilations[1] * (weight_shape[3] - 1) - 1) // strides[1] + 1
torch._check(
hout > 0 and wout > 0,
lambda: "ho, wo should larger than 0, but now ho is " + str(hout) + ", and wo is " + str(wout) +
ops_error(ErrCode.VALUE),
)
output_dim_list = [nout, cout, hout, wout]
if output_dtype == TORCH_DTYPE_MAP[torch.float16] or \
output_dtype == TORCH_DTYPE_MAP[torch.bfloat16] or \
output_dtype == TORCH_DTYPE_MAP[torch.float32]:
return scale.new_empty(tuple(output_dim_list), dtype=TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[output_dtype])
elif output_dtype == torch_npu.hifloat8:
return scale.new_empty(tuple(output_dim_list), dtype=torch.uint8)
else:
raise RuntimeError("output_dtype should be one of " +
"[torch.float16, torch.bfloat16, torch.float32, torch_npu.hifloat8], but got " +
str(output_dtype))
@impl(m, "npu_linear")
def npu_linear_meta(input_, weight, bias=None):
dimm = input_.size(0)
dimn = weight.size(0)
return input_.new_empty((dimm, dimn))
@impl(m, "npu_moe_finalize_routing")
def npu_moe_finalize_routing_meta(expanded_permuted_rows, skip1, skip2_optional, bias, scales, expanded_src_to_dst_row,
expert_for_source_row, drop_pad_mode=0):
if scales is None:
return torch.empty_like(expanded_permuted_rows, dtype=expanded_permuted_rows.dtype)
dimm = scales.size(0)
if drop_pad_mode == 1 or drop_pad_mode == 3:
dimn = expanded_permuted_rows.size(2)
else:
dimn = expanded_permuted_rows.size(1)
return expanded_permuted_rows.new_empty((dimm, dimn))
has_side_effect(torch.ops.npu.npu_prefetch.default)
@impl(m, "npu_prefetch")
def npu_prefetch_meta(self, dependency, max_size, offset=0):
torch._check(
max_size > 0,
lambda: f"The max_size should be greater than zero, but got {max_size}.",
)
torch._check(
offset >= 0,
lambda: f"The offset should be nonnegative, but got {offset}.",
)
@impl(m, "npu_swiglu")
def npu_swiglu_meta(x, dim=-1):
output_size = []
for i in range(x.dim()):
output_size.append(x.size(i))
output_size[dim] = math.floor(output_size[dim] / 2)
return torch.empty(output_size, dtype=x.dtype, device=x.device)
@impl(m, "npu_swiglu_backward")
def npu_swiglugrad_meta(y, x, dim=-1):
return torch.empty_like(x)
def rope_quant_kvcache(x, cos, k_cache, v_cache, size_splits, kv_output=False):
torch._check(
x.dim() == 3 or x.dim() == 2,
lambda: f"The x's dim should be 2 or 3, but got {x.dim()}.",
)
torch._check(
k_cache.dim() == 4,
lambda: f"The k_cache's dim should be 4, but got {k_cache.dim()}.",
)
num_size_splits = len(size_splits)
torch._check(
num_size_splits == 3,
lambda: f"The size_splits should be 3, but got {num_size_splits}.",
)
torch._check(
size_splits[0] >= 0,
lambda: f"size_splits[0] should not less than 0, but got {size_splits[0]}.",
)
batch = x.size(0)
seqlen = x.size(1)
k_headdim = k_cache.size(2)
hidden_size = k_cache.size(3)
q_headdim = 0
if hidden_size != 0:
q_headdim = size_splits[0] // hidden_size
out_q_size = [batch, seqlen, q_headdim, hidden_size] if x.dim() == 3 else [batch, q_headdim, hidden_size]
out_k_size = [0]
out_v_size = [0]
if kv_output:
out_k_size = [batch, seqlen, k_headdim, hidden_size] if x.dim() == 3 else [batch, k_headdim, hidden_size]
out_v_size = [batch, seqlen, k_headdim, hidden_size] if x.dim() == 3 else [batch, k_headdim, hidden_size]
return (torch.empty(out_q_size, dtype=cos.dtype, device=x.device),
torch.empty(out_k_size, dtype=cos.dtype, device=x.device),
torch.empty(out_v_size, dtype=cos.dtype, device=x.device),
k_cache, v_cache)
@impl(m, "npu_swiglu_quant")
def npu_swiglu_quant_meta(x, smooth_scales=None, offsets=None, group_index=None, activate_left=False, quant_mode=0,
group_list_type=0, dst_type=torch.int8):
y_size = []
scale_size = []
for i in range(x.dim() - 1):
y_size.append(x.size(i))
scale_size.append(x.size(i))
y_size.append(math.floor(x.size(x.dim() - 1) / 2))
return (torch.empty(y_size, dtype=dst_type, device=x.device),
torch.empty(scale_size, dtype=torch.float32, device=x.device))
@impl(m, "npu_dequant_swiglu_quant")
def npu_dequant_swiglu_quant_meta(x, weight_scale=None, activation_scale=None, bias=None, quant_scale=None,
quant_offset=None, group_index=None, activate_left=False, quant_mode=0,
dst_type=None, round_mode=None, activate_dim=None, swiglu_mode=0, clamp_limit=7.0,
glu_alpha=1.702, glu_bias=1.0):
y_size = []
scale_size = []
dst_type = dst_type if dst_type is not None else 1
round_mode = round_mode if round_mode is not None else 0
activate_dim = activate_dim if activate_dim is not None else -1
select_dim = activate_dim if activate_dim >= 0 else activate_dim + x.dim()
for i in range(x.dim()):
if i == select_dim:
y_size.append(x.size(i) // 2)
else:
y_size.append(x.size(i))
for i in range(x.dim() - 1):
if i == select_dim:
scale_size.append(x.size(i) // 2)
else:
scale_size.append(x.size(i))
dst_torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.int8)
if dst_torch_dtype == torch.uint8 and dst_type != torch_npu.hifloat8:
y_size[-1] = y_size[-1] // 2
return (torch.empty(y_size, dtype=dst_torch_dtype, device=x.device),
torch.empty(scale_size, dtype=torch.float32, device=x.device))
@impl(m, "npu_swiglu_mx_quant")
def npu_swiglu_mx_quant_meta(x, group_index=None, activate_dim=-1, activate_left=False,
swiglu_mode=0, clamp_limit=7.0, glu_alpha=1.702, glu_bias=1.0,
group_mode=0, axis=-1, dst_type=296, round_mode="rint",
scale_alg=0, max_dtype_value=0):
activate_dim = activate_dim if activate_dim is not None else -1
select_dim = activate_dim if activate_dim >= 0 else activate_dim + x.dim()
quant_dim = axis if axis >= 0 else axis + x.dim()
y_size = []
scale_size = []
swish_num = 2
block_size = 64
for i in range(x.dim()):
if i == select_dim:
y_size.append(x.size(i) // swish_num)
else:
y_size.append(x.size(i))
for i in range(x.dim()):
if i == select_dim:
scale_size.append(x.size(i) // swish_num)
else:
scale_size.append(x.size(i))
if group_index is None:
quant_size = int(math.ceil(scale_size[quant_dim] / block_size))
else:
if quant_dim == x.dim() - 1:
quant_size = int(math.ceil(scale_size[quant_dim] / block_size))
else:
quant_size = int(math.floor(scale_size[quant_dim] / block_size) + group_index.shape[0])
scale_size[quant_dim] = quant_size
scale_size.append(swish_num)
dst_torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch_npu.uint8)
if dst_torch_dtype == torch.uint8 and dst_type != torch_npu.hifloat8:
y_size[-1] = y_size[-1] // swish_num
return (torch.empty(y_size, dtype=dst_torch_dtype, device=x.device),
torch.empty(scale_size, dtype=torch.uint8, device=x.device))
@impl(m, "npu_clipped_swiglu")
def npu_clipped_swiglu_meta(x, group_index=None, dim=-1, alpha=1.702, limit=7.0, bias=1.0, interleaved=True):
output_size = []
for i in range(x.dim()):
output_size.append(x.size(i))
output_size[dim] = math.floor(output_size[dim] / 2)
return torch.empty(output_size, dtype=x.dtype, device=x.device)
@impl(m, "npu_fused_causal_conv1d_functional")
def npu_fused_causal_conv1d_functional_meta(x, weight, conv_states, *, query_start_loc=None, cache_indices=None,
initial_state_mode=None, bias=None, num_accepted_tokens=None,
activation_mode="None", pad_slot_id=-1, run_mode=0, residual_connection=0):
return torch.empty_like(x, dtype=x.dtype), torch.empty_like(conv_states, dtype=conv_states.dtype)
@impl(m, "npu_fused_causal_conv1d")
def npu_fused_causal_conv1d_meta(x, weight, conv_states, *, query_start_loc=None, cache_indices=None,
initial_state_mode=None, bias=None, num_accepted_tokens=None,
activation_mode="None", pad_slot_id=-1, run_mode=0, residual_connection=0):
return torch.empty_like(x, dtype=x.dtype)
@impl(m, "npu_dequant_rope_quant_kvcache")
def npu_dequant_rope_quant_kvcache_meta(x, cos, sin, k_cache, v_cache, indices, scale_k, scale_v, size_splits, *,
offset_k=None, offset_v=None, weight_scale=None, activation_scale=None,
bias=None, quant_mode=0, input_layout="BSND", kv_output=False,
cache_mode="contiguous"):
torch._check(
x.dtype == torch.int32,
lambda: f"The x's dtype should be Int32, but got {x.dtype}.",
)
return rope_quant_kvcache(x, cos, k_cache, v_cache, size_splits, kv_output=kv_output)
@impl(m, "npu_rope_quant_kvcache")
def npu_rope_quant_kvcache_meta(x, cos, sin, k_cache, v_cache, indices, scale_k, scale_v, size_splits, *, offset_k=None,
offset_v=None, quant_mode=0, input_layout="BSND", kv_output=False, cache_mode="contiguous"):
return rope_quant_kvcache(x, cos, k_cache, v_cache, size_splits, kv_output=kv_output)
@impl(m, "npu_dequant_bias")
def npu_dequant_bias_meta(x, weight_scale, activation_scale, bias, output_dtype=None):
if output_dtype is None:
output_dtype = torch.float16
if output_dtype != torch.float16 and output_dtype != torch.bfloat16:
raise RuntimeError("Only supported output_dtype is float16 and bfloat16" + ops_error(ErrCode.NOT_SUPPORT))
return torch.empty_like(x, dtype=output_dtype)
@impl(m, "npu_interleave_rope")
def npu_interleave_rope_meta(x, cos, sin):
return torch.empty_like(x)
@impl(m, "npu_batch_gather_matmul")
def npu_batch_gather_matmul_meta(self, x, weight_b, indices, weight_a=None,
layer_idx=0, scale=1e-3, y_offset=0, y_slice_size=-1):
return torch.empty_like(self, dtype=self.dtype)
@impl(m, "npu_batch_gather_matmul_")
def npu_batch_gather_matmul__meta(self, x, weight_b, indices, weight_a=None,
layer_idx=0, scale=1e-3, y_offset=0, y_slice_size=-1):
return self
@impl(m, "npu_gather_backward")
def npu_gather_backward__meta(grad, self_size, dim, index, sparse_grad):
return torch.empty(self_size, dtype=grad.dtype, device=grad.device)
@impl(m, "npu_moe_token_permute_with_routing_map")
def npu_moe_token_permute_with_routing_map_meta(tokens, routing_map, *, probs=None, num_out_tokens=None, drop_and_pad=False):
if num_out_tokens is None:
num_out_tokens = tokens.size(0)
dim = 1 if drop_and_pad else 0
out_token = num_out_tokens // routing_map.size(dim) * routing_map.size(dim)
output_size_0 = (out_token, tokens.size(1))
output_size_1 = (out_token,)
output_dtype_0 = tokens.dtype
output_dtype_1 = torch.int32
out1 = torch.empty(output_size_0, dtype=output_dtype_0, device=tokens.device)
out3 = torch.empty(output_size_1, dtype=output_dtype_1, device=tokens.device)
out2 = None
if probs is not None:
out2 = torch.empty(output_size_1, dtype=probs.dtype, device=tokens.device)
return out1, out2, out3
@impl(m, "npu_moe_token_permute_with_routing_map_grad")
def npu_moe_token_permute_with_routing_map_grad_meta(permuted_token_out_grad, probs_grad, sorted_indices, routing_map, experts_num, tokens_num, drop_and_pad):
output_size_0 = (tokens_num, permuted_token_out_grad.size(1))
output_size_1 = (tokens_num, experts_num)
output_dtype_0 = permuted_token_out_grad.dtype
out1 = torch.empty(output_size_0, dtype=output_dtype_0, device=permuted_token_out_grad.device)
out2 = None
if probs_grad is not None:
out2 = torch.empty(output_size_1, dtype=probs_grad.dtype, device=permuted_token_out_grad.device)
return out1, out2
@impl(m, "npu_moe_re_routing")
def npu_moe_re_routing_meta(tokens, expert_token_num_per_rank, per_token_scales=None, expert_token_num_type=1, idx_type=0):
permute_tokens_size = []
permute_per_token_scales_size = []
permute_token_idx_size = []
expert_token_num_size = []
for i in range(tokens.dim()):
permute_tokens_size.append(tokens.size(i))
if per_token_scales is None:
permute_per_token_scales_size.append(tokens.size(0))
permute_per_token_scales_dtype = torch.float32
else:
for i in range(per_token_scales.dim()):
permute_per_token_scales_size.append(per_token_scales.size(i))
permute_per_token_scales_dtype = per_token_scales.dtype
permute_token_idx_size.append(tokens.size(0))
expert_token_num_size.append(expert_token_num_per_rank.size(1))
return (torch.empty(permute_tokens_size, dtype=tokens.dtype, device=tokens.device),
torch.empty(permute_per_token_scales_size, dtype=permute_per_token_scales_dtype, device=tokens.device),
torch.empty(permute_token_idx_size, dtype=torch.int32, device=tokens.device),
torch.empty(expert_token_num_size, dtype=expert_token_num_per_rank.dtype, device=tokens.device))
@impl(m, "npu_attention_worker_combine")
def npu_attention_worker_combine(schedule_context, expert_scales, layer_id, hidden_size, token_dtype=0, need_schedule=0):
y_size = []
next_layer_id_size = []
y_size.append(expert_scales.size(0))
y_size.append(hidden_size)
next_layer_id_size.append(layer_id.size(0))
y_dtype = torch.half
if token_dtype == 1:
y_dtype = torch.bfloat16
return (torch.empty(y_size, dtype=y_dtype, device=schedule_context.device),
torch.empty(next_layer_id_size, dtype=torch.int32, device=schedule_context.device))
@impl(m, "npu_add_rms_norm_quant")
def npu_add_rms_norm_quant(x1, x2, gamma, scales1, zero_points1=None, beta=None, scales2=None, zero_points2=None, axis=-1, epsilon=1e-06, div_mode=True, dst_type=None):
torch._check(
axis == -1,
lambda: f"axis should be -1, but got {axis}.",
)
dst_type = dst_type if dst_type is not None else 1
dst_torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.int8)
return (torch.empty(x1.size(), dtype=dst_torch_dtype, device=x1.device),
torch.empty(x1.size(), dtype=dst_torch_dtype, device=x1.device),
torch.empty(x1.size(), dtype=x1.dtype, device=x1.device))
@impl(m, "npu_attention_update")
def npu_attention_update_meta(lse, local_out, update_type):
ref = local_out[0]
ref_lse = lse[0]
sp = len(lse)
return (torch.empty(ref.size(), dtype=ref.dtype, device=ref.device),
torch.empty(ref_lse.size(), dtype=ref_lse.dtype, device=ref_lse.device))
@impl(m, "npu_mrope")
def npu_mrope_meta(positions, query, key, cos_sin_cache, head_size, *, mrope_section=None, rotary_mode='half', cache_mode='default'):
return (torch.empty_like(query), torch.empty_like(key))
@impl(m, "npu_gather_sparse_index")
def npu_gather_sparse_index(inputs, index):
output_dim = inputs.dim() + index.dim() - 1
torch._check(
output_dim <= NPU_TENSOR_DIM_LIMIT,
lambda: f"input.dim() + index.dim() - 1 must not greater than 8, but got {output_dim}.",
)
output_size = []
input_dim = inputs.dim()
input_size = inputs.size()
if input_dim == 0:
output_size = input_size
return torch.empty(output_size, dtype=inputs.dtype, device=inputs.device)
index_dim = index.dim()
index_size = index.size()
for i in range(index_dim):
output_size.append(index_size[i])
for i in range(1, input_dim):
output_size.append(input_size[i])
return torch.empty(output_size, dtype=inputs.dtype, device=inputs.device)
@impl(m, "npu_top_k_top_p")
def npu_top_k_top_p_meta(logits, p, k):
return torch.empty_like(logits, dtype=logits.dtype)
@impl(m, "npu_moe_token_permute")
def npu_moe_token_permute_meta(tokens, indices, num_out_tokens=None, padded_mode=False):
torch._check(tokens.dim() == 2, lambda: f"The dims of input tokens should be 2 dimensional, but got {tokens.dim()}-dimensional.")
torch._check(indices.dim() == 1 or indices.dim() == 2, lambda: f"The dims of input indices should be 2 or 1 dimensional, but got {indices.dim()}-dimensional.")
num_out_tokens_value = 0 if num_out_tokens is None else num_out_tokens
flatten_size = indices.numel()
if num_out_tokens_value > 0:
actual_num_out_tokens = min(num_out_tokens_value, flatten_size)
else:
actual_num_out_tokens = num_out_tokens_value + flatten_size
output_shape = (actual_num_out_tokens, tokens.size(1))
return torch.empty(output_shape, dtype=tokens.dtype, device=tokens.device), torch.empty(indices.numel(), dtype=torch.int32, device=tokens.device)
@impl(m, "npu_moe_token_unpermute")
def npu_moe_token_unpermute_meta(permuted_tokens, sorted_indices, probs=None, padded_mode=False, restore_shape=None):
DEFAULT_TOPK = 1
if probs is not None:
torch._check(probs.dim() == 2, lambda: f"The dims of input probs should be 2 dimensional, but got {probs.value().dim()}-dimensional.")
torch._check(permuted_tokens.dim() == 2, lambda: f"The dims of input permuted_tokens should be 2 dimensional, but got {permuted_tokens.dim()}-dimensional.")
torch._check(sorted_indices.dim() == 1, lambda: f"The dims of input sorted_indices should be 1 dimensional, but got {sorted_indices.dim()}-dimensional.")
topk = DEFAULT_TOPK if probs is None else probs.size(1)
output_shape = (sorted_indices.size(0) // topk, permuted_tokens.size(-1))
return torch.empty(output_shape, dtype=permuted_tokens.dtype, device=permuted_tokens.device)
@impl(m, "npu_moe_token_permute_grad")
def npu_moe_token_permute_grad_meta(tokens, grad_permuted_tokens, indices, sorted_indices, padded_mode=False):
torch._check(tokens.dim() == 2, lambda: f"The dims of input tokens should be 2 dimensional, but got {tokens.dim()}-dimensional.")
torch._check(grad_permuted_tokens.dim() == 2, lambda: f"The dims of input grad_permuted_tokens should be 2 dimensional, but got {grad_permuted_tokens.dim()}-dimensional.")
torch._check(indices.dim() == 1 or indices.dim() == 2, lambda: f"The dims of input indices should be 2 or 1 dimensional, but got {indices.dim()}-dimensional.")
torch._check(sorted_indices.dim() == 1, lambda: f"The dims of input sorted_indices should be 1 dimensional, but got {sorted_indices.dim()}-dimensional.")
N, D = tokens.shape
return torch.empty((N, D), dtype=tokens.dtype, device=tokens.device)
@impl(m, "npu_moe_token_unpermute_grad")
def npu_moe_token_unpermute_grad_meta(permuted_tokens, grad_unpermuted_tokens, sorted_indices, probs=None, padded_mode=False, restore_shape=None):
torch._check(permuted_tokens.dim() == 2, lambda: f"The dims of input permuted_tokens should be 2 dimensional, but got {permuted_tokens.dim()}-dimensional.")
torch._check(grad_unpermuted_tokens.dim() == 2, lambda: f"The dims of input grad_unpermuted_tokens should be 2 dimensional, but got {grad_unpermuted_tokens.dim()}-dimensional.")
torch._check(sorted_indices.dim() == 1, lambda: f"The dims of input sorted_indices should be 1 dimensional, but got {sorted_indices.dim()}-dimensional.")
grad_permuted_tokens = torch.empty_like(permuted_tokens)
grad_probs = torch.empty_like(probs, dtype=probs.dtype) if probs is not None else None
return grad_permuted_tokens, grad_probs
@impl(m, "npu_grouped_matmul_swiglu_quant")
def npu_grouped_matmul_swiglu_quant_meta(x, weight, group_list, weight_scale, x_scale, *, bias=None, offset=None):
batch_size = x.size(0)
n = weight.size(2)
output_shape = torch.empty([batch_size, n // 2], dtype=torch.int8, device=x.device)
output_scale_shape = torch.empty([batch_size], dtype=torch.float32, device=x.device)
output_offset_shape = torch.empty([], dtype=torch.float32, device=x.device)
return output_shape, output_scale_shape, output_offset_shape
@impl(m, "npu_grouped_matmul_swiglu_quant_v2")
def npu_grouped_matmul_swiglu_quant_v2_meta(x, weight, weight_scale, x_scale, group_list, *, smooth_scale=None,
weight_assist_matrix=None, bias=None, dequant_mode=0, dequant_dtype=0, quant_mode=0, quant_dtype=1,
group_list_type=0, tuning_config=None, x_dtype=None, weight_dtype=None, weight_scale_dtype=None, x_scale_dtype=None):
torch._check(
len(weight) == 1,
lambda: f"The size of weight should be 1, current size is {len(weight)}.",
)
torch._check(
len(weight_scale) == 1,
lambda: f"The size of weight_scale should be 1, current size is {len(weight_scale)}.",
)
if x_dtype is not None:
torch._check(
x_dtype == torch_npu.float4_e2m1fn_x2 or x_dtype == torch_npu.hifloat8,
lambda: "The optional parameter x_dtype only supports torch_npu.float4_e2m1fn_x2, torch_npu.hifloat8, or None, but the actual value is " + npu_dtype_to_str(x_dtype),
)
if weight_dtype is not None:
torch._check(
weight_dtype == torch_npu.float4_e2m1fn_x2 or weight_dtype == torch_npu.hifloat8,
lambda: "The optional parameter weight_dtype only supports torch_npu.float4_e2m1fn_x2, torch_npu.hifloat8, or None, but the actual value is " + npu_dtype_to_str(weight_dtype),
)
if weight_scale_dtype is not None:
torch._check(
weight_scale_dtype == torch_npu.float8_e8m0fnu,
lambda: "The weight_scale_dtype only supports float8_e8m0fnu for now, but the actual value is " + npu_dtype_to_str(weight_scale_dtype),
)
torch._check(x.dtype == torch.float8_e5m2 or x.dtype == torch.float8_e4m3fn or x_dtype == torch_npu.float4_e2m1fn_x2,
lambda: "The x only supports torch.float8_e5m2/torch.float8_e4m3fn/torch_npu.float4_e2m1fn_x2 for now, but the actual value is " + npu_dtype_to_str(x.dtype),
)
torch._check(weight[0].dtype == torch.float8_e5m2 or weight[0].dtype == torch.float8_e4m3fn or weight_dtype == torch_npu.float4_e2m1fn_x2,
lambda: "The weight only supports torch.float8_e5m2/torch.float8_e4m3fn/torch_npu.float4_e2m1fn_x2 for now, but the actual value is " + npu_dtype_to_str(weight[0].dtype),
)
if x_scale_dtype is not None:
torch._check(
x_scale_dtype == torch_npu.float8_e8m0fnu,
lambda: "The x_scale_dtype only supports float8_e8m0fnu for now, but the actual value is " + npu_dtype_to_str(x_scale_dtype),
)
torch._check(quant_dtype == 1 or quant_dtype == TORCH_DTYPE_MAP[torch.float8_e5m2] or quant_dtype == TORCH_DTYPE_MAP[torch.float8_e4m3fn]
or quant_dtype == torch_npu.float4_e2m1fn_x2 or quant_dtype == torch_npu.hifloat8,
lambda: "quant_dtype only supports torch.int8, torch.float8_e5m2, torch.float8_e4m3fn, torch_npu.float4_e2m1fn_x2, torch_npu.hifloat8 for now, but it is " + npu_dtype_to_str(quant_dtype),
)
torch._check(dequant_dtype == TORCH_DTYPE_MAP[torch.int8] or dequant_dtype == TORCH_DTYPE_MAP[torch.float32] or dequant_dtype == TORCH_DTYPE_MAP[torch.bfloat16] or dequant_dtype == TORCH_DTYPE_MAP[torch.float16],
lambda: "dequant_dtype only supports torch.int8, torch.float32, torch.bfloat16, torch.float16 for now, but it is " + npu_dtype_to_str(dequant_dtype),
)
batch_size = x.size(0)
dim_n = 2
n = weight[0].size(dim_n)
is_a8w8_input = (x.dtype == torch.float8_e5m2 or x.dtype == torch.float8_e4m3fn) and \
(weight[0].dtype == torch.float8_e5m2 or weight[0].dtype == torch.float8_e4m3fn)
is_a4w4_input = False
if x_dtype is not None and weight_dtype is not None:
is_a4w4_input = x_dtype == torch_npu.float4_e2m1fn_x2 and weight_dtype == torch_npu.float4_e2m1fn_x2
FP4_IN_INT8 = 2
weight_trans = (x.size(-1) == weight[0].size(-2))
mxfp_multi_base_size = 2
mxfp_divisor_size = 64
output_n = n // mxfp_multi_base_size
output_scale_n = n // mxfp_multi_base_size / mxfp_divisor_size
output_n_new = ((n // mxfp_multi_base_size) * FP4_IN_INT8)
output_scale_n_new = (math.ceil(n * FP4_IN_INT8 // mxfp_multi_base_size / mxfp_divisor_size))
output_shape = None
output_scale_shape = None
if weight[0].dtype == torch.int8:
output_shape = torch.empty([batch_size, n // mxfp_multi_base_size], dtype=torch.int8, device=x.device)
output_scale_shape = torch.empty([batch_size], dtype=torch.float32, device=x.device)
if quant_dtype == TORCH_DTYPE_MAP[torch.float8_e5m2] and dequant_mode == 2:
if is_a8w8_input:
output_shape = torch.empty([batch_size, output_n], dtype=torch.float8_e5m2, device=x.device)
output_scale_shape = torch.empty([batch_size, math.ceil(output_scale_n), mxfp_multi_base_size], dtype=torch.uint8, device=x.device)
elif is_a4w4_input:
if not weight_trans:
output_shape = torch.empty([batch_size, output_n_new], dtype=torch.float8_e5m2, device=x.device)
output_scale_shape = torch.empty([batch_size, output_scale_n_new, mxfp_multi_base_size], dtype=torch.uint8, device=x.device)
else:
output_shape = torch.empty([batch_size, output_n], dtype=torch.float8_e5m2, device=x.device)
output_scale_shape = torch.empty([batch_size, math.ceil(output_scale_n), mxfp_multi_base_size], dtype=torch.uint8, device=x.device)
elif quant_dtype == TORCH_DTYPE_MAP[torch.float8_e4m3fn] and dequant_mode == 2:
if is_a8w8_input:
output_shape = torch.empty([batch_size, output_n], dtype=torch.float8_e4m3fn, device=x.device)
output_scale_shape = torch.empty([batch_size, math.ceil(output_scale_n), mxfp_multi_base_size], dtype=torch.uint8, device=x.device)
elif is_a4w4_input:
if not weight_trans:
output_shape = torch.empty([batch_size, output_n_new], dtype=torch.float8_e4m3fn, device=x.device)
output_scale_shape = torch.empty([batch_size, output_scale_n_new, mxfp_multi_base_size], dtype=torch.uint8, device=x.device)
else:
output_shape = torch.empty([batch_size, output_n], dtype=torch.float8_e4m3fn, device=x.device)
output_scale_shape = torch.empty([batch_size, math.ceil(output_scale_n), mxfp_multi_base_size], dtype=torch.uint8, device=x.device)
elif quant_dtype == torch_npu.float4_e2m1fn_x2 and dequant_mode == 2:
if is_a4w4_input:
if not weight_trans:
output_shape = torch.empty([batch_size, output_n], dtype=torch.uint8, device=x.device)
output_scale_shape = torch.empty([batch_size, math.ceil(output_scale_n_new), mxfp_multi_base_size], dtype=torch.uint8, device=x.device)
else:
output_shape = torch.empty([batch_size, output_n // FP4_IN_INT8], dtype=torch.uint8, device=x.device)
output_scale_shape = torch.empty([batch_size, math.ceil(output_scale_n), mxfp_multi_base_size], dtype=torch.uint8, device=x.device)
elif dequant_mode == 0:
out_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP[quant_dtype]
if out_dtype == torch_npu.hifloat8:
out_dtype = torch.uint8
output_shape = torch.empty([batch_size, n // mxfp_multi_base_size], dtype=out_dtype, device=x.device)
output_scale_shape = torch.empty([batch_size], dtype=torch.float32, device=x.device)
return output_shape, output_scale_shape
@impl(m, "npu_recurrent_gated_delta_rule")
def npu_recurrent_gated_delta_rule_meta(query, key, value, state, *, beta=None, scale=None, actual_seq_lengths=None, ssm_state_indices=None, num_accepted_tokens=None, g=None, gk=None):
torch._check(value.dim() == 3, lambda: f"valueTensor dim must be 3, but got {value.dim()}.")
out_shape = (value.size(0), value.size(1), value.size(2))
out = torch.empty(out_shape, dtype=torch.bfloat16, device=value.device)
return out
@impl(m, "npu_recurrent_gated_delta_rule_functional")
def npu_recurrent_gated_delta_rule_functional_meta(query, key, value, state, *, beta=None, scale=None, actual_seq_lengths=None, ssm_state_indices=None, num_accepted_tokens=None, g=None, gk=None):
torch._check(state.dim() == 4, lambda: f"state dim must be 4, but got {state.dim()}.")
torch._check(value.dim() == 3, lambda: f"valueTensor dim must be 3, but got {value.dim()}.")
state_shape = (state.size(0), state.size(1), state.size(2), state.size(3))
out_shape = (value.size(0), value.size(1), value.size(2))
finalState = torch.empty(state_shape, dtype=torch.bfloat16, device=state.device)
out = torch.empty(out_shape, dtype=torch.bfloat16, device=value.device)
return out, finalState
@impl(m, "npu_moe_token_unpermute_with_routing_map")
def npu_moe_token_unpermute_with_routing_map(permuted_tokens, sorted_indices, restore_shape, *, probs=None, routing_map=None, drop_and_pad=False):
unpermuted_tokens = torch.empty([restore_shape[0], restore_shape[1]], dtype=permuted_tokens.dtype, device=permuted_tokens.device)
return unpermuted_tokens
@impl(m, "_npu_moe_token_unpermute_with_routing_map")
def _npu_moe_token_unpermute_with_routing_map(permuted_tokens, sorted_indices, restore_shape, *, probs=None, routing_map=None, drop_and_pad=False):
unpermuted_tokens = torch.empty([restore_shape[0], restore_shape[1]], dtype=permuted_tokens.dtype, device=permuted_tokens.device)
out_index = torch.empty(sorted_indices.shape, dtype=sorted_indices.dtype, device=sorted_indices.device)
permuted_token_id = torch.empty(sorted_indices.shape, dtype=sorted_indices.dtype, device=sorted_indices.device)
permute_probs = None
if probs is not None:
permute_probs = torch.empty(sorted_indices.shape, dtype=probs.dtype, device=probs.device)
return unpermuted_tokens, out_index, permuted_token_id, permute_probs
@impl(m, "npu_moe_token_unpermute_with_routing_map_grad")
def npu_moe_token_unpermute_with_routing_map_grad(unpermuted_tokens_grad, out_index, permuted_token_id, routing_map, permuted_tokens, probs, drop_and_pad, restore_shape):
permuted_tokens_grad_out = torch.empty([out_index.shape[0], unpermuted_tokens_grad.shape[1]], dtype=unpermuted_tokens_grad.dtype, device=unpermuted_tokens_grad.device)
if probs is not None:
probs_grad_out = torch.empty(probs.shape, dtype=unpermuted_tokens_grad.dtype, device=unpermuted_tokens_grad.device)
return permuted_tokens_grad_out, probs_grad_out
else:
return permuted_tokens_grad_out, None
@impl(m, "npu_dynamic_block_quant")
def npu_dynamic_block_quant_meta(x, *, min_scale=0.0, round_mode="rint", dst_type=1, row_block_size=1, col_block_size=128, dst_type_max=0.0):
dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.float8_e5m2)
y = torch.empty(x.shape, dtype=dtype, device=x.device)
scale_shape = list(x.shape)
if len(scale_shape) == 2:
scale_shape[0] = math.ceil(scale_shape[0] / row_block_size)
scale_shape[1] = math.ceil(scale_shape[1] / col_block_size)
elif len(scale_shape) == 3:
scale_shape[1] = math.ceil(scale_shape[1] / row_block_size)
scale_shape[2] = math.ceil(scale_shape[2] / col_block_size)
else:
raise RuntimeError(f"Expected x to have 2 or 3 dimensions, but got {x.dim()}.")
scale_shape = torch.Size(scale_shape)
scale = torch.empty(scale_shape, dtype=torch.float32, device=x.device)
return y, scale
@impl(m, "npu_gather_pa_kv_cache_functional")
def npu_gather_pa_kv_cache_functional_meta(key_cache, value_cache, block_tables, seq_lens, key, value, *, seq_offset=None, is_seq_lens_cumsum=False):
key_out = key.new_empty(key.shape, dtype=key.dtype, device='meta')
value_out = value.new_empty(value.shape, dtype=value.dtype, device='meta')
return (key_out, value_out)
@impl(m, "npu_sim_exponential_")
def npu_sim_exponential__meta(self, lambd=1, generator=None):
return torch.empty_like(self)
@impl(m, "npu_grouped_matmul_add")
def npu_grouped_matmul_add_meta(
y,
x,
weight,
group_list,
*,
transpose_x=True,
transpose_weight=False,
group_type=2,
):
torch._check(
group_type == 2,
lambda: f"group_type only supports 2, but got {group_type} {ops_error(ErrCode.VALUE)}",
)
return y
@impl(m, "npu_cross_entropy_loss")
def npu_cross_entropy_loss_meta(
input_,
target,
weight=None,
reduction="mean",
ignore_index=-100,
label_smoothing=0.0,
lse_square_scale_for_zloss=0.0,
return_zloss=False,
):
input_shape = input_.shape
loss_out_shape = [
input_shape[0],
]
if reduction != "none":
loss_out_shape = [
1,
]
log_prob_shape = input_shape
zloss_shape = loss_out_shape
lse_for_zloss_shape = [
input_shape[0],
]
return (
torch.empty(loss_out_shape, dtype=input_.dtype, device=input_.device),
torch.empty(log_prob_shape, dtype=input_.dtype, device=input_.device),
torch.empty(zloss_shape, dtype=input_.dtype, device=input_.device),
torch.empty(lse_for_zloss_shape, dtype=input_.dtype, device=input_.device),
)
@impl(m, "npu_cross_entropy_loss_backward")
def npu_cross_entropy_loss_backward_meta(
grad_loss,
log_prob,
target,
weight=None,
grad_zloss=None,
lse_for_zloss=None,
reduction='mean',
ignore_index=-100,
label_smoothing=0.0,
lse_square_scale_for_zloss=0.0
):
result = torch.empty_like(log_prob)
return result
@impl(m, "npu_apply_adam_w.out")
def npu_apply_adam_w_meta(
beta1_power,
beta2_power,
lr,
weight_decay,
beta1,
beta2,
epsilon,
grad,
max_grad_norm,
amsgrad,
maximize,
*,
out,
):
return out[0], out[1], out[2]
@impl(m, "npu_conv2d")
def npu_conv2d_meta(input_, weight, bias, strides, pads, dilations, groups):
input_shape = input_.size()
weight_shape = weight.size()
nout = input_shape[0]
cout = weight_shape[0]
hout = (
input_shape[2] + pads[0] * 2 - dilations[0] * (weight_shape[2] - 1) - 1
) // strides[0] + 1
wout = (
input_shape[3] + pads[1] * 2 - dilations[1] * (weight_shape[3] - 1) - 1
) // strides[1] + 1
torch._check(
hout > 0 and wout > 0,
lambda: "ho, wo should larger than 0, but now ho is "
+ str(hout)
+ ", and wo is "
+ str(wout)
+ ops_error(ErrCode.VALUE),
)
output_dim_list = [nout, cout, hout, wout]
return torch.empty(tuple(output_dim_list), dtype=input_.dtype, device=input_.device)
@impl(m, "npu_conv2d_backward")
def npu_conv2d_backward_meta(x, grad_output, weight, stride, padding, dilation, groups, output_mask):
Co = weight.size(0)
result3_shape = (Co,)
result_1 = torch.empty(x.size(), dtype=x.dtype, device='meta')
result_2 = torch.empty(weight.size(), dtype=weight.dtype, device='meta')
if output_mask[2]:
result_3 = torch.empty(result3_shape, dtype=x.dtype, device='meta')
else:
result_3 = None
return (result_1, result_2, result_3)
has_side_effect(torch.ops.npu.npu_attention_to_ffn.default)
@impl(m, "npu_attention_to_ffn")
def npu_attention_to_ffn_meta(x, session_id, micro_batch_id, layer_id, expert_ids, expert_rank_table, group, world_size,
ffn_token_info_table_shape, ffn_token_data_shape, attn_token_info_table_shape, moe_expert_num,
scales=None, active_mask=None, quant_mode=0, sync_flag=0, ffn_start_rank_id=0):
return
has_side_effect(torch.ops.npu.npu_ffn_to_attention.default)
@impl(m, "npu_ffn_to_attention")
def npu_ffn_to_attention_meta(x, session_ids, micro_batch_ids, token_ids, expert_offsets, actual_token_num, group, world_size,
token_info_table_shape, token_data_shape, attn_rank_table=None):
return
@impl(m, "repeat_interleave_backward_int")
def npu_repeat_interleave_backward_int_meta(grad, x, repeats, dim=None):
result = torch.empty_like(x)
return result
@impl(m, "npu_dynamic_mx_quant_with_dual_axis")
def npu_dynamic_mx_quant_with_dual_axis(input_dummy, *, round_mode="rint", dst_type=296, scale_alg=0):
dim_num = input_dummy.dim()
mxscale1_shape = []
mxscale2_shape = []
if scale_alg != 0:
raise RuntimeError("Invalid scale_alg value: {0}. Expected 0.".format(scale_alg) +
ops_error(ErrCode.PARAM))
last_axis = -1
second_to_last_axis = -2
last_axis_change = last_axis + dim_num
second_to_last_axis_change = second_to_last_axis + dim_num
for dim in range(dim_num):
mxscale1_shape.append(input_dummy.size(dim))
mxscale2_shape.append(input_dummy.size(dim))
mxscale1_shape.append(2)
mxscale2_shape.append(2)
block_size = 32
last_dim_size = int(math.ceil(mxscale1_shape[last_axis_change] / block_size))
last_dim_size = (last_dim_size + 2 - 1) // 2
second_to_last_dim_size = int(math.ceil(mxscale2_shape[second_to_last_axis_change] / block_size))
second_to_last_dim_size = (second_to_last_dim_size + 2 - 1) // 2
mxscale1_shape[last_axis_change] = last_dim_size
mxscale2_shape[second_to_last_axis_change] = second_to_last_dim_size
torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.int8)
if torch_dtype == torch.float8_e5m2 or dst_type == torch_npu.float8_e5m2:
y1 = torch.empty_like(input_dummy, dtype=torch.float8_e5m2)
y2 = torch.empty_like(input_dummy, dtype=torch.float8_e5m2)
elif torch_dtype == torch.float8_e4m3fn or dst_type == torch_npu.float8_e4m3fn:
y1 = torch.empty_like(input_dummy, dtype=torch.float8_e4m3fn)
y2 = torch.empty_like(input_dummy, dtype=torch.float8_e4m3fn)
else:
if input_dummy.size(dim_num - 1) % 2:
raise RuntimeError("If output dtype is float4_e2m1 or float4_e1m2, " \
"the last dim of input must be divisible by 2, " +
ops_error(ErrCode.PARAM))
y1_shape = []
y2_shape = []
for dim in range(dim_num - 1):
y1_shape.append(input_dummy.size(dim))
y2_shape.append(input_dummy.size(dim))
y1_shape.append(input_dummy.size(dim_num - 1) // 2)
y1 = input_dummy.new_empty(y1_shape, dtype=torch.uint8)
y2_shape.append(input_dummy.size(dim_num - 1) // 2)
y2 = input_dummy.new_empty(y2_shape, dtype=torch.uint8)
mxscale1 = input_dummy.new_empty(mxscale1_shape, dtype=torch.uint8)
mxscale2 = input_dummy.new_empty(mxscale2_shape, dtype=torch.uint8)
return (y1, mxscale1, y2, mxscale2)
has_side_effect(torch.ops.npu.save_npugraph_tensor.default)
@impl(m, "save_npugraph_tensor")
def save_npugraph_tensor_meta(input, *, save_path=None):
return
has_side_effect(torch.ops.npu.save_npugraph_tensor.tensorlist)
@impl(m, "save_npugraph_tensor.tensorlist")
def save_npugraph_tensor_tensorlist_meta(input, *, save_name=None, save_dir=None, suffix=None):
return
@impl(m, "npu_grouped_dynamic_block_quant")
def npu_grouped_dynamic_block_quant_meta(x, group_list, *, min_scale=0.0, round_mode="rint", dst_type=torch.float8_e5m2, row_block_size=1, col_block_size=128, group_list_type=0):
dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.float8_e5m2)
y = torch.empty(x.shape, dtype=dtype, device=x.device)
scale_shape = list(x.shape)
if len(scale_shape) == 2:
scale_shape[0] = int(scale_shape[0] / row_block_size + group_list.shape[0])
scale_shape[1] = int(math.ceil(scale_shape[1] / col_block_size))
elif len(scale_shape) == 3:
scale_shape[1] = int(scale_shape[1] / row_block_size + group_list.shape[0])
scale_shape[2] = int(math.ceil(scale_shape[2] / col_block_size))
else:
raise RuntimeError(f"Expected x to have 2 or 3 dimensions, but got {x.dim()}.")
scale_shape = torch.Size(scale_shape)
scale = torch.empty(scale_shape, dtype=torch.float32, device=x.device)
return y, scale
@impl(m, "npu_deformable_conv2d")
def npu_deformable_conv2d_meta(input, weight, offset, bias, kernel_size, stride, padding,
dilation=None, groups=1, deformable_groups=1):
if dilation is None:
dilation = [1, 1, 1, 1]
MIN_DIM = 4
torch._check(input.dim() >= MIN_DIM, lambda: f"input has to be more than {MIN_DIM}D, but got Tensor of dimension {input.dim()}.")
torch._check(offset.dim() >= MIN_DIM, lambda: f"offset has to be more than {MIN_DIM}D, but got Tensor of dimension {offset.dim()}.")
torch._check(len(stride) >= MIN_DIM, lambda: f"stride must have at least 4 elements.")
torch._check(len(dilation) >= MIN_DIM, lambda: f"dilation must have at least 4 elements.")
N, in_c, in_h, in_w = input.shape
out_c, inc_or_groups, k_h, k_w = weight.shape
dil_h = (k_h - 1) * dilation[2] + 1
dil_w = (k_w - 1) * dilation[3] + 1
out_h = (in_h + padding[0] + padding[1] - dil_h) // stride[2] + 1
out_w = (in_w + padding[2] + padding[3] - dil_w) // stride[3] + 1
out_shape = (N, out_c, out_h, out_w)
deform_out_shape = (N, in_c, out_h * k_h, out_w * k_w)
out = input.new_empty(out_shape)
deform_out = input.new_empty(deform_out_shape)
return out, deform_out
@impl(m, "npu_ps_roi_pooling")
def npu_ps_roi_pooling_meta(x, rois, spatial_scale, group_size, output_dim):
MIN_ROIS_DIM = 3
torch._check(
rois.dim() >= MIN_ROIS_DIM,
lambda: f"rois only supports at least {MIN_ROIS_DIM}D tensors, rois got: {rois.dim()}D."
)
batch_size, _, rois_num = rois.shape
total_rois = batch_size * rois_num
out = x.new_empty((total_rois, output_dim, group_size, group_size))
return out
@impl(m, "npu_convolution")
def npu_convolution_meta(input, weight, bias, stride, padding, dilation, groups):
torch._check(input.dim() >= 4, lambda: f"Convolution input must be at least 4D.")
HW_START_DIM = 2
N = input.size(0)
out_channels = weight.size(0)
spatial_in = input.shape[HW_START_DIM:]
kernel = weight.shape[HW_START_DIM:]
out_shape = []
for i, (spatial_num, kernel_num, stride_num, padding_num, dilation_num) in enumerate(
zip(spatial_in, kernel, stride, padding, dilation)
):
O_size = (spatial_num + 2 * padding_num - dilation_num * (kernel_num - 1) - 1) // stride_num + 1
torch._check(O_size > 0, lambda: f"Invalid output size at dim {i}: {O_size}.")
out_shape.append(O_size)
return input.new_empty((N, out_channels, *out_shape))
@impl(m, "npu_convolution_transpose")
def npu_convolution_transpose_meta(input, weight, bias, padding, output_padding, stride, dilation, groups):
torch._check(input.dim() >= 4, lambda: f"Convolution input must be at least 4D.")
HW_START_DIM = 2
N = input.size(0)
out_channels = weight.size(1) * groups
spatial_in = input.shape[HW_START_DIM:]
kernel = weight.shape[HW_START_DIM:]
out_shape = []
for i, (spatial_num, kernel_num, stride_num, padding_num, op_num, dilation_num) in enumerate(
zip(spatial_in, kernel, stride, padding, output_padding, dilation)
):
O_size = ((spatial_num - 1) * stride_num - 2 * padding_num + dilation_num * (kernel_num - 1) + op_num + 1)
torch._check(O_size > 0, lambda: f"Invalid output size at dim {i}: {O_size}.")
out_shape.append(O_size)
return input.new_empty((N, out_channels, *out_shape))
@impl(m_aten, "_convolution")
def convolution_meta(input_tensor, weight, bias, stride, padding, dilation, transposeds,
output_padding, groups, benchmark, deterministic, cudnn_enabled, *args):
batch_size, in_channels = input_tensor.shape[0], input_tensor.shape[1]
out_channels = weight.shape[0]
if isinstance(stride, (tuple, list)):
stride_h, stride_w = stride
else:
stride_h = stride_w = stride
if isinstance(padding, (tuple, list)):
pad_h, pad_w = padding
else:
pad_h = pad_w = padding
if isinstance(dilation, (tuple, list)):
dilation_h, dilation_w = dilation
else:
dilation_h = dilation_w = dilation
kernel_h, kernel_w = weight.shape[2], weight.shape[3]
input_h, input_w = input_tensor.shape[2], input_tensor.shape[3]
output_h = (input_h + 2 * pad_h - dilation_h * (kernel_h - 1) - 1) // stride_h + 1
output_w = (input_w + 2 * pad_w - dilation_w * (kernel_w - 1) - 1) // stride_w + 1
output_shape = (batch_size, out_channels, output_h, output_w)
return torch.empty(output_shape, dtype=input_tensor.dtype, device='meta')
@impl(m, "batch_norm_reduce")
def batch_norm_reduce_meta(self, eps):
out_sum = torch.empty(self.size(1), dtype=self.dtype, device='meta')
out_square_sum = torch.empty(self.size(1), dtype=self.dtype, device='meta')
return (out_sum, out_square_sum)
@impl(m, "kl_div_backward")
def kl_div_backward_meta(grad_out, x, target, reduction="mean", log_target=False):
result = torch.empty_like(x)
return result
@impl(m, "npu_dropout_with_add_softmax")
def npu_dropout_with_add_softmax_meta(x1, x2, alpha, prod, axis):
add = x1 + x2
softmax = add.softmax(axis=axis)
dropout = torch.empty_like(softmax)
return (add.flatten().to(torch.uint8), softmax, dropout)
@impl(m, "l1_loss_backward")
def npu_l1_loss_backward_meta(grad_output, input, target, reduction):
return torch.empty_like(input)
@impl(m, "npu_linear_backward")
def npu_npu_linear_backward_meta(grad_output, input1, input2):
return torch.empty_like(input1), torch.empty_like(input2)
@impl(m, "npu_dynamic_block_mx_quant")
def npu_dynamic_block_mx_quant(input_dummy, *, round_mode="rint", dst_type=296, scale_alg=0, dst_type_max=0.0):
dim_num = input_dummy.dim()
scale1_shape = []
scale2_shape = []
last_axis = -1
second_to_last_axis = -2
last_axis_change = last_axis + dim_num
second_to_last_axis_change = second_to_last_axis + dim_num
for dim in range(dim_num):
scale1_shape.append(input_dummy.size(dim))
scale2_shape.append(input_dummy.size(dim))
scale1_shape.append(2)
scale2_shape.append(2)
block_size = 32
last_dim_size = int(math.ceil(scale1_shape[last_axis_change] / block_size))
last_dim_size = (last_dim_size + 2 - 1) // 2
second_to_last_dim_size = int(math.ceil(scale2_shape[second_to_last_axis_change] / block_size))
second_to_last_dim_size = (second_to_last_dim_size + 2 - 1) // 2
scale1_shape[last_axis_change] = last_dim_size
scale2_shape[second_to_last_axis_change] = second_to_last_dim_size
torch_dtype = TORCH_DTYPE_ENUM_VALUE_TO_SCALAR_TYPE_MAP.get(dst_type, torch.int8)
if torch_dtype == torch.float8_e5m2 or dst_type == torch_npu.float8_e5m2:
y = torch.empty_like(input_dummy, dtype=torch.float8_e5m2)
elif torch_dtype == torch.float8_e4m3fn or dst_type == torch_npu.float8_e4m3fn:
y = torch.empty_like(input_dummy, dtype=torch.float8_e4m3fn)
else:
if input_dummy.size(dim_num - 1) % 2:
raise RuntimeError("If output dtype is float4_e2m1 or float4_e1m2, " \
"the last dim of input must be divisible by 2, " +
ops_error(ErrCode.PARAM))
y_shape = []
for dim in range(dim_num - 1):
y_shape.append(input_dummy.size(dim))
y_shape.append(input_dummy.size(dim_num - 1) // 2)
y = input_dummy.new_empty(y_shape, dtype=torch.uint8)
scale1 = input_dummy.new_empty(scale1_shape, dtype=torch.uint8)
scale2 = input_dummy.new_empty(scale2_shape, dtype=torch.uint8)
return (y, scale1, scale2)