import unittest
from pathlib import Path
import torch
from parameterized import parameterized
from tensor_cast.model_config import ModelConfig, ParallelConfig, QuantConfig
from tensor_cast.transformers.builtin_model.minimax_m2 import shard_qk_norm
from tensor_cast.transformers.model import TransformerModel
from torch import nn
class MiniMaxM2ShardQkNormTestCase(unittest.TestCase):
def setUp(self):
self.model_id = str(Path(__file__).resolve().parents[2] / "assets" / "model_config" / "minimax_m2")
def _build_model(self):
model_config = ModelConfig(
ParallelConfig(),
QuantConfig(),
num_hidden_layers_override=1,
)
return TransformerModel(self.model_id, model_config)
def _get_self_attn(self, model: TransformerModel):
layer = model.unwrap().layers[0]
self_attn = layer
while hasattr(self_attn, "_inner"):
self_attn = self_attn._inner
if hasattr(self_attn, "self_attn"):
self_attn = self_attn.self_attn
return self_attn
def _set_tp_group(self, model: TransformerModel, tp_size: int, tp_rank: int):
tp_group = model.parallel_group_manager.tp_group
tp_group.world_size = tp_size
tp_group.rank_in_group = tp_rank
def _set_qk_norm_weights(
self,
self_attn: nn.Module,
q_requires_grad: bool = True,
k_requires_grad: bool = True,
):
self_attn.q_norm.weight = nn.Parameter(
torch.arange(6144, dtype=torch.float32),
requires_grad=q_requires_grad,
)
self_attn.k_norm.weight = nn.Parameter(
torch.arange(1024, dtype=torch.float32),
requires_grad=k_requires_grad,
)
@parameterized.expand(
[
(
"tp_enabled",
8,
3,
True,
False,
True,
torch.Size([768]),
torch.Size([128]),
torch.arange(2304, 3072, dtype=torch.float32),
torch.arange(384, 512, dtype=torch.float32),
False,
True,
),
(
"gqa_k_norm_rank7",
8,
7,
True,
True,
True,
None,
torch.Size([128]),
None,
torch.arange(896, 1024, dtype=torch.float32),
True,
True,
),
]
)
def test_shard_qk_norm_shards_weights(
self,
_name,
tp_size,
tp_rank,
use_qk_norm,
q_requires_grad,
k_requires_grad,
expected_q_shape,
expected_k_shape,
expected_q_values,
expected_k_values,
expected_q_requires_grad,
expected_k_requires_grad,
):
model = self._build_model()
self_attn = self._get_self_attn(model)
self._set_qk_norm_weights(
self_attn,
q_requires_grad=q_requires_grad,
k_requires_grad=k_requires_grad,
)
self._set_tp_group(model, tp_size=tp_size, tp_rank=tp_rank)
model.hf_config.use_qk_norm = use_qk_norm
result = shard_qk_norm(model)
self.assertIs(result, model)
if expected_q_shape is not None:
self.assertEqual(self_attn.q_norm.weight.shape, expected_q_shape)
if expected_k_shape is not None:
self.assertEqual(self_attn.k_norm.weight.shape, expected_k_shape)
if expected_q_values is not None:
self.assertTrue(torch.equal(self_attn.q_norm.weight.detach(), expected_q_values))
if expected_k_values is not None:
self.assertTrue(torch.equal(self_attn.k_norm.weight.detach(), expected_k_values))
self.assertEqual(
self_attn.q_norm.weight.requires_grad,
expected_q_requires_grad,
)
self.assertEqual(
self_attn.k_norm.weight.requires_grad,
expected_k_requires_grad,
)
@parameterized.expand(
[
("tp_size_one", 1, 0, True),
("qk_norm_disabled", 8, 3, False),
]
)
def test_shard_qk_norm_returns_early(
self,
_name,
tp_size,
tp_rank,
use_qk_norm,
):
model = self._build_model()
self_attn = self._get_self_attn(model)
self._set_qk_norm_weights(self_attn)
self._set_tp_group(model, tp_size=tp_size, tp_rank=tp_rank)
model.hf_config.use_qk_norm = use_qk_norm
original_q_weight = self_attn.q_norm.weight
original_k_weight = self_attn.k_norm.weight
result = shard_qk_norm(model)
self.assertIs(result, model)
self.assertIs(self_attn.q_norm.weight, original_q_weight)
self.assertIs(self_attn.k_norm.weight, original_k_weight)
self.assertEqual(self_attn.q_norm.weight.shape, torch.Size([6144]))
self.assertEqual(self_attn.k_norm.weight.shape, torch.Size([1024]))
if __name__ == "__main__":
unittest.main()