9583d925创建于 2024年10月23日历史提交
import torch

from torch.library import impl

from mindspeed.op_builder import QuantGMMOpBuilder

from mindspeed.op_builder.builder import AS_LIBRARY

from mindspeed.ops import gmm



__all__ = ["npu_quant_gmm", "npu_quant_gmm_v2"]





op_builder = QuantGMMOpBuilder()





@impl(AS_LIBRARY, "npu_quant_gmm", "PrivateUse1")

def _npu_quant_gmm(x, weight, scale, *, offset=None, per_token_scale=None, bias=None, group_list=None,

                   group_list_type=0, output_dtype=None, act_type=0):

    bias = [] if bias is None else [bias]

    scale = [] if scale is None else [scale]

    offset = [] if offset is None else [offset]

    per_token_scale = [] if per_token_scale is None else [per_token_scale]

    if output_dtype is None or output_dtype == torch.bfloat16:

        output_dtype_value = 1

    elif output_dtype == torch.float16:

        output_dtype_value = 0

    elif output_dtype == torch.int8:

        output_dtype_value = -1

    else:

        raise ValueError(f"output_dtype should be int8, float16, bfloat16 or None, but got {output_dtype}")

    outputs = op_builder.load().npu_quant_gmm([x], [weight], scale, offset, per_token_scale, bias, group_list,

                                              group_list_type, output_dtype_value, act_type)

    return outputs[0]





def _npu_quant_gmm_common(x, weight, scale, *, offset=None, per_token_scale=None, bias=None, group_list=None,

                          group_list_type=0, output_dtype=None, act_type=0):

    if x.dtype != torch.int8 or weight.dtype != torch.int8:

        raise ValueError(f"Quant gmm only accept quant case, but got x[{x.dtype}] weight[{weight.dtype}]")

    gmm.npu_gmm_param_verification(x, weight, bias=bias, group_list=group_list,

                                   group_type=0, group_list_type=group_list_type)

    gmm.check_optional_tensor(scale, x.device, "scale")

    gmm.check_optional_tensor(offset, x.device, "offset")

    gmm.check_optional_tensor(per_token_scale, x.device, "per_token_scale")

    return torch.ops.mindspeed.npu_quant_gmm(x, weight, scale, offset=offset, per_token_scale=per_token_scale,

                                             bias=bias, group_list=group_list, group_list_type=group_list_type,

                                             output_dtype=output_dtype, act_type=act_type)





def npu_quant_gmm(x, weight, scale, *, offset=None, per_token_scale=None, bias=None, group_list=None,

                  output_dtype=None, act_type=0):

    return _npu_quant_gmm_common(x, weight, scale, offset=offset, per_token_scale=per_token_scale,

                                 bias=bias, group_list=group_list, group_list_type=0, output_dtype=output_dtype,

                                 act_type=act_type)





def npu_quant_gmm_v2(x, weight, scale, *, offset=None, per_token_scale=None, bias=None, group_list=None,

                     output_dtype=None, act_type=0):

    return _npu_quant_gmm_common(x, weight, scale, offset=offset, per_token_scale=per_token_scale,

                                 bias=bias, group_list=group_list, group_list_type=1, output_dtype=output_dtype,

                                 act_type=act_type)