from typing import Optional
import torch
import torch.nn.functional as F
from ..model_config import MlaConfig
from ..parallel_group import ParallelGroup, ParallelGroupManager
from ..utils import exact_division
from .attention import AttentionMetadataBase
from . import COLWISE_LINEAR, ROWWISE_LINEAR
from .mla import (
DeepseekSparseAttention,
MultiheadLatentAttentionTensorCast,
_resolve_sparse_topk_limit,
)
from .mtp import MultiTokenPredictorLayer
from .utils import apply_static_quant_linear, ModelWrapperBase
def has_deepseek_v4_hash_routing(gate: torch.nn.Module, moe_layer_idx: Optional[int]) -> bool:
gate_hash = getattr(gate, "hash", None)
if gate_hash is not None:
return bool(gate_hash)
return moe_layer_idx is not None and moe_layer_idx < 3
def compute_v4_gate_scores(
gate: torch.nn.Module,
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, float, bool]:
gate_weight = _extract_gate_weight(gate)
score_func = str(getattr(gate, "score_func", "sqrtsoftplus"))
route_scale = float(getattr(gate, "route_scale", 1.0))
scores = F.linear(hidden_states.float(), gate_weight.float())
if score_func == "softmax":
scores = scores.softmax(dim=-1)
elif score_func == "sigmoid":
scores = scores.sigmoid()
else:
scores = F.softplus(scores).sqrt()
return scores, route_scale, score_func != "softmax"
def route_deepseek_v4_gate(
gate: torch.nn.Module,
hidden_states: torch.Tensor,
top_k: int,
input_ids: Optional[torch.Tensor] = None,
moe_layer_idx: Optional[int] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
scores, route_scale, normalize_weights = compute_v4_gate_scores(gate, hidden_states)
use_hash_routing = has_deepseek_v4_hash_routing(gate, moe_layer_idx)
topk_weights, topk_indices = route_v4_gate_tail(
gate,
top_k,
use_hash_routing,
scores,
normalize_weights,
route_scale,
input_ids,
)
return topk_indices, topk_weights.to(hidden_states.dtype)
def route_v4_gate_tail(
gate: torch.nn.Module,
top_k: int,
hash_routing: bool,
scores: torch.Tensor,
normalize_weights: bool,
route_scale: float,
input_ids: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
if hash_routing:
if input_ids is None:
raise ValueError("DeepSeek V4 hash routing requires input_ids")
tid2eid = _extract_tid2eid(gate)
if tid2eid is None:
raise ValueError("DeepSeek V4 hash routing requires gate.tid2eid")
topk_weights, topk_indices = torch.ops.tensor_cast.moe_gating_top_k_hash(
scores,
top_k,
normalize_weights,
route_scale,
input_ids,
tid2eid,
)
else:
bias = _extract_gate_bias(gate)
topk_weights, topk_indices = torch.ops.tensor_cast.moe_gating_top_k(
scores,
top_k,
normalize_weights,
route_scale,
bias,
)
gate_dtype = gate.weight.dtype if hasattr(gate, "weight") and hasattr(gate.weight, "dtype") else scores.dtype
return topk_weights.to(gate_dtype), topk_indices
def _extract_gate_weight(gate: torch.nn.Module) -> torch.Tensor:
weight = getattr(gate, "weight", None)
if weight is not None:
return weight.data if hasattr(weight, "data") else weight
for param in gate.parameters(recurse=True):
if param.ndim == 2:
return param.data
raise AttributeError("Hash-routing MoE gate must expose a 2D weight tensor for cost modeling")
def _extract_tid2eid(gate: torch.nn.Module) -> torch.Tensor | None:
tid2eid = getattr(gate, "tid2eid", None)
if tid2eid is None:
return None
return tid2eid.data if hasattr(tid2eid, "data") else tid2eid
def _extract_gate_bias(gate: torch.nn.Module) -> torch.Tensor | None:
bias = getattr(gate, "bias", None)
if bias is None:
bias = getattr(gate, "e_score_correction_bias", None)
if bias is None:
return None
return bias.data if hasattr(bias, "data") else bias
def get_window_topk_idxs(
window_size: int,
batch_size: int,
seq_length: int,
device: torch.device,
is_decode: bool = False,
) -> torch.Tensor:
W = int(window_size)
sl = int(seq_length)
width = W if is_decode else min(sl, W)
return torch.arange(width, device=device, dtype=torch.long).view(1, 1, -1).expand(int(batch_size), sl, -1)
def get_compress_topk_idxs(
ratio: int,
batch_size: int,
seq_length: int,
device: torch.device,
) -> torch.Tensor:
R = int(ratio)
sl = int(seq_length)
width = max(sl // R, 1)
return torch.arange(width, device=device, dtype=torch.long).view(1, 1, -1).expand(int(batch_size), sl, -1)
class DeepseekV4SparseAttention(DeepseekSparseAttention):
"""V4 sparse attention wrapper (covers Flash and Pro).
The cost-modeling forward here mirrors
`deepseek-ai/DeepSeek-V4-Flash/inference/model.py:Attention.forward` step-by-step.
For every layer it preserves the same structural stages:
1. ``q = q_norm(wq_a(x))``
2. ``q = wq_b(q).view(...)`` followed by explicit per-head
``rsqrt(mean(q^2)+eps)`` normalization
3. ``apply_rotary_emb(q[..., -rd:])`` on the Q branch
4. ``kv = kv_norm(wkv(x))`` over the FULL shared `head_dim`
5. ``apply_rotary_emb(kv[..., -rd:])`` on the shared-KV branch
6. Window top-k index materialization (every layer): inlined
host-side arange/clamp/where in `forward`. Prefill is op-for-op
with `model.py:255 get_window_topk_idxs`; decode uses a single
arange of the same shape and value range (cost-equivalent).
7. (ratio==4 only) Indexer wrapper, which internally runs the
indexer-local `compressor` (head_dim = `index_head_dim`) as its
first stage and then `quant_lightning_indexer` for the learned
compressed top-k — mirroring the reference `Indexer.forward`
8. (ratio==128 only) Compress top-k index materialization (inlined
in `forward`; prefill mirrors `model.py:269 get_compress_topk_idxs`
op-for-op, decode uses an `arange + offset` matching the
reference if-branch)
9. ``cat(window_topk, compress_topk).int()`` (skipped on ratio==0)
10. `scatter_nd_update_mla` writes the post-RoPE `kv_window_entry`
into the sliding-window cache and returns a functional handle
11. (ratio>0 only) `compressor` for the shared coarse-grain KV
12. `sparse_attn_sharedkv` over [window | compressed] memory
13. ``apply_rotary_emb(o[..., -rd:], inverse=True)`` on the output
14. Group-wise ``einsum("bsgd,grd->bsgr", o, wo_a)`` followed by
``wo_b(o.flatten(2))``
Compared to V3.2 (which fuses RMS+matmul+rope into `mlapo`/`mlapo_quant`
and writes KV via `concat_and_cache_mla`), V4 (Flash/Pro) deliberately keeps
these stages as separate ops on NPU. This wrapper therefore also
emits them as separate ops and DOES NOT call `mlapo` or
`concat_and_cache_mla`. The sliding-window KV write from
`model.py:520-533` is preserved here as `scatter_nd_update_mla`: the post-
RoPE `kv_window_entry` is written into `kv_cache`, and the op returns a
functional handle to the updated cache that is then fed into
`sparse_attn_sharedkv`. This builds an explicit
`wkv -> kv_norm -> KV-RoPE -> cat -> scatter -> sparse_attn` data edge,
matching the reference's `self.kv_cache[:bsz] = kv; o = sparse_attn_sharedkv(
q, self.kv_cache, ...)` pattern, and ensuring the full KV branch
(including KV-RoPE and the cache write) is accounted for in the modeled
runtime cost.
Layer-policy specifics:
- ratio == 0 (layers 0/1) : window-only, no Compressor/Indexer
- ratio == 4 (even layers >= 2) : window + Compressor + Indexer
- ratio == 128 (odd layers >= 3) : window + Compressor only
Critically, ratio=4 layers issue TWO `compressor` ops — one for the
indexer-local KV cache (head_dim = `index_head_dim`, e.g. 128) and one
for the shared coarse-grain KV stream (head_dim = `head_dim`, e.g. 512).
"""
def _setup_kv_b_decomposition(self, tp_group: ParallelGroup) -> None:
return None
def _quantize_kv_b_decomposition(self) -> None:
return None
@staticmethod
def _local_linear_out_features(module: torch.nn.Module) -> int:
out_features = getattr(module, "out_features_per_partition", None)
if out_features is None:
out_features = getattr(module, "out_features", None)
if out_features is None and hasattr(module, "_inner"):
out_features = getattr(module._inner, "out_features", None)
if out_features is None:
raise AttributeError(f"Unable to resolve local out_features from {type(module).__name__}")
return int(out_features)
@classmethod
def build_tp_plan_extras(cls, prefix: str, params: dict, config_info) -> dict[str, tuple[str, dict]]:
from .mla import tp_plan_module_path
return {
tp_plan_module_path(prefix, "self_attn.indexer.wq_b"): (COLWISE_LINEAR, dict(params)),
tp_plan_module_path(prefix, "self_attn.indexer.weights_proj"): (
COLWISE_LINEAR,
{
**dict(params),
"head_num": getattr(config_info, "index_n_heads"),
},
),
}
@classmethod
def build_o_proj_tp_plan_extras(cls, prefix: str, params: dict, config_info) -> dict[str, tuple[str, dict]]:
from .mla import tp_plan_module_path
return {
tp_plan_module_path(prefix, "self_attn.wo_a"): (COLWISE_LINEAR, {**dict(params), "dim": 1}),
tp_plan_module_path(prefix, "self_attn.o_proj"): (
ROWWISE_LINEAR,
{
**dict(params),
"tp_group": params["global_tp_group"],
"global_tp_group": params["tp_group"],
},
),
}
def __init__(
self,
mla_config: MlaConfig,
mla_module: torch.nn.Module,
tp_group: ParallelGroup,
decode_only: bool = False,
parallel_group_manager: Optional["ParallelGroupManager"] = None,
):
MultiheadLatentAttentionTensorCast.__init__(self, mla_config, mla_module, tp_group, decode_only)
self.compress_ratio = getattr(self._inner, "compress_ratio", 0)
self.use_indexer = bool(getattr(self._inner, "use_indexer", False))
self.use_compressor = bool(getattr(self._inner, "use_compressor", False))
self.hc_mult = getattr(getattr(self._inner, "config", None), "hc_mult", 1)
inner_config = getattr(self._inner, "config", None)
self.window_size = int(
getattr(self._inner, "window_size", None) or getattr(inner_config, "window_size", None) or 128
)
self.o_proj_tp_group = (
parallel_group_manager.o_proj_tp_group
if parallel_group_manager is not None
and getattr(parallel_group_manager, "o_proj_tp_group", None) is not None
else tp_group
)
self._head_dim = int(getattr(self._inner, "head_dim"))
self._qk_rope_head_dim = int(getattr(self._inner, "qk_rope_head_dim"))
self._n_groups = int(getattr(self._inner, "n_groups", 1))
if self.o_proj_tp_group.world_size > self._n_groups:
raise RuntimeError(
f"Skipped unsupported DeepSeek V4 parallel configuration: "
f"o_proj_tp={self.o_proj_tp_group.world_size}, o_groups={self._n_groups}. "
"Grouped O projection in the Flash/Pro model assumes o_proj_tp <= o_groups. "
"If you have set other parallel configurations, please wait for those results."
)
self._n_local_groups = exact_division(self._n_groups, self.o_proj_tp_group.world_size)
self._o_lora_rank = int(getattr(self._inner, "o_lora_rank", self._head_dim))
self._n_local_heads = exact_division(self._local_linear_out_features(self.q_b_proj), self._head_dim)
self._per_group_in_dim = exact_division(self._n_local_heads * self._head_dim, self._n_local_groups)
self._index_head_dim: Optional[int] = None
self.indexer = None
if self.use_indexer and getattr(self._inner, "indexer", None) is not None:
self.indexer = DeepseekV4SparseAttentionIndexer(
self._inner.indexer,
topk_limit=_resolve_sparse_topk_limit(
self._inner.indexer,
config=getattr(self._inner, "config", None),
),
tp_group=tp_group,
compress_ratio=self.compress_ratio,
)
self._index_head_dim = int(self.indexer.head_dim)
@property
def qk_rope_head_dim(self) -> int:
return self._qk_rope_head_dim
def _scatter_window_kv_prefill(
self,
kv_window_entry: torch.Tensor,
kv_cache: Optional[torch.Tensor],
slot_mapping: Optional[torch.Tensor],
sl: int,
meta_seq_lens: Optional[torch.Tensor],
) -> tuple[torch.Tensor, torch.Tensor]:
"""Write window KV entries into the sliding-window cache during prefill.
Returns (kv_for_attn, kv_attn_handle) where ``kv_attn_handle`` is the
post-write functional cache handle that must be wired into
``sparse_attn_sharedkv`` to keep the entire KV producer chain live.
Write semantics:
- ``sl <= W``: single scatter of the full window entry
- ``sl > W`` with ``cutoff == 0``: tail-``W`` fills the cache exactly
- ``sl > W`` with ``cutoff > 0``: two-scatter split to match the
circular-buffer semantics of the reference implementation
The final ``kv_attn_handle + tensor * 0`` expression in the caller is
NOT a no-op: it binds the two tensors to the same graph node so that
``torch.compile``'s dead-code elimination cannot prune the upstream
KV producer chain (wkv → kv_norm → RoPE → scatter). Both operands
are live because their sum is consumed by ``sparse_attn_sharedkv``.
"""
W = int(self.window_size)
kv_for_attn = kv_window_entry
kv_attn_handle = kv_window_entry
if kv_cache is None:
return kv_for_attn, kv_attn_handle
if sl <= W:
kv_cache = torch.ops.tensor_cast.scatter_nd_update_mla(
kv_window_entry,
kv_cache,
slot_mapping,
sl,
meta_seq_lens,
)
else:
cutoff = sl % W
kv_window_tail = kv_window_entry[:, -W:]
if cutoff == 0:
kv_cache = torch.ops.tensor_cast.scatter_nd_update_mla(
kv_window_tail,
kv_cache,
slot_mapping,
W,
meta_seq_lens,
)
else:
first, second = kv_window_tail.split([W - cutoff, cutoff], dim=1)
kv_cache = torch.ops.tensor_cast.scatter_nd_update_mla(
first,
kv_cache,
slot_mapping,
W - cutoff,
meta_seq_lens,
)
kv_cache = torch.ops.tensor_cast.scatter_nd_update_mla(
second,
kv_cache,
slot_mapping,
cutoff,
meta_seq_lens,
)
kv_attn_handle = kv_cache
return kv_for_attn, kv_attn_handle
def forward(
self,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor],
kv_cache_unused: Optional[torch.Tensor] = None,
attention_meta: Optional[AttentionMetadataBase] = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
kv_cache_by_layers = kwargs.pop("kv_cache_by_layers", None)
indexer_cache_by_layers = kwargs.pop("indexer_cache_by_layers", None)
kv_cache = kv_cache_by_layers[self.layer_idx] if kv_cache_by_layers else None
indexer_cache = (
indexer_cache_by_layers[self.layer_idx]
if indexer_cache_by_layers and self.layer_idx in indexer_cache_by_layers
else None
)
batch_size, seq_length = hidden_states.shape[:-1]
cos, sin = position_embeddings
rd = self._qk_rope_head_dim
head_dim = self._head_dim
n_local_heads = self._n_local_heads
q_a = apply_static_quant_linear(hidden_states, self.q_a_proj)
q_a_normed = torch.ops.tensor_cast.rms_norm(
q_a,
self.q_a_layernorm.weight.data,
getattr(self.q_a_layernorm, "variance_epsilon", 1e-6),
)
q_proj = apply_static_quant_linear(q_a_normed, self.q_b_proj)
n_local_heads = exact_division(q_proj.shape[-1], head_dim)
q_states = q_proj.view(batch_size, seq_length, n_local_heads, head_dim)
q_states *= torch.rsqrt(q_states.square().mean(-1, keepdim=True) + 1e-6).to(q_states.dtype)
torch.ops.tensor_cast.apply_rope_inplace(q_states, cos, sin, True, False, rd)
kv_input = hidden_states + q_states[..., :1, :1].reshape(batch_size, seq_length, 1).to(hidden_states.dtype) * 0
kv = apply_static_quant_linear(kv_input, self.kv_a_proj_with_mqa).view(batch_size, seq_length, head_dim)
kv_normed = torch.ops.tensor_cast.rms_norm(
kv,
self.kv_a_layernorm.weight.data,
getattr(self.kv_a_layernorm, "variance_epsilon", 1e-6),
)
torch.ops.tensor_cast.apply_rope_inplace(kv_normed, cos, sin, True, False, rd)
kv_nope_quant, _ = torch.ops.tensor_cast.dynamic_quantize_symmetric(
kv_normed[..., :-rd],
[-1],
scale_dtype=torch.float32,
out_dtype=torch.float8_e4m3fn,
)
kv_normed[..., :-rd] = kv_nope_quant.to(kv_normed.dtype)
kv_window_entry = kv_normed
W = int(self.window_size)
sl = int(seq_length)
device = hidden_states.device
meta_query_lens = attention_meta.query_lens if attention_meta is not None else None
meta_seq_lens = attention_meta.seq_lens if attention_meta is not None else None
is_decode = sl == int(meta_query_lens.shape[0]) if meta_query_lens is not None else sl == 1
topk_indices = get_window_topk_idxs(W, batch_size, sl, device, is_decode)
if self.compress_ratio:
R = int(self.compress_ratio)
if self.use_indexer and self.indexer is not None and indexer_cache is not None:
indexer_hidden_states = hidden_states + kv_window_entry[..., :1].to(hidden_states.dtype) * 0
indexer_q_a_normed = q_a_normed + kv_window_entry[..., :1].to(q_a_normed.dtype) * 0
compress_topk_indices = self.indexer(
indexer_hidden_states,
indexer_q_a_normed,
position_embeddings,
indexer_cache,
attention_meta=attention_meta,
)
else:
compress_topk_indices = get_compress_topk_idxs(R, batch_size, sl, device)
topk_indices = torch.cat([topk_indices, compress_topk_indices], dim=-1)
topk_indices = topk_indices.int()
attn_sink = getattr(
self._inner,
"attention_sink",
torch.empty(0, dtype=q_states.dtype, device=q_states.device),
)
softmax_scale = float(getattr(self._inner, "softmax_scale", getattr(self._inner, "scaling", head_dim**-0.5)))
slot_mapping = attention_meta.slot_mapping if attention_meta is not None else None
if not is_decode:
kv_for_attn, kv_attn_handle = self._scatter_window_kv_prefill(
kv_window_entry,
kv_cache,
slot_mapping,
sl,
meta_seq_lens,
)
if self.use_compressor and kv_cache is not None:
kv_compress, kv_cache = torch.ops.tensor_cast.compressor(
hidden_states,
kv_cache,
self.compress_ratio,
head_dim,
rd,
False,
meta_seq_lens,
meta_query_lens,
)
kv_for_attn = torch.cat([kv_for_attn, kv_compress], dim=1)
kv_attn_handle = kv_cache
kv_for_attn = kv_attn_handle + kv_for_attn.reshape(-1)[0].to(kv_attn_handle.dtype) * 0
attn_output = torch.ops.tensor_cast.sparse_attn_sharedkv(
q_states,
kv_for_attn,
attn_sink,
topk_indices,
softmax_scale,
head_dim,
)
else:
kv_for_attn = kv_window_entry
if kv_cache is not None:
kv_cache = torch.ops.tensor_cast.scatter_nd_update_mla(
kv_window_entry,
kv_cache,
slot_mapping,
sl,
meta_seq_lens,
)
kv_for_attn = kv_cache
if self.use_compressor and kv_cache is not None:
_, kv_cache = torch.ops.tensor_cast.compressor(
hidden_states,
kv_cache,
self.compress_ratio,
head_dim,
rd,
False,
meta_seq_lens,
meta_query_lens,
)
kv_for_attn = kv_cache
attn_output = torch.ops.tensor_cast.sparse_attn_sharedkv(
q_states,
kv_for_attn,
attn_sink,
topk_indices,
softmax_scale,
head_dim,
)
o_view = attn_output
torch.ops.tensor_cast.apply_rope_inplace(o_view, cos, sin, True, True, rd)
per_group_in_dim = exact_division(n_local_heads * head_dim, self._n_local_groups)
o_grouped = o_view.reshape(batch_size, seq_length, self._n_local_groups, per_group_in_dim)
wo_a_weight, _, _ = MultiheadLatentAttentionTensorCast.extract_qparams(self._inner.wo_a)
wo_a_grouped = wo_a_weight.reshape(self._n_local_groups, self._o_lora_rank, per_group_in_dim).to(
o_grouped.dtype
)
o_grouped = torch.einsum("bsgd,grd->bsgr", o_grouped, wo_a_grouped)
attn_output = self.o_proj(o_grouped.flatten(2))
return attn_output.to(hidden_states.dtype), None
class DeepseekV4SparseAttentionIndexer(ModelWrapperBase):
"""Wrapper for the ratio==4 learned Indexer path in V4 (Flash/Pro).
Mirrors reference `Indexer.forward` (deepseek-ai/DeepSeek-V4-Flash/inference/model.py:402-433)
so simulated execution cost tracks the reference's runtime. `rotate_activation`
and `fp4_act_quant` on q have no standalone tensor_cast op; their cost is
accounted for inside `quant_lightning_indexer` instead of as separate trace
events.
"""
def __init__(
self,
indexer,
topk_limit: Optional[int] = None,
tp_group: Optional[ParallelGroup] = None,
compress_ratio: int = 0,
):
super().__init__(indexer)
self._topk_limit = _resolve_sparse_topk_limit(indexer, topk_limit=topk_limit)
self.tp_group = tp_group
self.compress_ratio = int(compress_ratio)
@property
def num_heads(self) -> int:
return self._inner.num_heads
@property
def num_local_heads(self) -> int:
out_features = getattr(self.wq_b, "out_features_per_partition", None)
if out_features is None:
out_features = getattr(self.wq_b, "out_features", None)
if out_features is None and hasattr(self.wq_b, "_inner"):
out_features = getattr(self.wq_b._inner, "out_features", None)
if out_features is None:
raise AttributeError(f"Unable to resolve local out_features from {type(self.wq_b).__name__}")
return exact_division(int(out_features), self.head_dim)
@property
def head_dim(self) -> int:
return self._inner.head_dim
@property
def topk_limit(self) -> int:
return self._topk_limit
def forward(
self,
hidden_states: torch.Tensor,
qa_normed: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
indexer_cache: torch.Tensor,
attention_meta: Optional[AttentionMetadataBase] = None,
):
cos, sin = position_embeddings
seq_lens_meta = attention_meta.seq_lens if attention_meta is not None else None
query_lens_meta = attention_meta.query_lens if attention_meta is not None else None
batch_size, seq_length = hidden_states.shape[:-1]
rd = int(self.qk_rope_head_dim)
q_proj = apply_static_quant_linear(qa_normed, self.wq_b)
num_local_heads = exact_division(q_proj.shape[-1], self.head_dim)
q_states = q_proj.view(batch_size, seq_length, num_local_heads, self.head_dim)
torch.ops.tensor_cast.apply_rope_inplace(q_states, cos, sin, True, False, rd)
if self.compress_ratio:
_, indexer_cache = torch.ops.tensor_cast.compressor(
hidden_states,
indexer_cache,
self.compress_ratio,
int(self.head_dim),
rd,
True,
seq_lens_meta,
query_lens_meta,
)
weights = apply_static_quant_linear(hidden_states, self.weights_proj) * (
float(self.head_dim) ** -0.5 * float(self.num_heads) ** -0.5
)
return torch.ops.tensor_cast.quant_lightning_indexer(
q_states,
weights,
indexer_cache,
int(self.topk_limit),
int(self.tp_group.world_size) if self.tp_group is not None else 1,
seq_lens_meta,
query_lens_meta,
)
class HyperConnectedMultiTokenPredictorLayer(MultiTokenPredictorLayer):
"""MTP layer for V4 family with Hyper-Connection (HC).
The main model output is already HC-reduced to [B,S,D] but the MTP block
expects HC-expanded [B,S,Hc,D]. This subclass bridges the shape with HC-expand
at entry and HC-head reduction at exit, mirroring the reference MTPBlock
semantics (each MTPBlock owns its own hc_head params).
"""
def __init__(self, hf_config, mtp_block: torch.nn.Module):
super().__init__(hf_config, mtp_block)
self.hc_mult = int(getattr(mtp_block, "hc_mult", 1) or 1)
self.hc_eps = float(getattr(mtp_block, "hc_eps", 1e-6))
hc_dim = self.hc_mult * hf_config.hidden_size
self.hc_head_fn = torch.nn.Parameter(torch.empty(self.hc_mult, hc_dim, dtype=torch.float32))
self.hc_head_base = torch.nn.Parameter(torch.empty(self.hc_mult, dtype=torch.float32))
self.hc_head_scale = torch.nn.Parameter(torch.empty(1, dtype=torch.float32))
def forward(
self,
inputs_embeds: torch.Tensor,
position_ids: torch.Tensor,
previous_hidden_states: torch.Tensor,
position_embeddings: Optional[torch.Tensor] = None,
**kwargs,
):
inputs_embeds = self.emb_norm(inputs_embeds)
previous_hidden_states = self.hidden_norm(previous_hidden_states)
hidden_states = self.linear_proj(torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
hidden_states = hidden_states.unsqueeze(2).repeat(1, 1, self.hc_mult, 1)
hidden_states = self.mtp_block(
hidden_states,
position_ids=position_ids,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = torch.ops.tensor_cast.hc_head(
hidden_states,
self.hc_head_fn,
self.hc_head_scale,
self.hc_head_base,
self.hc_mult,
self.hc_eps,
)
return hidden_states