import re
from typing import Optional, Tuple
import torch
from transformers import AutoConfig, AutoModel
from tensor_cast.layers.attention import flash_attention_forward
from tensor_cast.transformers.custom_model_registry import (
ModelProfile,
register_custom_model,
register_model_profile,
)
from tensor_cast.transformers.model import TransformerModel
from tensor_cast.transformers.transformations import (
maybe_enable_mtp,
maybe_reuse_layers,
patch_mla,
patch_moe,
patch_rotary_emb,
quantize_model,
shard_model,
wrap_model,
)
from ..utils import replace_module
from .bailing_moe_hf.configuration_bailing_moe_v2 import BailingMoeV2Config
from .bailing_moe_hf.modeling_bailing_moe_v2 import BailingMoeV2Model
AutoConfig.register("bailing_moe", BailingMoeV2Config)
AutoModel.register(BailingMoeV2Config, BailingMoeV2Model)
register_model_profile(
ModelProfile(
model_type="bailing_moe",
moe_module_name="BailingMoeV2SparseMoeBlock",
moe_gate_returns_raw_logits=False,
custom_expert_module_type=None,
)
)
@register_custom_model("bailing_moe")
def _(model: TransformerModel):
model = wrap_model(model)
model = maybe_enable_mtp(model)
model = maybe_reuse_layers(model)
model = patch_rotary_emb(model)
if model.model_config.attention_cls is not None:
model.attention_by_layers = {}
for i in range(model.num_hidden_layers):
model.attention_by_layers[i] = model.model_config.attention_cls()
named_modules = list(model._inner.named_modules())
for name, module in named_modules:
if re.match("BailingMoe.*Attention", type(module).__name__):
adapter = BailingMoeV2AttentionAdapter(module, model.attention_by_layers)
replace_module(model, name, adapter)
model = patch_mla(model)
model = patch_moe(model)
model = quantize_model(model)
model = shard_model(model)
return model
class BailingMoeV2AttentionAdapter(torch.nn.Module):
"""
Adapter for BailingMoeV2Attention
This adapter wraps the original attention module and:
1. Extracts QKV computation from the original module
2. Applies necessary transformations (RoPE, QK norm, etc.)
3. Calls tensorcast's attention ops
4. Applies output projection
Args:
original_attention: The original BailingMoeV2Attention module to wrap
attention_by_layers: Dictionary mapping layer_idx to AttentionTensorCast instances
"""
def __init__(self, original_attention: torch.nn.Module, attention_by_layers: dict):
super().__init__()
self.original_attention = original_attention
self.attention_by_layers = attention_by_layers
self.layer_idx = original_attention.layer_idx
self.hidden_size = original_attention.hidden_size
self.config = original_attention.config
self.num_heads = original_attention.num_heads
self.head_dim = original_attention.head_dim
self.num_key_value_heads = original_attention.num_key_value_heads
self.num_key_value_groups = original_attention.num_key_value_groups
self.q_proj = torch.nn.Linear(
self.hidden_size,
self.num_heads * self.head_dim,
bias=self.config.use_qkv_bias,
)
self.k_proj = torch.nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=self.config.use_qkv_bias,
)
self.v_proj = torch.nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=self.config.use_qkv_bias,
)
self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=self.config.use_bias)
if hasattr(original_attention, "query_layernorm"):
self.query_layernorm = original_attention.query_layernorm
if hasattr(original_attention, "key_layernorm"):
self.key_layernorm = original_attention.key_layernorm
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[torch.Tensor] = None,
output_attentions: bool = False,
use_cache: bool = False,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Forward pass that computes QKV and calls tensorcast attention.
"""
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if hasattr(self, "query_layernorm") and self.config.use_qk_norm:
query_states = self.query_layernorm(query_states)
key_states = self.key_layernorm(key_states)
if position_embeddings is not None:
cos, sin = position_embeddings
apply_rotary_pos_emb = self._get_apply_rotary_pos_emb()
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if past_key_value is not None:
cache_kwargs = {}
if position_embeddings is not None:
cache_kwargs = {"sin": sin, "cos": cos}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
attn_output, attn_weights = flash_attention_forward(
self.original_attention,
query_states,
key_states,
value_states,
attention_mask,
attention_by_layers=self.attention_by_layers,
**kwargs,
)
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
def _get_apply_rotary_pos_emb(self):
"""
Get the apply_rotary_pos_emb function from the original module's namespace.
"""
import sys
module_name = self.original_attention.__class__.__module__
if module_name in sys.modules:
module = sys.modules[module_name]
if hasattr(module, "apply_rotary_pos_emb"):
return module.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
rotary_dim = cos.shape[-1]
q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
def rotate_half(x):
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
q_embed = torch.cat([q_embed, q_pass], dim=-1)
k_embed = torch.cat([k_embed, k_pass], dim=-1)
return q_embed, k_embed
return apply_rotary_pos_emb