import torch
try:
import torch_npu
except ImportError:
torch_npu = None
class _GroupedMatmul(torch.autograd.Function):
@staticmethod
def forward(ctx, input_tensor: torch.Tensor, weights, weights_bias, m_split, group_list_type) -> torch.Tensor:
if not isinstance(m_split, torch.Tensor):
ctx.group_list = torch.tensor(m_split, device='npu', dtype=torch.int64)
else:
ctx.group_list = m_split
ctx.group_list_type = group_list_type
weights_t = [w[0].T for w in weights.chunk(weights.shape[0], dim=0)]
fwd_output = torch_npu.npu_grouped_matmul([input_tensor], weights_t, bias=weights_bias,
group_list=ctx.group_list, split_item=2, group_type=0,
group_list_type=ctx.group_list_type)[0]
ctx.save_for_backward(input_tensor, weights)
return fwd_output
@staticmethod
def backward(ctx, grad_output):
group_list = ctx.group_list
inp, *weights = ctx.saved_tensors
group_list_type = ctx.group_list_type
grad = torch_npu.npu_grouped_matmul([grad_output], weights, bias=None, group_list=group_list,
split_item=2, group_type=0, group_list_type=group_list_type)[0]
grad_weight = torch_npu.npu_grouped_matmul([inp.T], [grad_output], bias=None, group_list=group_list,
split_item=3, group_type=2, group_list_type=group_list_type)[0]
grad_weight = [w.T for w in grad_weight]
return grad, torch.stack(grad_weight), None, None, None,
def fused_grouped_matmul(inputs, m_split, weights):
return _GroupedMatmul.apply(inputs, weights, None, m_split, 1)
def eager_grouped_matmul(inputs, m_split, weights):
output_shape = inputs.shape[:-1] + (weights[0].shape[0],)
final_hidden_states = torch.zeros(output_shape, dtype=inputs.dtype, device=inputs.device)
group_list = [0] + torch.cumsum(m_split, dim=0).tolist()
for i in range(len(group_list) - 1):
final_hidden_states[group_list[i]:group_list[i + 1], ...] = torch.matmul(
inputs[group_list[i]:group_list[i + 1], ...], weights[i].T)
return final_hidden_states