from functools import wraps
import torch
import torch.distributed as dist
from einops import rearrange
import mindspeed.megatron_adaptor
from mindspeed.core.context_parallel.ulysses_context_parallel import UlyssesContextAttention
from mindspeed.core.context_parallel.ring_context_parallel import ringattn_context_parallel
import opensora
from opensora.utils.config_utils import parse_configs
from opensora.acceleration.parallel_states import (get_sequence_parallel_group,
get_sequence_parallel_group_for_send_recv_overlap)
from opensora.utils.device_utils import is_npu_available
if not is_npu_available():
import xformers.ops
else:
import torch_npu
cfg = parse_configs(training=True)
def attention_init_wrapper(fn):
@wraps(fn)
def wrapper(self, *args, **kwargs):
fn(self, *args, **kwargs)
if cfg.sp_size > 1 and cfg.context_parallel_algo == 'ulysses_cp_algo':
sp_group = get_sequence_parallel_group()
if self.enable_flashattn:
self.core_attention = UlyssesContextAttention(self.core_attention, sp_group)
else:
self.core_attention = UlyssesContextAttention(self.core_attention, sp_group,
scatter_idx=1, gather_idx=2)
return wrapper
def core_attention_forward(self, q, k, v):
if not self.enable_flashattn:
dtype = q.dtype
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.to(torch.float32)
attn = attn.softmax(dim=-1)
attn = attn.to(dtype)
attn = self.attn_drop(attn)
x = attn @ v
return x
if is_npu_available() and q.dtype in [torch.float16, torch.bfloat16]:
if cfg.sp_size > 1 and cfg.context_parallel_algo == 'megatron_cp_algo':
cp_group = get_sequence_parallel_group()
cp_size = dist.get_world_size(cp_group)
rank = dist.get_rank(group=cp_group)
local_ranks_in_sp_group = list(range(cp_size))
cp_global_ranks = []
current_global_rank = dist.get_rank()
for local_rank in local_ranks_in_sp_group:
global_rank = (current_global_rank // cp_size) * cp_size + local_rank
cp_global_ranks.append(global_rank)
cp_para = dict()
cp_para['causal'] = None
cp_para['cp_group'] = cp_group
cp_para['cp_size'] = cp_size
cp_para['rank'] = rank
cp_para['cp_global_ranks'] = cp_global_ranks
cp_para['use_cp_send_recv_overlap'] = cfg.use_cp_send_recv_overlap
cp_para['cp_group_for_send_recv_overlap'] = get_sequence_parallel_group_for_send_recv_overlap() \
if cfg.use_cp_send_recv_overlap else None
head_num = q.shape[-2]
q, k, v = [rearrange(x, 's b h d -> s b (h d)') for x in [q, k, v]]
x = ringattn_context_parallel(q, k, v, head_num, cp_para, self.scale, None, self.attn_drop.p)
else:
num_head = q.shape[-2]
q, k, v = [rearrange(x, 's b h d -> s b (h d)') for x in [q, k, v]]
x = torch_npu.npu_fusion_attention(
q, k, v, num_head, input_layout="SBH",
pse=None,
scale=self.scale,
pre_tockens=65536,
next_tockens=65536,
keep_prob=1. - self.attn_drop.p if self.training else 1.,
sync=False,
inner_precise=0,
)[0]
else:
from flash_attn import flash_attn_func
x = flash_attn_func(
q,
k,
v,
dropout_p=self.attn_drop.p if self.training else 0.0,
softmax_scale=self.scale,
)
return x
def exe_adaptation():
opensora.models.layers.blocks.AttentionWithCp.__init__ = attention_init_wrapper(
opensora.models.layers.blocks.AttentionWithCp.__init__)
opensora.models.layers.blocks.CoreAttention.forward = core_attention_forward
exe_adaptation()