from typing import (
Any,
Callable,
ContextManager,
Iterator,
List,
Literal,
NamedTuple,
Optional,
Sequence,
Tuple,
TypeVar,
Union,
overload,
)
from mindspeed.op_builder.builder import MindSpeedOpBuilder, AS_LIBRARY
import torch
import torchair
from torch.library import Library, impl
torch_npu_api_version = None
try:
from torchair import ge
from torchair.ge import Tensor, TensorSpec, DataType
from torchair import register_fx_node_ge_converter
except ImportError:
torch_npu_api_version = 1
from torchair.ge_concrete_graph import ge_apis as ge
from torchair.ge_concrete_graph.fx2ge_converter import register_fx_node_ge_converter
from torchair.ge_concrete_graph.ge_graph import Tensor, TensorSpec, DataType
from torchair.ge_concrete_graph.ge_graph import get_default_ge_graph, next_unique_name
from torchair.ge_concrete_graph.ge_graph import compat_as_bytes
from torchair.ge_concrete_graph.ge_graph import get_invalid_desc
else:
torch_npu_api_version = 2
class FFNOpBuilder(MindSpeedOpBuilder):
OP_NAME = "npu_ffn"
OP_PROTO = "npu_ffn(Tensor x, Tensor weight1, Tensor weight2, str activation, *, Tensor? expert_tokens=None, \
Tensor? expert_tokens_index=None, Tensor? bias1=None, Tensor? bias2=None, Tensor? scale=None, \
Tensor? offset=None, Tensor? deq_scale1=None, Tensor? deq_scale2=None, Tensor? antiquant_scale1=None, \
Tensor? antiquant_scale2=None, Tensor? antiquant_offset1=None, Tensor? antiquant_offset2=None, \
int? inner_precise=None, ScalarType? output_dtype=None) -> Tensor"
def __init__(self):
super(FFNOpBuilder, self).__init__(self.OP_NAME)
self.register_op_proto(self.OP_PROTO)
self.register_op_ir()
def sources(self):
return ['ops/csrc/cann/ffn.cpp']
def include_paths(self):
paths = super().include_paths()
paths += ['ops/csrc/cann/inc']
return paths
def cxx_args(self):
args = super().cxx_args()
args += [
'-Wno-sign-compare',
'-Wno-deprecated-declarations',
'-Wno-return-type',
"-D__FILENAME__='\"$$(notdir $$(abspath $$<))\"'"
]
return args
def register_op_ir(self):
@impl(AS_LIBRARY, "npu_ffn", "Meta")
def npu_ffn_forward(x, weight1, weight2, activation, *, expert_tokens=None, expert_tokens_index=None,
bias1=None, bias2=None, scale=None, offset=None, deq_scale1=None, deq_scale2=None,
antiquant_scale1=None, antiquant_scale2=None, antiquant_offset1=None,
antiquant_offset2=None, inner_precise=0, output_dtype=None):
dim_list = []
for i in range(0, x.dim() - 1):
dim_list.append(x.size(i))
dim_list.append(weight2.size(weight2.dim() - 1))
if x.dtype == torch.int8:
if output_dtype is not None and output_dtype == torch.bfloat16:
return x.new_empty(tuple(dim_list), dtype=torch.bfloat16)
else:
return x.new_empty(tuple(dim_list), dtype=torch.float16)
else:
return x.new_empty(tuple(dim_list))
@register_fx_node_ge_converter(torch.ops.mindspeed.npu_ffn.default)
def convert_npu_ffn(
x: Tensor,
weight1: Tensor,
weight2: Tensor,
activation: str,
*,
expert_tokens: Optional[Tensor] = None,
expert_tokens_index: Optional[Tensor] = None,
bias1: Optional[Tensor] = None,
bias2: Optional[Tensor] = None,
scale: Optional[Tensor] = None,
offset: Optional[Tensor] = None,
deq_scale1: Optional[Tensor] = None,
deq_scale2: Optional[Tensor] = None,
antiquant_scale1: Optional[Tensor] = None,
antiquant_scale2: Optional[Tensor] = None,
antiquant_offset1: Optional[Tensor] = None,
antiquant_offset2: Optional[Tensor] = None,
inner_precise: Optional[int] = 0,
output_dtype: Optional[int] = None,
meta_outputs: TensorSpec = None
):
'''"npu::npu_ffn(Tensor x, Tensor weight1, Tensor weight2, str activation, *, Tensor? expert_tokens=None,
Tensor? expert_tokens_index=None, Tensor? bias1=None, Tensor? bias2=None, Tensor? scale=None,
Tensor? offset=None, Tensor? deq_scale1=None, Tensor? deq_scale2=None,
Tensor? antiquant_scale1=None, Tensor? antiquant_scale2=None, Tensor? antiquant_offset1=None,
Tensor? antiquant_offset2=None, int? inner_precise=None, ScalarType? output_dtype=None)
-> Tensor
"'''
tokens_index_flag = False
if expert_tokens is not None and expert_tokens_index is not None:
raise ValueError("Cannot assign the value to expert_tokens and expert_tokens_index simultaneously!")
elif expert_tokens_index is not None:
tokens_index_flag = True
expert_tokens = expert_tokens_index
y_dtype = -1
if x.dtype == DataType.DT_INT8 and output_dtype is not None:
if output_dtype == torch.float16:
y_dtype = 0
elif output_dtype == torch.bfloat16:
y_dtype = 1
else:
raise NotImplementedError("In the quant scenario, output_dtype should be float16 or bfloat16,"
"otherwise it should be None!")
return FFN(x, weight1, weight2, expert_tokens=expert_tokens, bias1=bias1, bias2=bias2, scale=scale,
offset=offset, deq_scale1=deq_scale1, deq_scale2=deq_scale2, antiquant_scale1=antiquant_scale1,
antiquant_scale2=antiquant_scale2, antiquant_offset1=antiquant_offset1,
antiquant_offset2=antiquant_offset2, activation=activation, inner_precise=inner_precise,
output_dtype=y_dtype, tokens_index_flag=tokens_index_flag)
FFN = None
if torch_npu_api_version == 2:
def FFNV2(x: Tensor,
weight1: Tensor,
weight2: Tensor,
expert_tokens: Optional[Tensor],
bias1: Optional[Tensor],
bias2: Optional[Tensor],
scale: Optional[Tensor],
offset: Optional[Tensor],
deq_scale1: Optional[Tensor],
deq_scale2: Optional[Tensor],
antiquant_scale1: Optional[Tensor],
antiquant_scale2: Optional[Tensor],
antiquant_offset1: Optional[Tensor],
antiquant_offset2: Optional[Tensor],
*,
activation: str,
inner_precise: int = 0,
output_dtype: int = -1,
tokens_index_flag: bool = False):
"""REG_OP(FFN)\n
.INPUT(x, TensorType({DT_INT8, DT_FLOAT16, DT_BF16}))\n
.INPUT(weight1, TensorType({DT_INT8, DT_FLOAT16, DT_BF16, DT_INT4}))\n
.INPUT(weight2, TensorType({DT_INT8, DT_FLOAT16, DT_BF16, DT_INT4}))\n
.OPTIONAL_INPUT(expert_tokens, TensorType({DT_INT64}))\n
.OPTIONAL_INPUT(bias1, TensorType({DT_INT32, DT_FLOAT16, DT_FLOAT}))\n
.OPTIONAL_INPUT(bias2, TensorType({DT_INT32, DT_FLOAT16, DT_FLOAT}))\n
.OPTIONAL_INPUT(scale, TensorType({DT_FLOAT}))\n
.OPTIONAL_INPUT(offset, TensorType({DT_FLOAT}))\n
.OPTIONAL_INPUT(deq_scale1, TensorType({DT_UINT64, DT_BF16}))\n
.OPTIONAL_INPUT(deq_scale2, TensorType({DT_UINT64, DT_BF16}))\n
.OPTIONAL_INPUT(antiquant_scale1, TensorType({DT_FLOAT16, DT_BF16}))\n
.OPTIONAL_INPUT(antiquant_scale2, TensorType({DT_FLOAT16, DT_BF16}))\n
.OPTIONAL_INPUT(antiquant_offset1, TensorType({DT_FLOAT16, DT_BF16}))\n
.OPTIONAL_INPUT(antiquant_offset2, TensorType({DT_FLOAT16, DT_BF16}))\n
.OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16}))\n
.REQUIRED_ATTR(activation, String)\n
.ATTR(inner_precise, Int, 0)\n
.ATTR(output_dtype, Int, -1)\n
.ATTR(tokens_index_flag, Bool, false)\n
"""
y = torchair.ge.custom_op("FFN",
inputs={
"x": x,
"weight1": weight1,
"weight2": weight2,
"expert_tokens": expert_tokens,
"bias1": bias1,
"bias2": bias2,
"scale": scale,
"offset": offset,
"deq_scale1": deq_scale1,
"deq_scale2": deq_scale2,
"antiquant_scale1": antiquant_scale1,
"antiquant_scale2": antiquant_scale2,
"antiquant_offset1": antiquant_offset1,
"antiquant_offset2": antiquant_offset2
},
attrs={
"activation": ge.attr.Str(activation),
"inner_precise": ge.attr.Int(inner_precise),
"output_dtype": ge.attr.Int(output_dtype),
"tokens_index_flag": ge.attr.Bool(tokens_index_flag)
},
outputs=[
"y"
])
return y
FFN = FFNV2
elif torch_npu_api_version == 1:
def FFNV1(x: Tensor,
weight1: Tensor,
weight2: Tensor,
expert_tokens: Optional[Tensor],
bias1: Optional[Tensor],
bias2: Optional[Tensor],
scale: Optional[Tensor],
offset: Optional[Tensor],
deq_scale1: Optional[Tensor],
deq_scale2: Optional[Tensor],
antiquant_scale1: Optional[Tensor],
antiquant_scale2: Optional[Tensor],
antiquant_offset1: Optional[Tensor],
antiquant_offset2: Optional[Tensor],
*,
activation: str,
inner_precise: int = 0,
output_dtype: int = -1,
tokens_index_flag: bool = False,
dependencies=[],
node_name=None):
"""REG_OP(FFN)\n
.INPUT(x, TensorType({DT_INT8, DT_FLOAT16, DT_BF16}))\n
.INPUT(weight1, TensorType({DT_INT8, DT_FLOAT16, DT_BF16, DT_INT4}))\n
.INPUT(weight2, TensorType({DT_INT8, DT_FLOAT16, DT_BF16, DT_INT4}))\n
.OPTIONAL_INPUT(expert_tokens, TensorType({DT_INT64}))\n
.OPTIONAL_INPUT(bias1, TensorType({DT_INT32, DT_FLOAT16, DT_FLOAT}))\n
.OPTIONAL_INPUT(bias2, TensorType({DT_INT32, DT_FLOAT16, DT_FLOAT}))\n
.OPTIONAL_INPUT(scale, TensorType({DT_FLOAT}))\n
.OPTIONAL_INPUT(offset, TensorType({DT_FLOAT}))\n
.OPTIONAL_INPUT(deq_scale1, TensorType({DT_UINT64, DT_BF16}))\n
.OPTIONAL_INPUT(deq_scale2, TensorType({DT_UINT64, DT_BF16}))\n
.OPTIONAL_INPUT(antiquant_scale1, TensorType({DT_FLOAT16, DT_BF16}))\n
.OPTIONAL_INPUT(antiquant_scale2, TensorType({DT_FLOAT16, DT_BF16}))\n
.OPTIONAL_INPUT(antiquant_offset1, TensorType({DT_FLOAT16, DT_BF16}))\n
.OPTIONAL_INPUT(antiquant_offset2, TensorType({DT_FLOAT16, DT_BF16}))\n
.OUTPUT(y, TensorType({DT_FLOAT16, DT_BF16}))\n
.REQUIRED_ATTR(activation, String)\n
.ATTR(inner_precise, Int, 0)\n
.ATTR(output_dtype, Int, -1)\n
.ATTR(tokens_index_flag, Bool, false)\n
"""
op = get_default_ge_graph().op.add()
op.type = "FFN"
op.name = next_unique_name(node_name, "FFN")
for dependency in dependencies:
op.input.append(dependency.controller)
op.input.append(x.tensor)
op.input_desc.add().CopyFrom(x.desc)
op.input_desc[-1].name = "x"
op.input.append(weight1.tensor)
op.input_desc.add().CopyFrom(weight1.desc)
op.input_desc[-1].name = "weight1"
op.input.append(weight2.tensor)
op.input_desc.add().CopyFrom(weight2.desc)
op.input_desc[-1].name = "weight2"
if expert_tokens is not None:
op.input.append(expert_tokens.tensor)
op.input_desc.add().CopyFrom(expert_tokens.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "expert_tokens"
if bias1 is not None:
op.input.append(bias1.tensor)
op.input_desc.add().CopyFrom(bias1.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "bias1"
if bias2 is not None:
op.input.append(bias2.tensor)
op.input_desc.add().CopyFrom(bias2.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "bias2"
if scale is not None:
op.input.append(scale.tensor)
op.input_desc.add().CopyFrom(scale.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "scale"
if offset is not None:
op.input.append(offset.tensor)
op.input_desc.add().CopyFrom(offset.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "offset"
if deq_scale1 is not None:
op.input.append(deq_scale1.tensor)
op.input_desc.add().CopyFrom(deq_scale1.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "deq_scale1"
if deq_scale2 is not None:
op.input.append(deq_scale2.tensor)
op.input_desc.add().CopyFrom(deq_scale2.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "deq_scale2"
if antiquant_scale1 is not None:
op.input.append(antiquant_scale1.tensor)
op.input_desc.add().CopyFrom(antiquant_scale1.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "antiquant_scale1"
if antiquant_scale2 is not None:
op.input.append(antiquant_scale2.tensor)
op.input_desc.add().CopyFrom(antiquant_scale2.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "antiquant_scale2"
if antiquant_offset1 is not None:
op.input.append(antiquant_offset1.tensor)
op.input_desc.add().CopyFrom(antiquant_offset1.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "antiquant_offset1"
if antiquant_offset2 is not None:
op.input.append(antiquant_offset2.tensor)
op.input_desc.add().CopyFrom(antiquant_offset2.desc)
else:
op.input.append('')
op.input_desc.add().CopyFrom(get_invalid_desc())
op.input_desc[-1].name = "antiquant_offset2"
op.attr["activation"].s = compat_as_bytes(activation)
op.attr["inner_precise"].i = inner_precise
op.attr["output_dtype"].i = output_dtype
op.attr["tokens_index_flag"].b = tokens_index_flag
output_index = 0
op.output_desc.add().name = "y"
y = Tensor(op, output_index)
output_index += 1
return y
FFN = FFNV1
else:
raise ValueError("touch_npu_api_version unsupport!")