from typing import Tuple, List, Optional
import math
import torch
from . import register_ops
from ..utils import ParametersInvalid
def laser_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
atten_mask: torch.Tensor | None = None,
alibi_mask: torch.Tensor | None = None,
drop_mask: torch.Tensor | None = None,
scale_value: float | torch.Tensor = 1.0,
head_num: int = 2,
input_layout: str = "BNSD",
keep_prob: float = 1.0,
pre_tokens: int = 2147483647,
next_tokens: int = 1,
is_high_precision: bool = True,
) -> tuple[torch.Tensor, torch.Tensor]:
return getattr(torch.ops.mindiesd, "la")(
query=query,
key=key,
value=value,
atten_mask=atten_mask,
alibi_mask=alibi_mask,
drop_mask=drop_mask,
scale_value=scale_value,
head_num=head_num,
input_layout=input_layout,
keep_prob=keep_prob,
pre_tokens=pre_tokens,
next_tokens=next_tokens,
is_highPrecision=is_high_precision,
)
@register_ops.register_mindie_fake_op("la")
def attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
atten_mask: torch.Tensor = None,
alibi_mask: torch.Tensor = None,
drop_mask: torch.Tensor = None,
scale_value: float = 1.0,
head_num: int = 2,
input_layout: str = "BNSD",
keep_prob: float = 1.0,
pre_tokens: int = 2147483647,
next_tokens: int = 1,
is_high_precision: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
softmax_log_max_sum = torch.empty(
[query.shape[0], query.shape[1], query.shape[2]], device=query.device, dtype=query.dtype
)
output = torch.empty_like(query)
return softmax_log_max_sum, output
def laser_attention_preprocess(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, align_len: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return getattr(torch.ops.mindiesd, "la_preprocess")(query, key, value, align_len)
@register_ops.register_mindie_fake_op("la_preprocess")
def attention_preprocess_fake(
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, align_len: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size = query.shape[0]
head_num = query.shape[2] if query.dim() == 4 else query.shape[1]
head_dim = query.shape[3] if query.dim() == 4 else query.shape[2]
q_seq_len = query.shape[1] if query.dim() == 4 else query.shape[0]
k_seq_len = key.shape[1] if key.dim() == 4 else key.shape[0]
v_seq_len = value.shape[1] if value.dim() == 4 else value.shape[0]
def pad_length(length):
return (length + align_len - 1) // align_len * align_len
q_padded_seq_len = pad_length(q_seq_len)
k_padded_seq_len = pad_length(k_seq_len)
v_padded_seq_len = pad_length(v_seq_len)
def create_padded_tensor(tensor, padded_seq_len):
if tensor.dim() == 4:
return torch.empty(
[batch_size, padded_seq_len, head_num, head_dim],
device=tensor.device,
dtype=tensor.dtype,
)
else:
return torch.empty([padded_seq_len, head_num, head_dim], device=tensor.device, dtype=tensor.dtype)
out_query = create_padded_tensor(query, q_padded_seq_len)
out_key = create_padded_tensor(key, k_padded_seq_len)
out_value = create_padded_tensor(value, v_padded_seq_len)
return out_query, out_key, out_value
def rain_fusion_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
select_idx: torch.Tensor,
select_num_idx: torch.Tensor,
blockshape: List[int],
attn_mask: Optional[torch.Tensor] = None,
actual_seq_qlen: Optional[List[int]] = None,
actual_seq_kvlen: Optional[List[int]] = None,
block_table: Optional[torch.Tensor] = None,
q_input_layout: str = "TND",
kv_input_layout: str = "TND",
head_num: int = 1,
mask_type: int = 0,
scale: float = 1.0,
inner_precise: int = 1,
block_size: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
return getattr(torch.ops.mindiesd, "rainfusionattention")(
query=query,
key=key,
value=value,
select_idx=select_idx,
select_num_idx=select_num_idx,
blockshape=blockshape,
attn_mask=attn_mask,
actual_seq_qlen=actual_seq_qlen,
actual_seq_kvlen=actual_seq_kvlen,
block_table=block_table,
q_input_layout=q_input_layout,
kv_input_layout=kv_input_layout,
head_num=head_num,
mask_type=mask_type,
scale=scale,
inner_precise=inner_precise,
block_size=block_size,
)
@register_ops.register_mindie_fake_op("rainfusionattention")
def rain_fusion_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
select_idx: torch.Tensor,
select_num_idx: torch.Tensor,
blockshape: List[int],
attn_mask: Optional[torch.Tensor] = None,
actual_seq_qlen: Optional[List[int]] = None,
actual_seq_kvlen: Optional[List[int]] = None,
block_table: Optional[torch.Tensor] = None,
q_input_layout: str = "TND",
kv_input_layout: str = "TND",
head_num: int = 1,
mask_type: int = 0,
scale: float = 1.0,
inner_precise: int = 1,
block_size: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor]:
softmax_lse = torch.empty([query.shape[0], query.shape[1], query.shape[2]], device=query.device, dtype=query.dtype)
output = torch.empty_like(query)
return output, softmax_lse
def sparse_block_estimate(
query: torch.Tensor,
key: torch.Tensor,
actual_seq_lengths: Optional[List[int]] = None,
actual_seq_lengths_kv: Optional[List[int]] = None,
input_layout: str = "BNSD",
stride: int = 8,
sparse_size: int = 128,
num_heads: int = 1,
num_key_value_heads: int = 1,
scale_value: float = 1.0,
threshold: float = 1.0,
causal: bool = False,
keep_sink: bool = True,
keep_recent: bool = True,
row_sparse: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
return getattr(torch.ops.mindiesd, "sparse_block_estimate")(
query=query,
key=key,
actual_seq_lengths=actual_seq_lengths,
actual_seq_lengths_kv=actual_seq_lengths_kv,
input_layout=input_layout,
stride=stride,
sparse_size=sparse_size,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
scale_value=scale_value,
threshold=threshold,
causal=causal,
keep_sink=keep_sink,
keep_recent=keep_recent,
row_sparse=row_sparse,
)
@register_ops.register_mindie_fake_op("sparse_block_estimate")
def sparse_block_estimate_fake(
query: torch.Tensor,
key: torch.Tensor,
actual_seq_lengths: Optional[List[int]] = None,
actual_seq_lengths_kv: Optional[List[int]] = None,
input_layout: str = "BNSD",
stride: int = 8,
sparse_size: int = 128,
num_heads: int = 1,
num_key_value_heads: int = 1,
scale_value: float = 1.0,
threshold: float = 1.0,
causal: bool = False,
keep_sink: bool = True,
keep_recent: bool = True,
row_sparse: float = 1.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
b, nq, s, d = 0, 0, 0, 0
if input_layout == "BNSD":
b, nq, s, d = query.shape
elif input_layout == "BSND":
b, s, nq, d = query.shape
elif input_layout == "BSH":
b, s = query.shape[0], query.shape[1]
nq = num_heads
else:
raise ParametersInvalid(f"The input_layout only supports 'BNSD', 'BSND', 'BSH', but got {input_layout}")
seqlen_sparse = int((s + sparse_size - 1) / sparse_size)
seqlen_sparse_align32 = (seqlen_sparse + 31) / 32 * 32
sparse_mask_shape = (b, nq, seqlen_sparse, seqlen_sparse_align32)
sparse_count_table_shape = (b, nq, seqlen_sparse)
sparse_mask = torch.empty(sparse_mask_shape, device=query.device, dtype=torch.int8)
sparse_count_table = torch.empty(sparse_count_table_shape, device=query.device, dtype=torch.int32)
return sparse_mask, sparse_count_table
def ada_block_sparse_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sparse_mask: torch.Tensor,
sparse_count_table: torch.Tensor,
input_layout: str = "BNSD",
sparse_size: int = 128,
num_heads: int = 1,
num_key_value_heads: int = 1,
scale_value: float = 1.0,
causal: bool = False,
inner_precise: int = 1,
pre_tokens: int = 214748647,
next_tokens: int = 214748647,
actual_seq_lengths: Optional[List[int]] = None,
actual_seq_lengths_kv: Optional[List[int]] = None,
) -> torch.Tensor:
return getattr(torch.ops.mindiesd, "ada_block_sparse_attention")(
query=query,
key=key,
value=value,
sparse_mask=sparse_mask,
sparse_count_table=sparse_count_table,
input_layout=input_layout,
sparse_size=sparse_size,
num_heads=num_heads,
num_key_value_heads=num_key_value_heads,
scale_value=scale_value,
causal=causal,
inner_precise=inner_precise,
pre_tokens=pre_tokens,
next_tokens=next_tokens,
actual_seq_lengths=actual_seq_lengths,
actual_seq_lengths_kv=actual_seq_lengths_kv,
)
@register_ops.register_mindie_fake_op("ada_block_sparse_attention")
def ada_block_sparse_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
sparse_mask: torch.Tensor,
sparse_count_table: torch.Tensor,
input_layout: str = "BNSD",
sparse_size: int = 128,
num_heads: int = 1,
num_key_value_heads: int = 1,
scale_value: float = 1.0,
causal: bool = False,
inner_precise: int = 1,
pre_tokens: int = 214748647,
next_tokens: int = 214748647,
actual_seq_lengths: Optional[List[int]] = None,
actual_seq_lengths_kv: Optional[List[int]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
output = torch.empty_like(query)
return output
def block_sparse_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
block_sparse_mask: Optional[torch.Tensor] = None,
block_shape: List[int] = None,
q_input_layout: str = "BNSD",
kv_input_layout: str = "BNSD",
num_key_value_heads: int = 1,
scale_value: float = 1.0,
inner_precise: int = 4,
actual_seq_lengths: Optional[List[int]] = None,
actual_seq_lengths_kv: Optional[List[int]] = None,
softmax_lse_flag: int = 0,
q_dequant_scale: Optional[torch.Tensor] = None,
k_dequant_scale: Optional[torch.Tensor] = None,
v_dequant_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if block_shape is None:
block_shape = [128, 128]
kwargs = dict(
query=query,
key=key,
value=value,
block_sparse_mask=block_sparse_mask,
block_shape=block_shape,
q_input_layout=q_input_layout,
kv_input_layout=kv_input_layout,
num_key_value_heads=num_key_value_heads,
scale_value=scale_value,
inner_precise=inner_precise,
softmax_lse_flag=softmax_lse_flag,
)
if actual_seq_lengths is not None:
kwargs["actual_seq_lengths"] = actual_seq_lengths
if actual_seq_lengths_kv is not None:
kwargs["actual_seq_lengths_kv"] = actual_seq_lengths_kv
if q_dequant_scale is not None:
kwargs["q_dequant_scale"] = q_dequant_scale
kwargs["k_dequant_scale"] = k_dequant_scale
kwargs["v_dequant_scale"] = v_dequant_scale
return getattr(torch.ops.mindiesd, "block_sparse_attention")(**kwargs)
@register_ops.register_mindie_fake_op("block_sparse_attention")
def block_sparse_attention_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
block_sparse_mask: Optional[torch.Tensor] = None,
block_shape: List[int] = None,
q_input_layout: str = "BNSD",
kv_input_layout: str = "BNSD",
num_key_value_heads: int = 1,
scale_value: float = 1.0,
inner_precise: int = 4,
actual_seq_lengths: Optional[List[int]] = None,
actual_seq_lengths_kv: Optional[List[int]] = None,
softmax_lse_flag: int = 0,
q_dequant_scale: Optional[torch.Tensor] = None,
k_dequant_scale: Optional[torch.Tensor] = None,
v_dequant_scale: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
out_dtype = torch.bfloat16 if q_dequant_scale is not None else query.dtype
attention_out = torch.empty(query.shape, device=query.device, dtype=out_dtype)
if q_input_layout == "TND":
lse_shape = [query.shape[0], query.shape[1], 1]
else:
lse_shape = [query.shape[0], query.shape[1], query.shape[2], 1]
softmax_lse = torch.empty(lse_shape, device=query.device, dtype=torch.float32)
return attention_out, softmax_lse
def adaln(
x: torch.Tensor,
scale: torch.Tensor,
shift: torch.Tensor,
weight: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
epsilon: float = 1e-05,
) -> torch.Tensor:
return getattr(torch.ops.mindiesd, "adaln")(
x=x, scale=scale, shift=shift, weight=weight, bias=bias, epsilon=epsilon
)
@register_ops.register_mindie_fake_op("adaln")
def adaln_fake(
x: torch.Tensor,
scale: torch.Tensor,
shift: torch.Tensor,
weight: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
epsilon: float = 1e-05,
) -> torch.Tensor:
return torch.empty_like(x)
def layernorm(
x: torch.Tensor,
normalized_shape: List[int],
weight: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
eps: float = 1e-05,
impl_mode: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return getattr(torch.ops.mindiesd, "layernorm")(
input=x,
normalized_shape=normalized_shape,
weight=weight,
bias=bias,
eps=eps,
impl_mode=impl_mode,
)
@register_ops.register_mindie_fake_op("layernorm")
def layernorm_fake(
x: torch.Tensor,
normalized_shape: List[int],
weight: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
eps: float = 1e-05,
impl_mode: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
output = torch.empty_like(x)
norm_ndim = len(normalized_shape)
input_ndim = x.dim()
if norm_ndim > input_ndim:
raise ParametersInvalid(
f"normalized_shape must fit within input dimensions, but got "
f"normalized_shape={normalized_shape} (ndim={norm_ndim}) and input.dim()={input_ndim}"
)
begin_axis = input_ndim - norm_ndim
m = math.prod(x.shape[:begin_axis])
if m <= 0:
mean_shape = [m]
else:
mean_shape = list(x.shape[:begin_axis]) + [1] * norm_ndim
mean_out = torch.empty(mean_shape, dtype=x.dtype, device=x.device)
rstd_out = torch.empty(mean_shape, dtype=x.dtype, device=x.device)
return output, mean_out, rstd_out
def adaln_v2(
x: torch.Tensor,
scale: torch.Tensor,
shift: torch.Tensor,
weight: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
epsilon: float = 1e-05,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
return getattr(torch.ops.mindiesd, "adaln_v2")(
x=x, scale=scale, shift=shift, weight=weight, bias=bias, epsilon=epsilon
)
@register_ops.register_mindie_fake_op("adaln_v2")
def adaln_v2_fake(
x: torch.Tensor,
scale: torch.Tensor,
shift: torch.Tensor,
weight: torch.Tensor | None = None,
bias: torch.Tensor | None = None,
epsilon: float = 1e-05,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
b, s, _ = x.shape
mean_out = torch.empty((b, s, 1), dtype=x.dtype, device=x.device)
rstd_out = torch.empty((b, s, 1), dtype=x.dtype, device=x.device)
return torch.empty_like(x), mean_out, rstd_out