from collections.abc import Mapping
from types import MappingProxyType
def is_layer_skipped(
prefix: str,
ignored_layers: list[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({})
) -> bool:
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in ignored_layers
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. Ensure all shards of fused layers "
"use the same precision.")
else:
is_skipped = prefix in ignored_layers
assert is_skipped is not None
return is_skipped
def reshape_mx_scale(scale_tensor):
"""
Reshape the last dimension of 2D/3D tensor into (original_size // 2, 2) for GMM/MM operators.
"""
return scale_tensor.view(*scale_tensor.shape[:-1], scale_tensor.size(-1) // 2, 2)