import functools
import warnings
import torch
import torch_npu
from torch_npu.utils._error_code import ErrCode, ops_error
__all__ = ['npu_fused_attention', 'npu_fused_attention_with_layernorm']
def _exec_once(func):
@functools.wraps(func)
def wrapper_exec_once(*args, **kwargs):
if not wrapper_exec_once.called:
wrapper_exec_once.called = True
func(*args, **kwargs)
wrapper_exec_once.called = False
return wrapper_exec_once
VALID_FORMAT = [29, 29, 29, 29, 29, 2, 2, 2]
def _is_format_matched(input_list):
format_list = list(map(torch_npu.get_npu_format, input_list))
return format_list == VALID_FORMAT
@_exec_once
def _check_compatibility_once(hidden_states,
attention_mask,
query_kernel,
key_kernel,
value_kernel,
query_bias,
key_bias,
value_bias,
gamma=None,
beta=None):
if not _is_format_matched(
[hidden_states, attention_mask, query_kernel, key_kernel, value_kernel, query_bias, key_bias, value_bias]):
raise RuntimeError(
'fused attention check compatibility failed, format not matches' + ops_error(ErrCode.VALUE))
if gamma is not None and beta is not None:
if torch_npu.get_npu_format(gamma) != 2 or torch_npu.get_npu_format(
beta) != 2:
raise RuntimeError(
'fused attention check compatibility failed, gamma or beta format not matches' +
ops_error(ErrCode.VALUE)
)
if len(hidden_states.size()) != 2 or hidden_states.shape[
0] % 32 != 0 or hidden_states.shape[1] not in (1024, 768):
raise RuntimeError(
'fused attention check compatibility failed, shape of hidden_states not matches' + ops_error(ErrCode.VALUE)
)
if len(attention_mask.size()) != 4 or attention_mask.shape[1] != 1 or (
attention_mask.shape[2] != attention_mask.shape[3]):
raise RuntimeError(
'fused attention check compatibility failed, shape of attention_mask not matches' + ops_error(ErrCode.VALUE)
)
if query_kernel.shape[0] not in (1024, 768) or key_kernel.shape[0] not in (
1024, 768) or value_kernel.shape[0] not in (1024, 768):
raise RuntimeError(
'fused attention check compatibility failed, shape of kernel not matches' + ops_error(ErrCode.VALUE)
)
def _permute_with_reshape(x, new_shape):
return torch_npu.npu_format_cast(torch_npu.npu_confusion_transpose(x,
(0, 2, 1, 3),
new_shape, False), 29)
class _FusedAttentionWithLayerNorm(torch.autograd.Function):
@staticmethod
def forward(ctx,
hidden_states,
attention_mask,
query_kernel,
key_kernel,
value_kernel,
query_bias,
key_bias,
value_bias,
gamma,
beta,
scale=1,
keep_prob=0):
warnings.warn("torch_npu.contrib.npu_fused_attention_with_layernorm is deprecated and "
"will be removed in future version. Use torch_npu.npu_fusion_attention and "
"torch.nn.LayerNorm instead.", FutureWarning)
_check_compatibility_once(hidden_states, attention_mask, query_kernel,
key_kernel, value_kernel, query_bias,
key_bias, value_bias, gamma, beta)
ctx.bsnc = [
attention_mask.shape[0],
hidden_states.shape[0] // attention_mask.shape[0],
hidden_states.shape[1] // 64, 64
]
norm, query_layer, key_layer, value_layer, mean, variance = torch_npu.npu_fused_attention_layernorm_qkv_fwd(
hidden_states, query_kernel, key_kernel, value_kernel, gamma, beta,
query_bias, key_bias, value_bias, ctx.bsnc[1], ctx.bsnc[2])
context_layer, softmax_output, dropout_mask = torch_npu.npu_fused_attention_score_fwd(
query_layer, key_layer, value_layer, attention_mask, scale, keep_prob)
ctx.scale = scale
ctx.keep_prob = keep_prob
ctx.save_for_backward(query_kernel, key_kernel, value_kernel,
query_layer, key_layer, value_layer,
hidden_states, softmax_output, dropout_mask,
norm, mean, variance, gamma, beta)
return context_layer, norm
@staticmethod
def backward(ctx, grad_output, grad_norm):
q_w, k_w, v_w, q_l, k_l, v_l, h_s, s_o, d_m, norm, mean, variance, gamma, beta = ctx.saved_variables
query_grad, key_grad, value_grad = torch_npu.npu_fused_attention_score_grad(
grad_output, s_o, q_l, k_l, v_l, d_m, ctx.scale, ctx.keep_prob)
g_h_s, g_w_q, g_w_k, g_w_v, g_b_q, g_b_k, g_b_v = torch_npu.npu_fused_attention_qkv_grad(
query_grad, key_grad, value_grad, q_w, k_w, v_w, norm,
torch_npu.npu_format_cast(grad_norm, 29))
g_h_s, g_gamma, g_beta = torch_npu.npu_layernorm_grad(
g_h_s, h_s, (g_h_s.shape[1],), mean, variance, gamma, beta)
return g_h_s, None, g_w_q, g_w_k, g_w_v, g_b_q, g_b_k, g_b_v, g_gamma, g_beta, None, None
npu_fused_attention_with_layernorm = _FusedAttentionWithLayerNorm.apply
class _FusedAttention(torch.autograd.Function):
@staticmethod
def forward(ctx,
hidden_states,
attention_mask,
query_kernel,
key_kernel,
value_kernel,
query_bias,
key_bias,
value_bias,
scale=1,
keep_prob=0):
_check_compatibility_once(hidden_states, attention_mask, query_kernel,
key_kernel, value_kernel, query_bias,
key_bias, value_bias, None, None)
ctx.bsnc = [
attention_mask.shape[0],
hidden_states.shape[0] // attention_mask.shape[0],
hidden_states.shape[1] // 64, 64
]
with torch.no_grad():
query_layer = _permute_with_reshape(
torch_npu.npu_linear(hidden_states, query_kernel.t(), query_bias),
(ctx.bsnc[0], ctx.bsnc[1], ctx.bsnc[2], ctx.bsnc[3]))
key_layer = _permute_with_reshape(
torch_npu.npu_linear(hidden_states, key_kernel.t(), key_bias),
(ctx.bsnc[0], ctx.bsnc[1], ctx.bsnc[2], ctx.bsnc[3]))
value_layer = _permute_with_reshape(
torch_npu.npu_linear(hidden_states, value_kernel.t(), value_bias),
(ctx.bsnc[0], ctx.bsnc[1], ctx.bsnc[2], ctx.bsnc[3]))
context_layer, softmax_output, dropout_mask = torch_npu.npu_fused_attention_score_fwd(
query_layer, key_layer, value_layer, attention_mask, scale, keep_prob)
ctx.scale = scale
ctx.keep_prob = keep_prob
ctx.save_for_backward(query_kernel, key_kernel, value_kernel,
query_layer, key_layer, value_layer,
hidden_states, softmax_output, dropout_mask)
return context_layer
@staticmethod
def backward(ctx, grad_output):
q_w, k_w, v_w, q_l, k_l, v_l, h_s, s_o, d_m = ctx.saved_variables
query_grad, key_grad, value_grad = torch_npu.npu_fused_attention_score_grad(
grad_output, s_o, q_l, k_l, v_l, d_m, ctx.scale, ctx.keep_prob)
g_h_s, g_w_q, g_w_k, g_w_v, g_b_q, g_b_k, g_b_v = torch_npu.npu_fused_attention_qkv_grad(
query_grad, key_grad, value_grad, q_w, k_w, v_w, h_s,
torch_npu.npu_format_cast(torch.zeros_like(h_s), 29))
return g_h_s, None, g_w_q, g_w_k, g_w_v, g_b_q, g_b_k, g_b_v, None, None
npu_fused_attention = _FusedAttention.apply