from typing import List, Optional
import torch
from mindspeed_mm.fsdp.utils.device import IS_NPU_AVAILABLE
if IS_NPU_AVAILABLE:
import torch_npu
class GmmFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, weight, group_list):
ctx.save_for_backward(x, weight)
ctx.group_list = group_list
fwd_output = torch_npu.npu_grouped_matmul([x], [weight], bias=None, group_list=group_list,
split_item=2, group_type=0, group_list_type=1)[0]
return fwd_output
@staticmethod
def backward(ctx, grad_output):
input_tensor, weight = ctx.saved_tensors
group_list = ctx.group_list
weight = torch.transpose(weight, 1, 2)
grad_input = torch_npu.npu_grouped_matmul([grad_output], [weight], bias=None, group_list=group_list,
split_item=2, group_type=0, group_list_type=1)[0]
grad_weight = torch_npu.npu_grouped_matmul([input_tensor.T], [grad_output], bias=None, group_list=group_list,
split_item=3, group_type=2, group_list_type=1)[0]
return grad_input, grad_weight, None
def fused_group_gemm(x, weight, group_list):
output = GmmFunction.apply(x, weight, group_list)
return output
def eager_grouped_matmul(x, weight, group_list):
"""
Grouped matrix multiplication that handles two weight tensor formats.
Args:
inputs: Tensor of shape [batch_size, input_dim]
m_split: Tensor of group sizes that sum to batch_size
weights: Weight tensor of either:
Format 1: [num_groups, input_dim, output_dim] - ready for matmul
Format 2: [num_groups, output_dim, input_dim] - needs transpose
Returns:
Tensor of shape [batch_size, output_dim]
"""
inputs, m_split, weights = x, group_list, weight
batch_size, input_dim = inputs.shape
if weights.shape[1] == input_dim:
output_dim = weights.shape[2]
else:
output_dim = weights.shape[1]
weights = weights.transpose(1, 2)
output_shape = (batch_size, output_dim)
final_hidden_states = torch.zeros(output_shape, dtype=inputs.dtype, device=inputs.device)
group_list = [0] + torch.cumsum(m_split, dim=0).tolist()
for i in range(len(group_list) - 1):
start_idx = group_list[i]
end_idx = group_list[i + 1]
final_hidden_states[start_idx:end_idx, :] = torch.matmul(
inputs[start_idx:end_idx, :],
weights[i]
)
return final_hidden_states
def grouped_matmul(x, weight, group_list, fused=True):
if fused and IS_NPU_AVAILABLE:
return fused_group_gemm(x, weight, group_list)
else:
return eager_grouped_matmul(x, weight, group_list)