import os
from abc import ABC, abstractmethod
import math
from typing import Optional, Union, List, Sequence, Tuple
import torch
import torch.nn as nn
from torch.nn import Parameter
import torch_npu
from transformers.activations import ACT2FN
from transformers.utils import logging
from module.quantization import QuantizeMethodBase, QuantizationConfig
from .utils import (divide, set_weight_attrs)
logger = logging.get_logger(__name__)
def split_tensor_along_last_dim(
tensor: torch.Tensor,
num_partitions: int,
contiguous_split_chunks: bool = False,
) -> Sequence[torch.Tensor]:
""" Split a tensor along its last dimension.
Arguments:
tensor: input tensor.
num_partitions: number of partitions to split the tensor
contiguous_split_chunks: If True, make each chunk contiguous
in memory.
Returns:
A list of Tensors
"""
last_dim = tensor.dim() - 1
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
if contiguous_split_chunks:
return tuple(chunk.contiguous() for chunk in tensor_list)
return tensor_list
class LinearMethodBase(QuantizeMethodBase):
"""Base class for different (maybe quantized) linear methods."""
@abstractmethod
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for MergedColumnParallelLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError
@abstractmethod
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer.
"""
raise NotImplementedError
class UnquantizedLinearMethod(LinearMethodBase):
"""Linear method without quantization."""
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
**kargs) -> torch.Tensor:
origin_shape = x.size()
x = x.reshape(-1, origin_shape[-1])
out = torch.matmul(x, layer.weight.data)
if bias is not None:
out = out + bias
out = out.view(*origin_shape[:-1], out.shape[-1])
return out
def process_weights_after_loading(self, layer, is_transpose=True, is_nz=True, **kwargs):
weight = layer.weight
if is_transpose:
weight.data = weight.data.transpose(-2, -1)
if is_nz and layer.weight.dtype != torch.float32:
weight.data = torch_npu.npu_format_cast(weight.data.contiguous(), 29)
layer.weight = Parameter(weight, requires_grad=False)
class LinearBase(torch.nn.Module):
"""Base linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
return_bias: If true, return bias together with outputs in forward pass.
"""
def __init__(
self,
input_size: int,
output_size: int,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional = None,
prefix: str = "",
*,
return_bias: Optional[bool] = False,
):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
if quant_config is None:
self.quant_method: Optional[
QuantizeMethodBase] = UnquantizedLinearMethod()
else:
self.quant_method = quant_config.get_quant_method(self,
prefix=prefix)
self.return_bias = return_bias
def forward(
self, input_: torch.Tensor
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
raise NotImplementedError
class ColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Args:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias.
tp_size: The size of the tensor parallel communication domain
where the current device is located.
tp_rank: The rank index of current device in the located tensor parallel
communication domain.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.q_b_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
tp_size: int = 1,
tp_rank: int = 0,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional = None,
output_sizes: Optional[list[int]] = None,
prefix: str = "",
*,
return_bias: Optional[bool] = False,
):
self.tp_size = tp_size
self.tp_rank = tp_rank
self.input_size_per_partition = input_size
self.output_size_per_partition = divide(output_size, self.tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
)
if output_sizes is None:
output_sizes = [output_size]
if self.quant_method is None:
raise RuntimeError("quant method cannnot be none")
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)
is_sharded_weight = getattr(param, "is_sharded_weight", False)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
is_sharded_weight = is_sharded_weight or use_bitsandbytes_4bit
param_data = param.data
if output_dim is not None and not is_sharded_weight:
shard_size = param_data.shape[output_dim]
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if param_data.shape != loaded_weight.shape:
raise RuntimeError(
f"Tried to load weights of shape {loaded_weight.shape}"
f"to a parameter of shape {param_data.shape}")
param_data.copy_(loaded_weight)
def forward(
self, input_, dynamic_scale=None, out_dtype=torch.bfloat16
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
if self.quant_method is None:
raise RuntimeError("quant method cannnot be none")
output = self.quant_method.apply(self,
x=input_,
bias=bias,
dynamic_scale=dynamic_scale,
out_dtype=out_dtype)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size_per_partition}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
return s
class VocabParallelEmbedding(nn.Embedding):
"""Embedding parallelized in the vocabulary dimension.
Adapted from torch.nn.Embedding, note that we pad the vocabulary size to
make sure it is divisible by the number of model parallel NPUs.
Args:
vocab_size: Vocabulary size.
hidden_size: Second dimension of matrix A.
padding_idx: The index of the padding token.
params_dtype: Data type for the parameters.
tp_size: The size of the tensor parallel communication domain
where the current device is located.
tp_rank: The rank index of current device in the located tensor parallel
communication domain.
"""
def __init__(
self,
vocab_size,
hidden_size,
padding_idx,
params_dtype,
tp_size=1,
tp_rank=0,
):
self.tp_size = tp_size
self.tp_rank = tp_rank
self.input_size = vocab_size
self.input_size_per_partition = divide(vocab_size, self.tp_size)
shard_start_idx = self.tp_rank * self.input_size_per_partition
shard_end_idx = shard_start_idx + self.input_size_per_partition
local_padding_idx = None
if padding_idx is not None and shard_start_idx <= padding_idx < shard_end_idx:
local_padding_idx = padding_idx - shard_start_idx
super().__init__(self.input_size_per_partition,
hidden_size,
local_padding_idx,
dtype=params_dtype)
self.output_size_per_partition = hidden_size
self.output_size = hidden_size
self.output_partition_sizes = [hidden_size]
self.params_dtype = params_dtype
self.weight.requires_grad_(False)
set_weight_attrs(self.weight, {"input_dim": 0, "output_dim": 1})
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
param_data = param.data
if input_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[input_dim]
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if param_data.shape != loaded_weight.shape:
raise RuntimeError("param_data.shape != loaded_weight.shape")
param_data.copy_(loaded_weight)
class ReplicatedLinear(LinearBase):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.gate)
return_bias: If true, return bias together with outputs in forward pass.
"""
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = False,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional = None,
prefix: str = "",
*,
return_bias: bool = False,
):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix=prefix,
return_bias=return_bias,
)
if self.quant_method is None:
raise RuntimeError("quant method cannnot be none")
self.quant_method.create_weights(self,
self.input_size, [self.output_size],
self.input_size,
self.output_size,
self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=self.params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if param.size() != loaded_weight.size():
raise RuntimeError(
f"Tried to load weights of size {loaded_weight.size()}"
f"to a parameter of size {param.size()}")
param.data.copy_(loaded_weight)
def forward(
self, input_: torch.Tensor, dynamic_scale=None, out_dtype=torch.bfloat16
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[Parameter]]]:
bias = self.bias if not self.skip_bias_add else None
if self.quant_method is None:
raise RuntimeError("quant method cannnot be none")
output = self.quant_method.apply(self,
x=input_,
bias=bias,
dynamic_scale=dynamic_scale,
out_dtype=out_dtype)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str:
s = f"in_features={self.input_size}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
return s
class MergedColumnParallelLinear(LinearBase):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p]. A is concatenated by two
weight matrices along second dimension.
Args:
input_size: first dimension of matrix A.
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
tp_size: The size of the tensor parallel communication domain
where the current device is located.
tp_rank: The rank index of current device in the located tensor parallel
communication domain.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.mlp.merge_up_gate_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
def __init__(self,
input_size: int,
output_sizes: List[int],
bias: bool = False,
skip_bias_add: bool = False,
tp_size: int = 1,
tp_rank: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional = None,
prefix: str = "",
return_bias: bool = False,
):
self.output_sizes = output_sizes
self.tp_size = tp_size
if not all(output_size % tp_size == 0 for output_size in output_sizes):
raise RuntimeError("All output_sizes must be divisible by tp_size")
self.tp_rank = tp_rank
self.quant_config = quant_config
output_size = sum(output_sizes)
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
)
if self.quant_method is None:
raise RuntimeError("self.quant_method must not be None")
self.output_size_per_partition = divide(self.output_size, tp_size)
self.output_partition_sizes = [self.output_size_per_partition]
if hasattr(self, "output_sizes"):
self.output_partition_sizes = [
divide(output_size, self.tp_size)
for output_size in self.output_sizes
]
output_sizes = [output_size]
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size_per_partition,
dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
self.throw_dequant = True
def forward(self, input_, dynamic_scale=None, out_dtype=torch.bfloat16):
bias = self.bias if not self.skip_bias_add else None
if self.quant_method is None:
raise RuntimeError("self.quant_method must not be None")
output = self.quant_method.apply(self,
x=input_,
bias=bias,
dynamic_scale=dynamic_scale,
out_dtype=out_dtype)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
is_per_block_scale = getattr(param, "is_per_block_scale", False)
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
if output_dim is None:
if param_data.shape != loaded_weight.shape:
raise RuntimeError("param_data.shape != loaded_weight.shape")
param_data.copy_(loaded_weight)
return
current_shard_offset = 0
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
if loaded_shard_id >= len(self.output_sizes):
raise RuntimeError("loaded_shard_id must be less than the length of self.output_sizes")
if output_dim is not None:
if is_per_block_scale:
block_size = self.quant_config.weight_block_size[0]
shard_offset = math.ceil(sum(self.output_sizes[:loaded_shard_id]) / block_size) // self.tp_size
shard_size = math.ceil(self.output_sizes[loaded_shard_id] / block_size) // self.tp_size
else:
shard_offset = sum(self.output_sizes[:loaded_shard_id]) // self.tp_size
shard_size = self.output_sizes[loaded_shard_id] // self.tp_size
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
False)
if use_bitsandbytes_4bit:
shard_size = loaded_weight.shape[output_dim]
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
start_idx = self.tp_rank * shard_size
if not use_bitsandbytes_4bit:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is "
"the same for all partitions.")
if param_data.shape != loaded_weight.shape:
raise RuntimeError("param_data.shape != loaded_weight.shape")
param_data.copy_(loaded_weight)
class RowParallelLinear(LinearBase):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
tp_size: The size of the tensor parallel communication domain
where the current device is located.
tp_rank: The rank index of current device in the located tensor parallel
communication domain.
input_is_parallel: If true, we assume that the input is already
split across the Npus and we do not split
again.
skip_bias_add: This was added to enable performance optimization where
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.o_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
def __init__(self,
input_size: int,
output_size: int,
bias: bool = False,
tp_size: int = 1,
tp_rank: int = 0,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional = None,
prefix: str = "",
return_bias: bool = False):
super().__init__(input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
)
if self.quant_method is None:
raise RuntimeError("self.quant_method must not be None")
self.input_is_parallel = input_is_parallel
self.tp_size = tp_size
self.tp_rank = tp_rank
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=self.weight_loader)
if bias:
self.bias = Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
input_dim = getattr(param, "input_dim", None)
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
param_data = param.data
if input_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[input_dim]
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(input_dim, start_idx,
shard_size)
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
if param_data.shape != loaded_weight.shape:
raise RuntimeError("param_data.shape != loaded_weight.shape")
param_data.copy_(loaded_weight)
def forward(self, input_, dynamic_scale=None, out_dtype=torch.bfloat16):
if self.input_is_parallel:
input_parallel = input_
else:
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[self.tp_rank].contiguous()
if self.quant_method is None:
raise RuntimeError("self.quant_method is None")
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output = self.quant_method.apply(self,
x=input_parallel,
bias=bias_,
dynamic_scale=dynamic_scale,
out_dtype=out_dtype)
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
def extra_repr(self) -> str:
s = f"input_features={self.input_size_per_partition}"
s += f", output_features={self.output_size}"
s += f", bias={self.bias is not None}"
s += f", tp_size={self.tp_size}"
return s
class NpuLinear(nn.Module):
def __init__(self, in_feature, out_feature, bias: bool = False, dtype: torch.dtype = torch.bfloat16):
super().__init__()
self.weight = nn.Parameter(torch.empty((out_feature, in_feature), dtype=dtype), requires_grad=False)
self.bias = None
if bias is not None and bias:
self.bias = nn.Parameter(torch.empty((out_feature,), dtype=dtype, required_grad=False))
def forward(self, x):
origin_shape = x.size()
x = x.view(-1, origin_shape[-1])
out = torch.matmul(x, self.weight.data)
if self.bias is not None:
out = out + self.bias
out = out.view(*origin_shape[:-1], -1)
return out
class DynamicW8A8Linear(nn.Module):
def __init__(self, in_feature, out_feature, bias=False, offset=False, dtype=torch.bfloat16):
super().__init__()
self.in_feature, self.out_feature = in_feature, out_feature
self.weight = Parameter(torch.ones((out_feature, in_feature), dtype=torch.int8), requires_grad=False)
if bias:
self.bias = Parameter(torch.ones((out_feature,), dtype=dtype))
else:
self.bias = None
self.smooth_scales = Parameter(torch.ones(self.in_feature, dtype=dtype), requires_grad=False)
self.weight_scale = Parameter(torch.ones(self.out_feature, dtype=dtype), requires_grad=False)
if offset:
self.offset = Parameter(torch.ones(self.out_feature, dtype=dtype), requires_grad=False)
else:
self.offset = None
def forward(self, x, dynamic_scale=None, out_dtype=torch.bfloat16):
if dynamic_scale is not None:
x_scale = dynamic_scale
else:
x, x_scale = torch_npu.npu_dynamic_quant(x, smooth_scales=self.smooth_scales)
out_shape = x.size()[:-1] + (self.out_feature, )
x = x.view(-1, x.size(-1))
x_scale = x_scale.view(-1)
x = torch_npu.npu_quant_matmul(x, self.weight,
self.weight_scale.view(-1),
pertoken_scale=None if out_dtype == torch.int32 else x_scale,
bias=self.bias,
output_dtype=out_dtype)
x = x.view(out_shape)
if out_dtype == torch.int32:
return x, x_scale
else:
return x
class DynamicW8A8MoeFFN(nn.Module):
def __init__(self, hidden_size, intermediate_size, expert_num, dtype=torch.bfloat16):
super().__init__()
self.dtype = dtype
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.group_w1_w3 = Parameter(
torch.empty((expert_num, self.intermediate_size * 2, self.hidden_size),
dtype=torch.int8), requires_grad=False)
self.group_w2 = Parameter(
torch.empty((expert_num, self.hidden_size, self.intermediate_size),
dtype=torch.int8), requires_grad=False)
_, self.out_feature_1, self.in_feature_1 = self.group_w1_w3.size()
_, self.out_feature_2, self.in_feature_2 = self.group_w2.size()
self.smooth_scale_1 = Parameter(torch.ones((expert_num, self.in_feature_1), dtype=dtype), requires_grad=False)
self.group_w1_w3_scale = Parameter(
torch.ones(size=(expert_num, self.out_feature_1), dtype=dtype), requires_grad=False)
self.smooth_scale_2 = Parameter(torch.ones((expert_num, self.in_feature_2), dtype=dtype), requires_grad=False)
self.group_w2_scale = Parameter(
torch.ones(size=(expert_num, self.out_feature_2), dtype=dtype), requires_grad=False)
def forward(self, x, expert_tokens, group_list_type=0, pertoken_scale=None):
hidden_size = x.size(-1)
if pertoken_scale is None:
x, pertoken_scale = torch_npu.npu_dynamic_quant(x)
if pertoken_scale.dim() > 1:
pertoken_scale = pertoken_scale.reshape(-1)
x = x.view(-1, hidden_size)
mm1_mm3 = torch_npu.npu_grouped_matmul([x], [self.group_w1_w3],
group_list=expert_tokens, split_item=3,
output_dtype=torch.int32, group_type=0,
group_list_type=group_list_type,
tuning_config=[0]
)[0]
intermediate_h, pertoken_scale = torch_npu.npu_dequant_swiglu_quant(
mm1_mm3, weight_scale=self.group_w1_w3_scale,
quant_scale=self.smooth_scale_2,
group_index=expert_tokens,
activate_left=True,
quant_mode=1,
activation_scale=pertoken_scale
)
out_hidden = torch_npu.npu_grouped_matmul(
[intermediate_h], [self.group_w2], bias=None,
scale=[self.group_w2_scale], per_token_scale=[pertoken_scale],
group_list=expert_tokens, split_item=3,
output_dtype=self.dtype, group_type=0,
group_list_type=group_list_type,
tuning_config=[0]
)[0]
return out_hidden
class QKVParallelLinear(ColumnParallelLinear):
"""Linear layers for the attention's QKV transformation.
Linear layers for the linear transformation of the query, key, and value
vectors in the attention layer. The weight matrix is concatenated along
the output dimension. The layer is parallelized along the head dimension.
When the number of key/value heads is smaller than the number of query
heads (e.g., multi-query/grouped-query attention), the key/value head may
be replicated while the query heads are partitioned.
Args:
hidden_size: input hidden state size of the transformer.
head_size: size of each attention head.
total_num_heads: total number of attention query heads.
total_num_kv_heads: total number of attention key/value heads. If
None, assume total_num_kv_heads = total_num_heads.
bias: If true, add bias.
skip_bias_add: This was added to enable performance optimizations where
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
tp_size: The size of the tensor parallel communication domain
where the current device is located.
tp_rank: The rank index of current device in the located tensor parallel
communication domain.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
"""
def __init__(
self,
hidden_size: int,
head_size: int,
total_num_heads: int,
total_num_kv_heads: Optional[int] = None,
bias: bool = True,
skip_bias_add: bool = False,
tp_size: int = 1,
tp_rank: int = 0,
params_dtype: Optional[torch.dtype] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True):
self.hidden_size = hidden_size
self.head_size = head_size
self.total_num_heads = total_num_heads
if total_num_kv_heads is None:
total_num_kv_heads = total_num_heads
self.total_num_kv_heads = total_num_kv_heads
self.tp_size = tp_size
self.tp_rank = tp_rank
self.num_heads = divide(self.total_num_heads, tp_size)
if tp_size >= self.total_num_kv_heads:
self.num_kv_heads = 1
self.num_kv_head_replicas = divide(tp_size,
self.total_num_kv_heads)
else:
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
self.num_kv_head_replicas = 1
input_size = self.hidden_size
output_size = (self.num_heads +
2 * self.num_kv_heads) * tp_size * self.head_size
self.output_sizes = [
self.num_heads * self.head_size * tp_size,
self.num_kv_heads * self.head_size * tp_size,
self.num_kv_heads * self.head_size * tp_size,
]
super().__init__(input_size=input_size,
output_size=output_size,
bias=bias,
tp_size=tp_size,
tp_rank=tp_rank,
skip_bias_add=skip_bias_add,
params_dtype=params_dtype,
quant_config=quant_config,
prefix=prefix,
return_bias=return_bias,
)
def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):
param_data = param.data
output_dim = getattr(param, "output_dim", None)
needs_scalar_to_array = getattr(param, "needs_scalar_to_array", False)
if loaded_shard_id is None:
if output_dim is None:
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
shard_offsets = [
("q", 0, self.total_num_heads * self.head_size),
("k", self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size),
("v", (self.total_num_heads + self.total_num_kv_heads) *
self.head_size, self.total_num_kv_heads * self.head_size),
]
packed_dim = getattr(param, "packed_dim", None)
for shard_id, shard_offset, shard_size in shard_offsets:
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
assert loaded_shard_id in ["q", "k", "v"]
if output_dim is not None:
if loaded_shard_id == "q":
shard_offset = 0
shard_size = self.num_heads * self.head_size
elif loaded_shard_id == "k":
shard_offset = self.num_heads * self.head_size
shard_size = self.num_kv_heads * self.head_size
elif loaded_shard_id == "v":
shard_offset = (self.num_heads +
self.num_kv_heads) * self.head_size
shard_size = self.num_kv_heads * self.head_size
packed_dim = getattr(param, "packed_dim", None)
if packed_dim == output_dim:
shard_size = shard_size // param.pack_factor
shard_offset = shard_offset // param.pack_factor
is_sharded_weight = getattr(param, "is_sharded_weight", False)
is_sharded_weight = is_sharded_weight
param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if loaded_shard_id == "q":
shard_id = self.tp_rank
else:
shard_id = self.tp_rank // self.num_kv_head_replicas
start_idx = shard_id * shard_size
if not is_sharded_weight:
loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size)
else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning(
"Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same "
"for all partitions.")
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)