import torch
import torch_npu
import torch.distributed as dist
import numpy as np
from einops import rearrange
from functools import lru_cache
from scipy.sparse.linalg import eigsh
from mindspeed.ops.npu_ring_attention_update import npu_ring_attention_update
from mindspeed.core.context_parallel import get_args
from mindspeed.core.context_parallel import get_context_parallel_global_ranks
from mindspeed.core.context_parallel.model_parallel_utils import get_context_parallel_for_hybrid_ring_global_ranks
from mindspeed.op_builder import AdaptiveCpOpBuilder
ADAPTIVE_CP_SCHEDULING_INFO = None
ADAPTIVE_CP_SEQ_ORDER = None
CACHED_GRID_MASK = None
CACHED_SEQ = None
CACHED_MASK_LIST = []
CACHED_SCHEDULING = None
COMM_THRESHOLD = 6
ADAPTIVE_CP_DEFAULT_SHAPE = 1024
ADAPTIVE_CP_MASK_LIST_SET_BY_USER = None
ADAPTIVE_CP_GRID_MASK_SET_BY_USER = None
def sbh_to_tnd(x, n):
s, b, h = x.shape
d, t = h // n, int(b * s)
return x.transpose(0, 1).view(t, h).view(t, n, d)
def tnd_to_sbh(x, b):
t, n, d = x.shape
s, h = t // b, int(n * d)
return x.view(b, s, n, d).transpose(0, 1).view(s, b, h)
@lru_cache(maxsize=8)
def get_selection_indices_for_tnd_softmax_update(t, n, sub_seq_len):
full_indices = list(range(t * n))
cur_seq_start_idx = 0
indices = []
seq_start = 0
for seq_len in sub_seq_len:
for i in range(n):
start = seq_start + seq_len * 2 * i + seq_len
end = seq_start + seq_len * 2 * (i + 1)
indices.extend(full_indices[start:end])
seq_start += seq_len * n * 2
return torch.tensor(indices)
def flatten_softmax(x, sub_seq_len):
orig_shape = x.shape
section_len = [s * orig_shape[1] for s in sub_seq_len]
splits = x.view(-1, orig_shape[-1]).split(section_len, dim=0)
merged = [item.view(orig_shape[1], -1, orig_shape[-1]).transpose(0, 1) for item in splits]
merged = torch.cat(merged, dim=0)
return merged
def unflatten_softmax(x, sub_seq_len):
orig_shape = x.shape
section_len = [s * orig_shape[1] for s in sub_seq_len]
splits = x.view(-1, orig_shape[-1]).split(section_len, dim=0)
merged = [item.view(-1, orig_shape[1], orig_shape[-1]).transpose(0, 1) \
.view(-1, orig_shape[-1]) for item in splits]
merged = torch.cat(merged, dim=0)
return merged.view(*orig_shape)
def forward_update_without_fused(prev_attn_out, prev_softmax_max, prev_softmax_sum,
cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH'):
if layout == 'TND':
cur_softmax_max = flatten_softmax(cur_softmax_max, actual_seq_qlen)
cur_softmax_sum = flatten_softmax(cur_softmax_sum, actual_seq_qlen)
prev_softmax_max = flatten_softmax(prev_softmax_max, actual_seq_qlen)
prev_softmax_sum = flatten_softmax(prev_softmax_sum, actual_seq_qlen)
origin_dtype = prev_attn_out.dtype
softmax_max = torch.maximum(prev_softmax_max, cur_softmax_max)
prev_scale = torch.exp(prev_softmax_max - softmax_max)
cur_scale = torch.exp(cur_softmax_max - softmax_max)
prev_softmax_sum_scaled = prev_softmax_sum * prev_scale
cur_softmax_sum_scaled = cur_softmax_sum * cur_scale
softmax_sum = prev_softmax_sum_scaled + cur_softmax_sum_scaled
prev_out_scale = prev_softmax_sum_scaled / softmax_sum
cur_out_scale = cur_softmax_sum_scaled / softmax_sum
if layout == 'SBH':
n = prev_out_scale.shape[1]
h = prev_attn_out.shape[-1]
d = h // n
prev_out_scale = prev_out_scale[..., 0].unsqueeze(3).repeat(1, 1, 1, d)
prev_out_scale = rearrange(prev_out_scale, 'b n s d -> s b (n d)').contiguous()
cur_out_scale = cur_out_scale[..., 0].unsqueeze(3).repeat(1, 1, 1, d)
cur_out_scale = rearrange(cur_out_scale, 'b n s d -> s b (n d)').contiguous()
elif layout == 'TND':
d = prev_attn_out.shape[-1]
prev_out_scale = prev_out_scale[..., 0].unsqueeze(2).repeat(1, 1, d)
cur_out_scale = cur_out_scale[..., 0].unsqueeze(2).repeat(1, 1, d)
attn_out = prev_attn_out * prev_out_scale + cur_attn_out * cur_out_scale
attn_out = attn_out.to(origin_dtype)
if layout == 'TND':
softmax_max = unflatten_softmax(softmax_max, actual_seq_qlen)
softmax_sum = unflatten_softmax(softmax_sum, actual_seq_qlen)
return attn_out, softmax_max, softmax_sum
class RingP2P:
def __init__(self, ring_global_ranks, group, group_for_send_recv_overlap=None, is_backward=False) -> None:
self.group = group
self.group_for_send_recv_overlap = group
if group_for_send_recv_overlap is not None:
self.group_for_send_recv_overlap = group_for_send_recv_overlap
global_rank = dist.get_rank()
ring_rank = ring_global_ranks.index(global_rank)
ring_size = len(ring_global_ranks)
self.next = ring_global_ranks[(ring_rank + 1) % ring_size]
self.prev = ring_global_ranks[(ring_rank + ring_size - 1) % ring_size]
self.ring_rank = ring_rank
if is_backward:
self.next, self.prev = self.prev, self.next
self.send_recv_ops = []
def async_send_recv(self, orig_send_tensor, orig_recv_tensor, shapes=None):
send_tensor, recv_tensor = orig_send_tensor, orig_recv_tensor
enable_mla = isinstance(orig_send_tensor, (list, tuple))
if enable_mla:
if shapes is not None:
raise ValueError("MLA context parallel does not support uneven shapes yet.")
if len(orig_send_tensor) != 2 or len(orig_recv_tensor) != 2:
raise ValueError(
f"Expected tensors of length 2 (k,v), got lengths: "
f"send={len(orig_send_tensor)}, recv={len(orig_recv_tensor)}"
)
k_send, v_send = orig_send_tensor
k_recv, v_recv = orig_recv_tensor
if k_send.shape != k_recv.shape or v_send.shape != v_recv.shape:
raise ValueError(
"Shape mismatch in KV tensors:\n"
f" k_send: {k_send.shape} vs k_recv: {k_recv.shape}\n"
f" v_send: {v_send.shape} vs v_recv: {v_recv.shape}"
)
k_shape, v_shape = k_send.shape, v_send.shape
k_numel = k_send.numel()
send_tensor = torch.cat((k_send.view(-1), v_send.view(-1)), dim=-1)
recv_tensor = torch.cat((k_recv.view(-1), v_recv.view(-1)), dim=-1)
if self.ring_rank % 2 == 0:
if shapes is not None:
send_tensor_shape_list = list(send_tensor.shape)
send_tensor_shape_list[-3] = shapes[0]
send_tensor.resize_(send_tensor_shape_list)
send_op = dist.isend(send_tensor, self.next, self.group)
if shapes is not None:
recv_tensor_shape_list = list(recv_tensor.shape)
recv_tensor_shape_list[-3] = shapes[1]
recv_tensor.resize_(recv_tensor_shape_list)
recv_op = dist.irecv(recv_tensor, self.prev, self.group_for_send_recv_overlap)
self.send_recv_ops.append(send_op)
self.send_recv_ops.append(recv_op)
else:
if shapes is not None:
recv_tensor_shape_list = list(recv_tensor.shape)
recv_tensor_shape_list[-3] = shapes[1]
recv_tensor.resize_(recv_tensor_shape_list)
recv_op = dist.irecv(recv_tensor, self.prev, self.group)
if shapes is not None:
send_tensor_shape_list = list(send_tensor.shape)
send_tensor_shape_list[-3] = shapes[0]
send_tensor.resize_(send_tensor_shape_list)
send_op = dist.isend(send_tensor, self.next, self.group_for_send_recv_overlap)
self.send_recv_ops.append(recv_op)
self.send_recv_ops.append(send_op)
if enable_mla:
orig_recv_tensor[0] = recv_tensor[:k_numel].view(*k_shape)
orig_recv_tensor[1] = recv_tensor[k_numel:].view(*v_shape)
def wait(self):
if len(self.send_recv_ops) > 0:
for op in self.send_recv_ops:
op.wait()
self.send_recv_ops = []
return 1
else:
return 0
def forward_update(prev_attn_out, prev_softmax_max, prev_softmax_sum,
cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout='SBH'):
"""
Updates the attention output and softmax statistics for the ring attention mechanism,
with added parameters for enhanced flexibility and extensibility.
This function is designed to update the attention output and related softmax statistics
for a given sequence length in a ring attention mechanism. It handles the merging of
previous and current attention outputs and their corresponding softmax statistics.
The introduction of `actual_seq_qlen` and `layout` parameters allows for greater flexibility
in handling variable sequence lengths and different tensor layouts, respectively.
Parameters:
- prev_attn_out (Tensor): The attention output from the previous process.
- prev_softmax_max (Tensor): The maximum value of the softmax distribution from the previous process.
- prev_softmax_sum (Tensor): The sum of the softmax distribution from the previous process.
- cur_attn_out (Tensor): The attention output from the current process.
- cur_softmax_max (Tensor): The maximum value of the softmax distribution from the current process.
- cur_softmax_sum (Tensor): The sum of the softmax distribution from the current process.
- actual_seq_qlen (Tensor, optional): The actual sequence length for the query. This parameter
is crucial for handling variable-length sequences and ensuring
that the attention mechanism operates correctly under such conditions.
If not provided, it defaults to the length of the current attention output.
- layout (str, optional): The layout format of the input tensors. This parameter allows for the specification
of different tensor layouts, enhancing the function's versatility across various
model architectures. Default is 'SBH', where:
- S: Sequence length
- B: Batch size
- H: Hidden size (number of attention heads)
Returns:
- updated_attn_out (Tensor): The updated attention output after merging previous and current process.
- updated_softmax_max (Tensor): The updated maximum value of the softmax distribution.
- updated_softmax_sum (Tensor): The updated sum of the softmax distribution.
"""
_args = get_args()
if hasattr(_args, 'use_fused_ring_attention_update') and _args.use_fused_ring_attention_update:
def accumulate_list(input_list):
"""
借助numpy库将列表转换为numpy数组进行元素累加,再转换回列表并在开头添加0
"""
np_array = np.array(input_list)
cumsum_result = np.cumsum(np_array)
return torch.tensor([0] + list(cumsum_result), dtype=torch.int64).to(prev_attn_out.device)
if layout == "TND":
actual_seq_qlen = accumulate_list(actual_seq_qlen)
return npu_ring_attention_update(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out,
cur_softmax_max, cur_softmax_sum, actual_seq_qlen, layout)
return forward_update_without_fused(prev_attn_out, prev_softmax_max, prev_softmax_sum, cur_attn_out,
cur_softmax_max, cur_softmax_sum, actual_seq_qlen, layout)
def tnd_out_update(q_block_id, kv_block_id, cur_attn_outs, global_attn_outs, q_index, softmax_indices, cur_sub_out_seq_len):
cur_attn_out, cur_softmax_max, cur_softmax_sum = cur_attn_outs[0], cur_attn_outs[1], cur_attn_outs[2]
attn_out, softmax_max, softmax_sum, rng_states = global_attn_outs
layout = 'TND'
if len(cur_attn_outs) > 3:
rng_states[kv_block_id] = (cur_attn_outs[4], cur_attn_outs[5], cur_attn_outs[6])
if q_block_id == kv_block_id:
attn_out = cur_attn_out
softmax_max = cur_softmax_max
softmax_sum = cur_softmax_sum
elif kv_block_id <= q_block_id:
attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
attn_out, softmax_max, softmax_sum,
cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=cur_sub_out_seq_len, layout=layout
)
attn_out, softmax_max, softmax_sum = attn_out_updated, softmax_max_updated, softmax_sum_updated
else:
n = attn_out.shape[1]
t = attn_out.shape[0]
prev_softmax_max = softmax_max.view(-1, 8)[softmax_indices].view(-1, n, 8)
prev_softmax_sum = softmax_sum.view(-1, 8)[softmax_indices].view(-1, n, 8)
attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
torch.index_select(attn_out, 0, q_index), prev_softmax_max, prev_softmax_sum,
cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=cur_sub_out_seq_len, layout=layout
)
attn_out.index_copy_(0, q_index, attn_out_updated)
softmax_max = softmax_max.view(-1, 8).index_copy(0, softmax_indices, softmax_max_updated.view(-1, 8)).view(-1, n, 8)
softmax_sum = softmax_sum.view(-1, 8).index_copy(0, softmax_indices, softmax_sum_updated.view(-1, 8)).view(-1, n, 8)
return [attn_out, softmax_max, softmax_sum, rng_states]
def causal_out_update(q_block_id, kv_block_id, cur_attn_outs, global_attn_outs):
cur_attn_out, cur_softmax_max, cur_softmax_sum = cur_attn_outs[0], cur_attn_outs[1], cur_attn_outs[2]
attn_out, softmax_max, softmax_sum, rng_states = global_attn_outs
layout = 'SBH'
if len(cur_attn_outs) > 3:
rng_states[kv_block_id] = (cur_attn_outs[4], cur_attn_outs[5], cur_attn_outs[6])
if q_block_id == kv_block_id:
attn_out = cur_attn_out
softmax_max = cur_softmax_max
softmax_sum = cur_softmax_sum
elif kv_block_id <= q_block_id:
attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
attn_out, softmax_max, softmax_sum,
cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout=layout
)
attn_out, softmax_max, softmax_sum = attn_out_updated, softmax_max_updated, softmax_sum_updated
else:
attn_out = attn_out.view(2, attn_out.shape[0] // 2, *attn_out.shape[1:])
softmax_max = softmax_max.view(softmax_max.shape[0], softmax_max.shape[1],
2, softmax_max.shape[2] // 2, softmax_max.shape[-1])
softmax_sum = softmax_sum.view(softmax_sum.shape[0], softmax_sum.shape[1],
2, softmax_sum.shape[2] // 2, softmax_sum.shape[-1])
attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
attn_out[1], softmax_max[:, :, 1, :, :], softmax_sum[:, :, 1, :, :],
cur_attn_out, cur_softmax_max, cur_softmax_sum, actual_seq_qlen=None, layout=layout
)
attn_out[1].copy_(attn_out_updated)
softmax_max[:, :, 1, :, :].copy_(softmax_max_updated)
softmax_sum[:, :, 1, :, :].copy_(softmax_sum_updated)
attn_out = attn_out.view(-1, *attn_out.shape[2:])
softmax_max = softmax_max.view(softmax_max.shape[0], softmax_max.shape[1], -1,
softmax_max.shape[-1])
softmax_sum = softmax_sum.view(softmax_sum.shape[0], softmax_sum.shape[1], -1,
softmax_sum.shape[-1])
return [attn_out, softmax_max, softmax_sum, rng_states]
def general_out_update(q_block_id, kv_block_id, cur_attn_outs, global_attn_outs):
cur_attn_out, cur_softmax_max, cur_softmax_sum = cur_attn_outs[0], cur_attn_outs[1], cur_attn_outs[2]
attn_out, softmax_max, softmax_sum, rng_states = global_attn_outs
layout = 'SBH'
rng_states[kv_block_id] = (cur_attn_outs[4], cur_attn_outs[5], cur_attn_outs[6])
if q_block_id == kv_block_id:
attn_out = cur_attn_out
softmax_max = cur_softmax_max
softmax_sum = cur_softmax_sum
else:
attn_out_updated, softmax_max_updated, softmax_sum_updated = forward_update(
attn_out, softmax_max, softmax_sum,
cur_attn_out, cur_softmax_max, cur_softmax_sum, layout=layout
)
attn_out, softmax_max, softmax_sum = attn_out_updated, softmax_max_updated, softmax_sum_updated
return [attn_out, softmax_max, softmax_sum, rng_states]
class SchedulingInfo:
def __init__(self, round_idx, recv_q_src: int = -1, recv_kv_src: int = -1, recv_o_src: list = None,
send_q_dst=None, send_kv_dst: list = None, send_o_dst: int = -1, comm_unit_limit=6):
self.round_idx = round_idx
self.recv_q_src = recv_q_src
self.recv_kv_src = recv_kv_src
self.recv_o_src = [] if recv_o_src is None else recv_o_src
self.send_q_dst = [] if send_q_dst is None else send_q_dst
self.send_kv_dst = [] if send_kv_dst is None else send_kv_dst
self.send_o_dst = send_o_dst
self.comm_unit_limit = comm_unit_limit
self.cnt_comm_unit_forward = -1
self.check_eligibility()
def check_eligibility(self):
if self.recv_q_src > -1 and self.recv_kv_src > -1:
raise ValueError("only receive one of q and kv in a single round")
self.count_comm_units()
if self.cnt_comm_unit_forward > self.comm_unit_limit:
raise ValueError(f"comm unit exceed limit: round {self.round_idx}, device {torch.npu.current_device()}")
def count_comm_units(self):
sum_recv_units = self.recv_q_src > -1 + (self.recv_kv_src > -1) * 2 + len(self.recv_o_src)
sum_send_units = len(self.send_q_dst) + len(self.send_kv_dst) * 2 + self.send_o_dst > -1
self.cnt_comm_unit_forward = sum_recv_units + sum_send_units
def coarsen_attn_mask_npu(attn_mask, coarse_ratio):
orig_size = attn_mask.shape[0]
attn_mask_reshaped = (~attn_mask)
attn_mask_reshaped = attn_mask_reshaped.view(orig_size // coarse_ratio, coarse_ratio,
orig_size // coarse_ratio, coarse_ratio).permute(0, 2, 1, 3)
coarse_attn_mask = ~torch.any(torch.any(attn_mask_reshaped, dim=3), dim=2)
return coarse_attn_mask
def set_scheduling_info(cp_rank, scheduling):
global ADAPTIVE_CP_SCHEDULING_INFO
if ADAPTIVE_CP_SCHEDULING_INFO is None or get_args().adaptive_cp_dynamic_attn_mask:
ADAPTIVE_CP_SCHEDULING_INFO = process_scheduling_info(cp_rank, scheduling)[1:]
def get_scheduling_info():
if ADAPTIVE_CP_SCHEDULING_INFO is None:
raise RuntimeError("Trying to get scheduling info before setting it, ADAPTIVE_CP_SCHEDULING_INFO is still None")
return ADAPTIVE_CP_SCHEDULING_INFO
def set_remapped_seq_order(seq_order):
global ADAPTIVE_CP_SEQ_ORDER
ADAPTIVE_CP_SEQ_ORDER = seq_order
def get_remapped_seq_order():
if ADAPTIVE_CP_SEQ_ORDER is None:
raise RuntimeError("Trying to get optimized sequence before setting it, ADAPTIVE_CP_SEQ_ORDER is still None")
return ADAPTIVE_CP_SEQ_ORDER
def set_adaptive_cp_mask_list_by_user(mask_list):
global ADAPTIVE_CP_MASK_LIST_SET_BY_USER
ADAPTIVE_CP_MASK_LIST_SET_BY_USER = mask_list
def get_adaptive_cp_mask_list_by_user():
global ADAPTIVE_CP_MASK_LIST_SET_BY_USER
if ADAPTIVE_CP_MASK_LIST_SET_BY_USER is None:
raise RuntimeError("Trying to get mask list before setting it, ADAPTIVE_CP_MASK_LIST_SET_BY_USER is still None")
return ADAPTIVE_CP_MASK_LIST_SET_BY_USER
def generate_adaptive_cp_mask_list_by_user(opt_seq, scheduling_info, cp_rank, cp_size):
mask_list = None
set_adaptive_cp_mask_list_by_user(mask_list)
def set_adaptive_cp_grid_mask_by_user(grid_mask):
global ADAPTIVE_CP_GRID_MASK_SET_BY_USER
ADAPTIVE_CP_GRID_MASK_SET_BY_USER = grid_mask
def get_adaptive_cp_grid_mask_by_user():
global ADAPTIVE_CP_GRID_MASK_SET_BY_USER
if ADAPTIVE_CP_GRID_MASK_SET_BY_USER is None:
raise RuntimeError("Trying to get grid mask before setting it, ADAPTIVE_CP_GRID_MASK_SET_BY_USER is still None")
return ADAPTIVE_CP_GRID_MASK_SET_BY_USER
def generate_adaptive_cp_grid_mask_by_user(cp_size):
grid_mask = None
set_adaptive_cp_grid_mask_by_user(grid_mask)
def process_scheduling_info(local_rank, orig_scheduling, comm_limit=6):
round_num = len(orig_scheduling)
device_num = len(orig_scheduling[0])
processed_scheduling_info = [SchedulingInfo(round_idx=i, comm_unit_limit=comm_limit) for i in range(round_num + 1)]
for rnd_idx in range(round_num):
process_single_scheduling_info(local_rank, device_num, rnd_idx, orig_scheduling[rnd_idx],
processed_scheduling_info)
return processed_scheduling_info
def process_single_scheduling_info(local_rank, device_num, round_idx, round_scheduling_info, processed_scheduling_info):
if get_args().context_parallel_algo == 'adaptive_cp_algo':
rank_list = get_context_parallel_global_ranks()
else:
rank_list = get_context_parallel_for_hybrid_ring_global_ranks()
for execute_device_id, task_id in enumerate(round_scheduling_info):
if task_id == -1:
continue
origin_device_id = rank_list[int(task_id / device_num)]
kv_device_id = rank_list[task_id % device_num]
execute_device_id = rank_list[execute_device_id]
if execute_device_id != origin_device_id:
if execute_device_id == local_rank:
processed_scheduling_info[round_idx].recv_q_src = origin_device_id
processed_scheduling_info[round_idx + 1].send_o_dst = origin_device_id
elif origin_device_id == local_rank:
processed_scheduling_info[round_idx].send_q_dst.append(execute_device_id)
processed_scheduling_info[round_idx + 1].recv_o_src.append(execute_device_id)
else:
if execute_device_id == local_rank:
processed_scheduling_info[round_idx].recv_kv_src = kv_device_id
elif kv_device_id == local_rank:
processed_scheduling_info[round_idx].send_kv_dst.append(execute_device_id)
processed_scheduling_info[round_idx].check_eligibility()
def adaptive_reschedule_task(grid_mask, cp_size):
scheduling_info = []
total_task = torch.sum(grid_mask)
round_idx = 0
next_comm = np.zeros(cp_size)
while total_task > 0:
scheduling_info.append([-1 for _ in range(cp_size)])
cur_comm = next_comm
next_comm = np.zeros(cp_size)
total_task -= execute_scheduling(grid_mask, cp_size, round_idx, cur_comm, next_comm, scheduling_info[round_idx])
round_idx += 1
return scheduling_info
def execute_scheduling(grid_mask, cp_size, round_idx, cur_comm, next_comm, scheduling_info):
count = 0
is_free = np.ones(cp_size)
for device_id in range(cp_size):
row, col = find_kv_task(grid_mask, cp_size, round_idx, cur_comm, device_id, is_free)
if row != -1 and col != -1:
scheduling_info[device_id] = row * cp_size + col
grid_mask[row][col] = 0
count += 1
is_send_q = np.zeros(cp_size, dtype=int)
for device_id in range(cp_size):
if is_free[device_id] == 0:
continue
row, col = find_qo_task(grid_mask, cp_size, cur_comm, next_comm, device_id, is_send_q)
if row != -1 and col != -1:
scheduling_info[device_id] = row * cp_size + col
grid_mask[row][col] = 0
count += 1
return count
def find_kv_task(grid_mask, cp_size, round_idx, cur_comm, device_id, is_free):
is_free[device_id] = 0
row = device_id
col = (device_id + round_idx) % cp_size
if grid_mask[row][col] == 1:
cur_comm[row] = cur_comm[row] + 2
cur_comm[col] = cur_comm[col] + 2
return row, col
for i in range(1, cp_size):
row = device_id
col = (device_id - i + cp_size) % cp_size
if grid_mask[row][col] == 1 and cur_comm[row] <= COMM_THRESHOLD - 2 and cur_comm[col] <= COMM_THRESHOLD - 2:
cur_comm[row] += 2
cur_comm[col] += 2
return row, col
is_free[device_id] = 1
return -1, -1
def find_qo_task(grid_mask, cp_size, cur_comm, next_comm, device_id, is_send_q):
for i in range(1, cp_size):
row = (device_id + i) % cp_size
col = device_id
if grid_mask[row][col] == 1 and cur_comm[row] <= COMM_THRESHOLD - 1 and \
cur_comm[col] <= COMM_THRESHOLD - 1 and is_send_q[row] != 1:
is_send_q[row] = 1
cur_comm[row] += 1
cur_comm[col] += 1
next_comm[row] += 1
next_comm[col] += 1
return row, col
return -1, -1
def clear_global_info():
global CACHED_SEQ, CACHED_GRID_MASK, CACHED_MASK_LIST, CACHED_SCHEDULING, ADAPTIVE_CP_SCHEDULING_INFO
CACHED_SEQ, CACHED_GRID_MASK, CACHED_MASK_LIST, CACHED_SCHEDULING, ADAPTIVE_CP_SCHEDULING_INFO = (None, None, [],
None, None)
class AdaptiveCpOps:
def __init__(self):
self.ops = AdaptiveCpOpBuilder().load()
def coarsen_attn_mask_cpu(self, attn_mask, sampling_ratio):
if not attn_mask.is_contiguous():
attn_mask = attn_mask.contiguous()
mask_size_after_sampling = attn_mask.shape[0] // sampling_ratio
coarse_mask = torch.ones((mask_size_after_sampling, mask_size_after_sampling), dtype=torch.bool)
self.ops.coarsen_mask(attn_mask, mask_size_after_sampling, coarse_mask)
return coarse_mask
def get_grid_mask(self, attn_mask, cp_size):
if not attn_mask.is_contiguous():
attn_mask = attn_mask.contiguous()
if get_args().attention_mask_on_cpu:
grid_mask = torch.ones((cp_size, cp_size), dtype=torch.bool)
self.ops.coarsen_mask(attn_mask, cp_size, grid_mask)
else:
grid_mask = coarsen_attn_mask_npu(attn_mask, attn_mask.shape[0] // cp_size)
grid_mask = ~grid_mask
return grid_mask
def search_kmeans_cpu(self, attn_mask, reduced_mask, cp_size, num_iters=100):
tmp_attn_mask = torch.ones_like(attn_mask)
tmp_grid_mask = torch.ones((cp_size, cp_size), dtype=torch.bool)
optimal_attn_mask = torch.ones_like(attn_mask)
optimal_grid_mask = torch.ones((cp_size, cp_size), dtype=torch.bool)
optimal_num_cluster = [-1]
optimal_sorted_indices = self.ops.search_kmeans(attn_mask, reduced_mask, tmp_attn_mask, tmp_grid_mask,
optimal_grid_mask, optimal_attn_mask,
optimal_num_cluster, cp_size, num_iters)
return optimal_sorted_indices, optimal_grid_mask, optimal_attn_mask, optimal_num_cluster
def adaptive_remap(self, attn_mask, cp_size, truncated_dim=10):
args = get_args()
if attn_mask.dim() != 2 or attn_mask.shape[0] != attn_mask.shape[1]:
raise RuntimeError("Only 2-dimensional self-attention mask supported in adaptive cp")
if args.adaptive_cp_without_coarse:
sampling_ratio = 1
if args.attention_mask_on_cpu:
coarse_mask = attn_mask
else:
coarse_mask = attn_mask.cpu()
else:
if attn_mask.shape[0] % ADAPTIVE_CP_DEFAULT_SHAPE != 0:
raise RuntimeError("Shape of attention mask needs to be a multiple of 1024 if not enable "
"args.adaptive_cp_without_coarse in adaptive cp")
if args.attention_mask_on_cpu:
sampling_ratio = attn_mask.shape[0] // ADAPTIVE_CP_DEFAULT_SHAPE
coarse_mask = self.coarsen_attn_mask_cpu(attn_mask, sampling_ratio)
else:
sampling_ratio = attn_mask.shape[0] // ADAPTIVE_CP_DEFAULT_SHAPE
coarse_mask = coarsen_attn_mask_npu(attn_mask, sampling_ratio).cpu()
coarse_mask_np = coarse_mask.to(torch.float16).numpy()
mean_matrix = np.mean(coarse_mask_np, axis=0)
centered_matrix = (coarse_mask_np - mean_matrix).astype(float)
cov_matrix = np.matmul(centered_matrix.T, centered_matrix)
eigenvalues, eigenvectors = eigsh(cov_matrix, k=truncated_dim, which='LM')
feature_matrix = np.matmul(coarse_mask_np, eigenvectors).tolist()
optimal_seq, optimal_grid_mask, optimal_coarsen_attn_mask, optimal_num_cluster = (
self.search_kmeans_cpu(coarse_mask, feature_matrix, cp_size))
if args.adaptive_cp_without_coarse:
final_opt_seq = optimal_seq
else:
final_opt_seq = sampling_ratio * torch.tensor(optimal_seq)[:, None] + torch.arange(sampling_ratio)
final_opt_seq = final_opt_seq.view(-1).tolist()
optimal_grid_mask = ~optimal_grid_mask
return optimal_grid_mask, final_opt_seq
def get_adaptive_cp_info(self, attn_mask, cp_size):
args = get_args()
global CACHED_GRID_MASK, CACHED_SEQ
if args.attention_mask_on_cpu != (attn_mask.device.type == 'cpu'):
raise RuntimeError("args.attention_mask_on_cpu does not match the device of set attention mask")
if not args.adaptive_cp_only_reschedule:
if args.adaptive_cp_dynamic_attn_mask or CACHED_GRID_MASK is None:
opt_grid_mask, opt_seq = self.adaptive_remap(attn_mask, cp_size)
if not args.adaptive_cp_dynamic_attn_mask:
CACHED_GRID_MASK, CACHED_SEQ = opt_grid_mask, opt_seq
else:
opt_grid_mask, opt_seq = CACHED_GRID_MASK, CACHED_SEQ
else:
opt_seq = list(range(attn_mask.shape[0]))
if args.adaptive_cp_dynamic_attn_mask or CACHED_GRID_MASK is None:
opt_grid_mask = self.get_grid_mask(attn_mask, cp_size)
CACHED_GRID_MASK = opt_grid_mask
else:
opt_grid_mask = CACHED_GRID_MASK
opt_scheduling = adaptive_reschedule_task(opt_grid_mask, cp_size)
return opt_seq, opt_scheduling
def get_mask_list(self, attn_mask, opt_scheduling, opt_seq, cp_rank, cp_size):
args = get_args()
global CACHED_MASK_LIST
if not args.adaptive_cp_dynamic_attn_mask and len(CACHED_MASK_LIST) > 0:
return CACHED_MASK_LIST
round_num = len(opt_scheduling)
grid_size = attn_mask.shape[0] // cp_size
mask_list = []
for rnd_idx in range(round_num):
task_id = opt_scheduling[rnd_idx][cp_rank]
if task_id == -1:
mask_list.append(None)
continue
q_device_id = task_id // cp_size
kv_device_id = task_id % cp_size
if args.attention_mask_on_cpu:
mask_list.append(torch.empty((grid_size, grid_size), dtype=torch.bool, device='cpu'))
if args.adaptive_cp_only_reschedule:
grid_inds = [q_device_id, kv_device_id]
self.ops.get_mask_list_without_remap(attn_mask, mask_list[rnd_idx], grid_inds, cp_size)
else:
q_token_list = opt_seq[grid_size * q_device_id: grid_size * (q_device_id + 1)]
kv_token_list = opt_seq[grid_size * kv_device_id: grid_size * (kv_device_id + 1)]
self.ops.get_mask_list_with_remap(attn_mask, mask_list[rnd_idx], q_token_list, kv_token_list)
else:
q_token_list = opt_seq[grid_size * q_device_id: grid_size * (q_device_id + 1)]
kv_token_list = opt_seq[grid_size * kv_device_id: grid_size * (kv_device_id + 1)]
mask_list.append(attn_mask[q_token_list, :][:, kv_token_list])
if args.attention_mask_on_cpu:
for rnd_idx in range(round_num):
if mask_list[rnd_idx] is not None:
mask_list[rnd_idx] = mask_list[rnd_idx].npu(non_blocking=True)
CACHED_MASK_LIST = mask_list
return mask_list
def round_up(x, n):
return (x + n - 1) // n * n
def divisible(tensor, divisor):
rem = tensor % divisor
return torch.all(rem == 0).item()
def pad_data(actual_seq_len, batch, cp_size, tp_size):
from math import lcm
pad_to = lcm(2, tp_size) * cp_size
if divisible(actual_seq_len, pad_to):
return actual_seq_len
first_seq_len = actual_seq_len[0:1]
per_seq_lens = torch.cat((first_seq_len, torch.diff(actual_seq_len)))
per_seq_lens_padded = round_up(per_seq_lens, pad_to)
actual_seq_len_padded = torch.cumsum(per_seq_lens_padded, dim=0)
paded_total_len = actual_seq_len_padded[-1]
per_seq_lens_cpu = per_seq_lens.cpu()
starts = torch.cat([torch.tensor([0], device='npu'), actual_seq_len_padded[:-1]])
starts_cpu = starts.cpu()
index_ranges = []
for i in range(len(per_seq_lens_cpu)):
start_val = starts_cpu[i]
seq_len_val = per_seq_lens_cpu[i]
index_ranges.append((start_val, start_val + seq_len_val))
all_indices = []
for start_val, end_val in index_ranges:
all_indices.append(torch.arange(start_val, end_val, device='npu'))
indices = torch.cat(all_indices)
def pad(data):
if data is None:
return data
data = data.view(-1)
buffer = torch.zeros(paded_total_len, device='npu', dtype=data.dtype)
buffer[indices] = data[:len(indices)]
return buffer.view((1, -1))
for key in ['tokens', 'labels', 'loss_mask', 'position_ids']:
batch[key] = pad(batch[key])
return actual_seq_len_padded