"""Transformers model patches for NPU optimization.
Patches:
- Qwen3RMSNorm: torch_npu.npu_rms_norm (transformers >= 4.51.0, including v5.x)
- Qwen3RoPE: torch_npu.npu_rotary_mul (transformers >= 4.51.0, including v5.x)
- Qwen3VLTextRMSNorm: torch_npu.npu_rms_norm for Qwen3-VL text model (transformers >= 4.57.0, including v5.x)
- Qwen3VLRoPE: torch_npu.npu_rotary_mul for Qwen3-VL vision rotary (transformers >= 4.57.0, including v5.x)
- Qwen3VL_get_placeholder_mask: fixed get_placeholder_mask for Qwen3-VL (transformers >= 4.57.0, including v5.x)
- BaseImageProcessorFast_preprocess: force device to NPU in image preprocessor (transformers == 4.57.0)
- Qwen2VL_preprocess: optimized vision preprocessing for Qwen3-VL using Qwen2-VL processor (transformers == 4.57.0)
- FlashAttention: mask caching via ATTN_MASK_NPU_CACHE (transformers 4.51.x, 4.52.x only)
- FlashAttentionVarlen: mask caching via ATTN_MASK_NPU_CACHE (transformers 4.51.x, 4.52.x only)
Version compatibility (source verified):
Qwen3: 4.51.x+ and 5.x (v5 removed position_ids from apply_rotary_pos_emb, patch compatible)
Qwen3-VL: 4.57.x+ and 5.x (vision patches compatible with the same NPU ops)
FlashAttention: 4.51.x/4.52.x need patch (no mask caching); 4.53.x+ built-in caching, no patch needed
"""
import os
import importlib
from typing import List, Optional
from mx_driving.patcher.patch import AtomicPatch, BasePatch, Patch, with_imports
from mx_driving.patcher.version import get_version
def _get_transformers_version() -> Optional[str]:
return get_version("transformers")
def _parse_version(version_str: str) -> List[int]:
try:
parts = version_str.split('.')
return [int(p.split('+')[0].split('-')[0]) for p in parts if p]
except (ValueError, AttributeError):
return []
def _is_transformers_compatible_for_qwen3() -> bool:
"""Qwen3 was added in 4.51.0. v5 removed position_ids from apply_rotary_pos_emb
but our patch doesn't use it, so it remains compatible.
"""
version = _get_transformers_version()
if not version:
return False
parts = _parse_version(version)
if len(parts) < 2:
return False
major, minor = parts[0], parts[1]
if major == 4 and minor >= 51:
return True
if major >= 5:
return True
return False
def _is_transformers_compatible_for_preprocess() -> bool:
"""Check if transformers version supports preprocess patches (== 4.57.0)."""
version = _get_transformers_version()
if not version:
return False
parts = _parse_version(version)
if len(parts) < 2:
return False
major, minor = parts[0], parts[1]
if major == 4 and minor == 57:
return True
return False
def _is_transformers_compatible_for_qwen3_vl() -> bool:
"""Check if transformers version supports Qwen3-VL patches (>= 4.57.0 or 5.x)."""
version = _get_transformers_version()
if not version:
return False
parts = _parse_version(version)
if len(parts) < 2:
return False
major, minor = parts[0], parts[1]
if major == 4 and minor >= 57:
return True
if major >= 5:
return True
return False
def _is_transformers_compatible_for_flash_attention() -> bool:
"""FlashAttention patch provides mask caching (ATTN_MASK_NPU_CACHE).
4.51.x/4.52.x: no caching in upstream, patch needed.
4.53.x+: caching built-in, patch not needed.
"""
version = _get_transformers_version()
if not version:
return False
parts = _parse_version(version)
if len(parts) < 2:
return False
major, minor = parts[0], parts[1]
if major == 4 and minor in (51, 52):
return True
return False
def _is_qwen3_available() -> bool:
try:
importlib.import_module("transformers.models.qwen3.modeling_qwen3")
return True
except ImportError:
return False
def _is_qwen3_vl_available() -> bool:
try:
importlib.import_module("transformers.models.qwen3_vl.modeling_qwen3_vl")
return True
except ImportError:
return False
def _is_flash_attention_available() -> bool:
try:
importlib.import_module("transformers.modeling_flash_attention_utils")
importlib.import_module("transformers.integrations.npu_flash_attention")
return True
except ImportError:
return False
ATTN_MASK_NPU_CACHE = {}
TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2
DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3
SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=str(DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE)))
def get_attn_mask_npu(device):
import torch
if device not in ATTN_MASK_NPU_CACHE:
ATTN_MASK_NPU_CACHE[device] = torch.triu(torch.ones([2048, 2048], device=device), diagonal=1).bool()
return ATTN_MASK_NPU_CACHE[device]
class Qwen3RMSNorm(Patch):
"""Replace Qwen3RMSNorm.forward with torch_npu.npu_rms_norm.
Requirements: transformers >= 4.51.0 (including v5.x), torch_npu
"""
name = "qwen3_rmsnorm"
legacy_name = "qwen3_rmsnorm"
target_module = "transformers.models.qwen3.modeling_qwen3"
@staticmethod
def precheck() -> bool:
return _is_transformers_compatible_for_qwen3() and _is_qwen3_available()
@staticmethod
@with_imports("torch_npu")
def forward(self, hidden_states):
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
AtomicPatch(
"transformers.models.qwen3.modeling_qwen3.Qwen3RMSNorm.forward",
cls.forward,
precheck=cls.precheck,
),
]
class Qwen3RoPE(Patch):
"""Replace apply_rotary_pos_emb with torch_npu.npu_rotary_mul.
v5 removed position_ids param but our patch doesn't use it, so compatible.
Requirements: transformers >= 4.51.0 (including v5.x), torch_npu
"""
name = "qwen3_rope"
legacy_name = "qwen3_rope"
target_module = "transformers.models.qwen3.modeling_qwen3"
@staticmethod
def precheck() -> bool:
return _is_transformers_compatible_for_qwen3() and _is_qwen3_available()
@staticmethod
@with_imports("torch_npu")
def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
return q_embed, k_embed
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
AtomicPatch(
"transformers.models.qwen3.modeling_qwen3.apply_rotary_pos_emb",
cls.apply_rotary_pos_emb,
precheck=cls.precheck,
),
]
class Qwen3VLTextRMSNorm(Patch):
"""Replace Qwen3VLTextRMSNorm.forward with torch_npu.npu_rms_norm for Qwen3-VL text model.
Requirements: transformers >= 4.57.0 (including v5.x), torch_npu
"""
name = "qwen3vltext_rmsnorm"
legacy_name = "qwen3vltext_rmsnorm"
target_module = "transformers.models.qwen3_vl.modeling_qwen3_vl"
@staticmethod
def precheck() -> bool:
return _is_transformers_compatible_for_qwen3_vl() and _is_qwen3_vl_available()
@staticmethod
@with_imports("torch_npu")
def forward(self, hidden_states):
return torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.variance_epsilon)[0]
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
AtomicPatch(
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLTextRMSNorm.forward",
cls.forward,
precheck=cls.precheck,
),
]
class Qwen3VLRoPE(Patch):
"""Replace apply_rotary_pos_emb_vision with torch_npu.npu_rotary_mul for Qwen3-VL vision encoder.
Requirements: transformers >= 4.57.0 (including v5.x), torch_npu
"""
name = "qwen3_vl_rope_vision"
legacy_name = "qwen3_vl_rope_vision"
target_module = "transformers.models.qwen3_vl.modeling_qwen3_vl"
@staticmethod
def precheck() -> bool:
return _is_transformers_compatible_for_qwen3_vl() and _is_qwen3_vl_available()
@staticmethod
@with_imports("torch_npu")
def apply_rotary_pos_emb_vision(q, k, cos, sin):
orig_q_dtype = q.dtype
orig_k_dtype = k.dtype
q, k = q.float(), k.float()
cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
q_embed = torch_npu.npu_rotary_mul(q, cos, sin)
k_embed = torch_npu.npu_rotary_mul(k, cos, sin)
q_embed = q_embed.to(orig_q_dtype)
k_embed = k_embed.to(orig_k_dtype)
return q_embed, k_embed
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
AtomicPatch(
"transformers.models.qwen3_vl.modeling_qwen3_vl.apply_rotary_pos_emb_vision",
cls.apply_rotary_pos_emb_vision,
precheck=cls.precheck,
),
]
class Qwen3VL_get_placeholder_mask(Patch):
"""Replace Qwen3VLModel.get_placeholder_mask with an NPU-optimized version.
Requirements: transformers >= 4.57.0 (including v5.x)
"""
name = "get_placeholder_mask"
legacy_name = "get_placeholder_mask"
target_module = "transformers.models.qwen3_vl.modeling_qwen3_vl"
@staticmethod
def precheck() -> bool:
return _is_transformers_compatible_for_qwen3_vl() and _is_qwen3_vl_available()
@staticmethod
@with_imports("torch")
def get_placeholder_mask(
self,
input_ids,
inputs_embeds,
image_features=None,
video_features=None,
):
if input_ids is None:
special_image_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_image_mask = special_image_mask.all(-1)
special_video_mask = inputs_embeds == self.get_input_embeddings()(
torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
)
special_video_mask = special_video_mask.all(-1)
else:
special_image_mask = input_ids == self.config.image_token_id
special_video_mask = input_ids == self.config.video_token_id
n_image_tokens = special_image_mask.sum()
special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if image_features is not None and special_image_mask.sum().item() != image_features.numel():
raise ValueError(
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
)
n_video_tokens = special_video_mask.sum()
special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
if video_features is not None and special_video_mask.sum().item() != video_features.numel():
raise ValueError(
f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
)
return special_image_mask, special_video_mask
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
AtomicPatch(
"transformers.models.qwen3_vl.modeling_qwen3_vl.Qwen3VLModel.get_placeholder_mask",
cls.get_placeholder_mask,
precheck=cls.precheck,
),
]
class BaseImageProcessorFast_preprocess(Patch):
"""Override BaseImageProcessorFast.preprocess to force device to NPU.
Requirements: transformers == 4.57.0
"""
name = "preprocess"
legacy_name = "preprocess"
target_module = "transformers.image_processing_utils_fast"
@staticmethod
def precheck() -> bool:
return _is_transformers_compatible_for_preprocess() and _is_qwen3_vl_available()
@staticmethod
@with_imports("os", "transformers")
def preprocess(self, images, *args, **kwargs):
@transformers.utils.auto_docstring
def preprocess_auto(self, images, *args, **kwargs):
transformers.image_utils.validate_kwargs(
captured_kwargs=kwargs.keys(), valid_processor_keys=self._valid_kwargs_names
)
for kwarg_name in self._valid_kwargs_names:
kwargs.setdefault(kwarg_name, getattr(self, kwarg_name, None))
do_convert_rgb = kwargs.pop("do_convert_rgb")
input_data_format = kwargs.pop("input_data_format")
device = kwargs.pop("device")
local_rank = int(os.environ["LOCAL_RANK"])
device = f"npu:{local_rank}"
kwargs = self._further_process_kwargs(**kwargs)
self._validate_preprocess_kwargs(**kwargs)
kwargs.pop("data_format")
return self._preprocess_image_like_inputs(
images,
*args,
do_convert_rgb=do_convert_rgb,
input_data_format=input_data_format,
device=device,
**kwargs,
)
return preprocess_auto(self, images, *args, **kwargs)
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
AtomicPatch(
"transformers.image_processing_utils_fast.BaseImageProcessorFast.preprocess",
cls.preprocess,
precheck=cls.precheck,
),
]
class Qwen2VL_preprocess(Patch):
"""Replace Qwen2VLImageProcessorFast._preprocess with an NPU-optimized version.
Requirements: transformers == 4.57.0
"""
name = "_preprocess"
legacy_name = "_preprocess"
target_module = "transformers.models.qwen2_vl.image_processing_qwen2_vl_fast"
@staticmethod
def precheck() -> bool:
return _is_transformers_compatible_for_preprocess() and _is_qwen3_vl_available()
@staticmethod
@with_imports("torch", "transformers")
def _preprocess(
self,
images,
do_resize,
size,
interpolation,
do_rescale,
rescale_factor,
do_normalize,
image_mean,
image_std,
patch_size,
temporal_patch_size,
merge_size,
disable_grouping,
return_tensors,
**kwargs,
):
grouped_images, grouped_images_index = transformers.image_processing_utils_fast.group_images_by_shape(
images, disable_grouping=disable_grouping
)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
height, width = stacked_images.shape[-2:]
if do_resize:
resized_height, resized_width = transformers.models.qwen2_vl.image_processing_qwen2_vl.smart_resize(
height,
width,
factor=patch_size * merge_size,
min_pixels=size["shortest_edge"],
max_pixels=size["longest_edge"],
)
stacked_images = self.resize(
image=stacked_images,
size=transformers.image_utils.SizeDict(height=resized_height, width=resized_width),
interpolation=interpolation,
)
resized_images_grouped[shape] = stacked_images
resized_images = transformers.image_processing_utils_fast.reorder_images(
resized_images_grouped, grouped_images_index
)
grouped_images, grouped_images_index = transformers.image_processing_utils_fast.group_images_by_shape(
resized_images, disable_grouping=disable_grouping
)
processed_images_grouped = {}
processed_grids = {}
for shape, stacked_images in grouped_images.items():
resized_height, resized_width = stacked_images.shape[-2:]
patches = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
if patches.ndim == 4:
patches = patches.unsqueeze(1)
if patches.shape[1] % temporal_patch_size != 0:
repeats = patches[:, -1:].repeat(1, temporal_patch_size - 1, 1, 1, 1)
patches = torch.cat([patches, repeats], dim=1)
batch_size, grid_t, channel = patches.shape[:3]
grid_t = grid_t // temporal_patch_size
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size
tmp1 = patches.view(
batch_size,
grid_t,
temporal_patch_size
* channel
* (grid_h // merge_size)
* merge_size
* patch_size
* (grid_w // merge_size)
* merge_size,
patch_size,
).permute(0, 1, 3, 2)
tmp2 = tmp1.reshape(
batch_size * grid_t * patch_size,
temporal_patch_size,
channel,
grid_h // merge_size,
merge_size,
patch_size,
grid_w // merge_size,
merge_size,
).permute(0, 3, 6, 4, 7, 2, 1, 5)
tmp3 = tmp2.reshape(
batch_size * grid_t * patch_size, grid_h * grid_w, channel * temporal_patch_size * patch_size
)
tmp4 = tmp3.view(
batch_size, grid_t, patch_size, grid_h * grid_w, channel * temporal_patch_size * patch_size
).permute(0, 1, 3, 4, 2)
flatten_patches = tmp4.reshape(
batch_size,
grid_t * grid_h * grid_w,
channel * temporal_patch_size * patch_size * patch_size,
)
processed_images_grouped[shape] = flatten_patches
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size
processed_images = transformers.image_processing_utils_fast.reorder_images(
processed_images_grouped, grouped_images_index
)
processed_grids = transformers.image_processing_utils_fast.reorder_images(processed_grids, grouped_images_index)
pixel_values = torch.cat(processed_images, dim=0)
image_grid_thw = torch.tensor(processed_grids)
return transformers.BatchFeature(
data={"pixel_values": pixel_values, "image_grid_thw": image_grid_thw}, tensor_type=return_tensors
)
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
AtomicPatch(
"transformers.models.qwen2_vl.image_processing_qwen2_vl_fast.Qwen2VLImageProcessorFast._preprocess",
cls._preprocess,
precheck=cls.precheck,
),
]
class FlashAttention(Patch):
"""Replace flash_attn_func with torch_npu.npu_fusion_attention, adding mask caching.
4.51.x/4.52.x: upstream creates mask on every call, patch adds caching.
4.53.x+: upstream already caches mask, patch not needed.
Requirements: transformers >= 4.51.0, < 4.53.0, torch_npu
"""
name = "flash_attention"
legacy_name = "flash_attention"
target_module = "transformers.integrations.npu_flash_attention"
@staticmethod
def precheck() -> bool:
return _is_transformers_compatible_for_flash_attention() and _is_flash_attention_available()
@staticmethod
@with_imports("torch_npu", "torch")
def flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, **kwargs):
keep_prob = 1.0 - dropout_p
if not causal:
head_num = q.shape[2]
output = torch_npu.npu_fusion_attention(
q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale
)[0]
else:
attn_mask_npu = get_attn_mask_npu(q.device)
head_num = q.shape[2]
output = torch_npu.npu_fusion_attention(
q,
k,
v,
head_num,
"BSND",
keep_prob=keep_prob,
scale=softmax_scale,
atten_mask=attn_mask_npu,
sparse_mode=SPARSE_MODE,
)[0]
return output
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
AtomicPatch(
"transformers.modeling_flash_attention_utils.flash_attn_func",
cls.flash_attn_func,
precheck=cls.precheck,
aliases=[
"transformers.integrations.npu_flash_attention.flash_attn_func",
"transformers.integrations.npu_flash_attention.npu_flash_attn_func",
"transformers.modeling_flash_attention_utils.npu_flash_attn_func",
],
),
]
class FlashAttentionVarlen(Patch):
"""Replace flash_attn_varlen_func with torch_npu.npu_fusion_attention, adding mask caching.
4.51.x/4.52.x: upstream creates mask on every call, patch adds caching.
4.53.x+: upstream already caches mask, patch not needed.
Requirements: transformers >= 4.51.0, < 4.53.0, torch_npu
"""
name = "flash_attention_varlen"
legacy_name = "flash_attention_varlen"
target_module = "transformers.integrations.npu_flash_attention"
@staticmethod
def precheck() -> bool:
return _is_transformers_compatible_for_flash_attention() and _is_flash_attention_available()
@staticmethod
@with_imports("torch_npu", "torch")
def flash_attn_varlen_func(
q, k, v, cu_seqlens_q, cu_seqlens_k, dropout_p=0.0, softmax_scale=None, causal=False, **kwargs
):
keep_prob = 1.0 - dropout_p
if not causal:
head_num = q.shape[1]
output = torch_npu.npu_fusion_attention(
q,
k,
v,
head_num,
pse=None,
atten_mask=None,
scale=softmax_scale,
keep_prob=keep_prob,
input_layout="TND",
actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
)[0]
else:
attn_mask_npu = get_attn_mask_npu(q.device)
head_num = q.shape[1]
output = torch_npu.npu_fusion_attention(
q,
k,
v,
head_num,
pse=None,
padding_mask=None,
atten_mask=attn_mask_npu,
scale=softmax_scale,
keep_prob=keep_prob,
input_layout="TND",
actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
sparse_mode=SPARSE_MODE,
)[0]
return output
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
AtomicPatch(
"transformers.modeling_flash_attention_utils.flash_attn_varlen_func",
cls.flash_attn_varlen_func,
precheck=cls.precheck,
aliases=[
"transformers.integrations.npu_flash_attention.flash_attn_varlen_func",
"transformers.integrations.npu_flash_attention.npu_flash_attn_varlen_func",
"transformers.modeling_flash_attention_utils.npu_flash_attn_varlen_func",
],
),
]
class TransformersNPU(Patch):
"""Composite patch: Qwen3RMSNorm + Qwen3RoPE (>=4.51.0) + Qwen3-VL patches (>=4.57.0) + FlashAttention (4.51.x/4.52.x).
FlashAttention patches only apply to 4.51.x/4.52.x (4.53.x+ has built-in mask caching).
Qwen3-VL patches activate when the Qwen3-VL model is available and transformers >= 4.57.0 (or 5.x).
Requirements: transformers >= 4.51.0, torch_npu
"""
name = "transformers_npu"
legacy_name = "transformers_npu"
target_module = "transformers"
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
all_patches = []
if _is_qwen3_available():
all_patches.extend(Qwen3RMSNorm.patches(options))
all_patches.extend(Qwen3RoPE.patches(options))
if _is_qwen3_vl_available():
all_patches.extend(Qwen3VLTextRMSNorm.patches(options))
all_patches.extend(Qwen3VLRoPE.patches(options))
all_patches.extend(Qwen3VL_get_placeholder_mask.patches(options))
all_patches.extend(BaseImageProcessorFast_preprocess.patches(options))
all_patches.extend(Qwen2VL_preprocess.patches(options))
if _is_flash_attention_available():
all_patches.extend(FlashAttention.patches(options))
all_patches.extend(FlashAttentionVarlen.patches(options))
return all_patches