__golden__ = {
"kernel": {
"grouped_matmul_add": "grouped_matmul_add_golden"
},
"e2e": {
"torch_npu.npu_grouped_matmul_add_": "torch_grouped_matmul_add_golden"
}
}
import ml_dtypes
import torch
import numpy as np
def grouped_matmul_add_golden(x, weight, group_list, y, transpose_x: bool = True, transpose_weight: bool = False,
group_type: int = 2, group_list_type: int = 0, **kwargs):
output_dtypes = kwargs['output_dtypes']
out_dtype = output_dtypes[0]
outs = []
group_num = len(group_list)
if group_list_type == 1:
group_list = np.cumsum(group_list)
x = torch.from_numpy(x.astype(np.float32))
weight = torch.from_numpy(weight.astype(np.float32))
for i in range(group_num):
M = x.shape[-1]
N = weight.shape[-1]
pre = 0 if i == 0 else group_list[i - 1]
cur = group_list[i]
if cur - pre == 0:
out = np.zeros((M, N), dtype=out_dtype)
outs.append(out)
continue
x_g = x[pre:cur, :]
x_g = np.swapaxes(x_g, -1, -2)
weight_g = weight[pre:cur, :]
out = torch.matmul(x_g, weight_g).numpy().astype(out_dtype)
outs.append(out)
real_out = outs if not outs else np.concatenate(outs, axis=0)
real_out = real_out.reshape(y.shape)
real_out = real_out + y
return real_out
def torch_grouped_matmul_add_golden(self_ref, x, weight, group_list, *, transpose_x=None, transpose_weight=None,
group_type=None, group_list_type=None, **kwargs):
self_ref = torch_to_numpy(self_ref)
x = torch_to_numpy(x)
weight = torch_to_numpy(weight)
group_list = torch_to_numpy(group_list)
param_group_list = group_list
x_dtype = x.dtype
weight_dtype = weight.dtype
group_tensor_new, _, group_list_cumsum = group_list_diff_cumsum_process(group_list_type, param_group_list)
if group_list_cumsum[-1] > x.shape[0]:
raise Exception("Sum of grouplist ({}) exceeds the first dim of x[0]:({})".format(group_list_cumsum[-1], x.shape[0]))
x = transpose_last_two_dim(x, transpose_x)
x = input_convert_to_high_precision(x, x_dtype)
weight = transpose_last_two_dim(weight, transpose_weight)
weight = input_convert_to_high_precision(weight, weight_dtype)
x, w, _ = gmm_non_quant_data_split_per_expect(x, weight, group_type, group_list_cumsum)
output = gmm_non_quant_cpu_golden_calculation(x, w)
params_out_dtype = self_ref.dtype
output = output_dtype_post_process(output, params_out_dtype)
output = golden_output_process(output, 2, group_type)
y_input = torch.add(torch.from_numpy(self_ref), output[0])
return y_input
def group_list_diff_cumsum_process(group_list_type, param_group_list, group_idx=None):
if group_idx is None:
group_idx = [0]
if group_list_type == 0:
count_values = np.diff(param_group_list)
count = np.insert(count_values, 0, param_group_list[0])
cumsum = param_group_list
group_tensor_new = torch.tensor(cumsum)
elif group_list_type == 2:
group_list_tensor = torch.tensor(param_group_list)
group_idx_tensor = torch.tensor(group_idx)
group_tensor_new = torch.stack((group_idx_tensor, group_list_tensor), dim=1).to(torch.int64)
sorted_group_list = [param_group_list[i] for i in sorted(range(len(group_idx)), key=lambda k: group_idx[k])]
count = sorted_group_list.copy()
cumsum = np.cumsum(sorted_group_list)
else:
count = param_group_list
cumsum = np.cumsum(param_group_list)
group_tensor_new = torch.tensor(count)
return group_tensor_new, count, cumsum
def transpose_last_two_dim(input, transpose_flag):
if isinstance(input, torch.Tensor):
return torch_transpose_last_two_dim(input, transpose_flag)
elif isinstance(input, np.ndarray):
return np_transpose_last_two_dim(input, transpose_flag)
else:
raise TypeError("Unsupported dtype!")
def np_transpose_last_two_dim(input, transpose_flag):
if transpose_flag == 1:
ndim = input.ndim
axes = list(range(ndim - 2)) + [ndim - 1, ndim - 2]
return np.transpose(input, axes=axes)
else:
return input
def torch_transpose_last_two_dim(input, transpose_flag):
if transpose_flag == 1:
ndim = input.dim()
if ndim < 2:
raise ValueError("Tensor must have at least 2 dimensions.")
new_dims = list(range(ndim - 2)) + [ndim - 1, ndim - 2]
return input.permute(*new_dims)
else:
return input
def input_convert_to_high_precision(input, dtype):
if isinstance(input, torch.Tensor):
return torch_convert_to_high_precision(input, dtype)
elif isinstance(input, np.ndarray):
return np_convert_to_high_precision(input, dtype)
else:
raise TypeError("Unsupported dtype")
def np_convert_to_high_precision(input, dtype):
if dtype in ("float8_e4m3fn", "float8_e5m2", "float4_e2m1", "float4_e1m2", "hifloat8", "float16", "bfloat16"):
return torch.from_numpy(input.astype(np.float32))
elif dtype in ("int4",):
return torch.from_numpy(input.astype(np.int32)).to(torch.int32)
else:
return torch.from_numpy(input).to(torch.int32)
def torch_convert_to_high_precision(input, dtype):
if dtype == 'int8':
return input.to(torch.int8).numpy()
elif dtype == 'int32':
return input.to(torch.int32).numpy()
elif dtype in ('float8_e4m3fn', 'float8_e5m2'):
return input.to(torch.float32).numpy()
elif dtype in ('hifloat8',):
return input.to(torch.float32).numpy()
elif dtype in ('bf16', 'bfloat16'):
return input.to(torch.float32).numpy()
else:
return input.to(torch.float32).numpy()
def get_x_weight_bias(x_input, w_input, group_type, group_list_cumsum, biass=None, has_bias=0):
Xs, Ws, Bs = [], [], []
if group_type == 0:
offset = 0
for i in range(len(group_list_cumsum)):
Xs.append(x_input[offset: group_list_cumsum[i], :])
Ws.append(w_input[i])
offset = group_list_cumsum[i]
elif group_type == 2:
offset = 0
for i in range(len(group_list_cumsum)):
Xs.append(x_input[:, offset: group_list_cumsum[i]])
Ws.append(w_input[offset: group_list_cumsum[i], :])
offset = group_list_cumsum[i]
else:
Xs.append(x_input)
Ws.append(w_input)
E = len(group_list_cumsum)
if has_bias == 1:
Bs = np.split(biass, E, axis=0)
return Xs, Ws, Bs
def gmm_non_quant_data_split_per_expect(x1, x2, group_type, grouplist, bias=None, has_bias=0):
return get_x_weight_bias(x1, x2, group_type, grouplist, bias, has_bias)
def gmm_non_quant_cpu_golden_calculation(x, weight, bias=None, has_bias=0):
return gmm_non_quant_golden_calculate(x, weight, bias, has_bias)
def gmm_non_quant_golden_calculate(x, weight, bias=None, has_bias=0):
output = [None] * len(x)
for i in range(len(x)):
output[i] = np.matmul(x[i], weight[i])
if has_bias == 1:
output[i] = torch.add(output[i], torch.from_numpy(bias[i]))
return output
def output_dtype_post_process(output, dtype):
output_new = [None] * len(output)
for i in range(len(output)):
if dtype == 'bfloat16':
output_new[i] = output[i].to(torch.bfloat16)
elif dtype == 'float16':
output_new[i] = output[i].half()
elif dtype == 'int8':
output_new[i] = output[i].to(torch.int8)
else:
output_new[i] = output[i].float()
return output_new
def golden_output_process(output, param_split_item, group_type, activation_feature_out=None, dyn_scale_out=None):
golden = []
if param_split_item in [2, 3]:
if group_type == 0:
ret = torch.cat(output, dim=-2)
golden.append(ret)
elif group_type == 2:
ret = torch.stack(output, dim=0)
golden.append(ret)
else:
raise TypeError("unsupported group type, please check")
num_inf = torch.isinf(golden[0]).sum().item()
num_nan = torch.isnan(golden[0]).sum().item()
ratio = 100 * ((num_inf + num_nan) / (ret.shape[0] * ret.shape[1])) if ret.shape[0] * ret.shape[1] != 0 else 0
else:
golden = output
return golden
def torch_to_numpy(x):
if isinstance(x, torch.Tensor):
if x.dtype == torch.float8_e5m2:
x = x.to(torch.float32).numpy().astype(ml_dtypes.float8_e5m2)
elif x.dtype == torch.float8_e4m3fn:
x = x.to(torch.float32).numpy().astype(ml_dtypes.float8_e4m3fn)
elif x.dtype == torch.bfloat16:
x = x.to(torch.float32).numpy().astype(ml_dtypes.bfloat16)
else:
x = x.numpy()
return x