from unittest.mock import MagicMock
import mindspeed.megatron_adaptor
import pytest
import torch
from tests.ut.utils import judge_expression
from mindspeed_mm.models.vision.vlm_attentionmask_for_llm import qwen2vl_get_rope_index
@pytest.fixture
def mock_args():
args = MagicMock()
args.mm.model.vision_start_token_id = 151652
args.mm.model.image_token_id = 151655
args.mm.model.video_token_id = 151656
args.mm.model.image_encoder.vision_encoder.spatial_merge_size = 2
args.mm.model.image_encoder.vision_encoder.tokens_per_second = 2
return args
@pytest.fixture(autouse=True)
def patch_get_args(mock_args, mocker):
mocker.patch("mindspeed_mm.models.vision.vlm_attentionmask_for_llm.get_args", return_value=mock_args)
class TestQwen2VLGetRoPEIndex:
def test_pure_text_input_with_attention_mask(self):
input_ids = torch.tensor([
[151644, 101, 102, 103],
[151644, 104, 105, 151643],
])
attention_mask = torch.tensor([
[1, 1, 1, 1],
[1, 1, 1, 0],
])
pos_ids, deltas = qwen2vl_get_rope_index(
config=None,
input_ids=input_ids,
attention_mask=attention_mask
)
expected_deltas = torch.tensor([[0], [-1]])
judge_expression(torch.equal(deltas, expected_deltas))
expected_pos_ids = torch.tensor([[[0, 1, 2, 3], [0, 1, 2, 1]],
[[0, 1, 2, 3], [0, 1, 2, 1]],
[[0, 1, 2, 3], [0, 1, 2, 1]]])
judge_expression(torch.equal(pos_ids, expected_pos_ids))
def test_pure_text_input_without_attention_mask(self):
input_ids = torch.tensor([[151644, 101, 102, 103]])
pos_ids, deltas = qwen2vl_get_rope_index(
config=None,
input_ids=input_ids
)
expected_deltas = torch.tensor([[0]])
judge_expression(torch.equal(deltas, expected_deltas))
expected_pos_ids = torch.tensor([[[0, 1, 2, 3]],
[[0, 1, 2, 3]],
[[0, 1, 2, 3]]])
judge_expression(torch.equal(pos_ids, expected_pos_ids))
def test_video_and_text_input(self):
input_ids = torch.tensor([[151644, 101, 102, 103, 151652, 151656, 151656, 151656, 151656, 151656, 151656, 151656, 151656,
151656, 151656, 151656, 151656, 151653, 151644, 104, 105, 106]])
image_grid_thw = torch.tensor([[2, 4, 6]])
pos_ids, deltas = qwen2vl_get_rope_index(
config=None,
input_ids=input_ids,
image_grid_thw=image_grid_thw
)
expected_deltas = torch.tensor([[-9]])
judge_expression(torch.equal(deltas, expected_deltas))
expected_pos_ids = torch.tensor([[[0, 1, 2, 3, 4, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 8, 9, 10, 11, 12]],
[[0, 1, 2, 3, 4, 5, 5, 5, 6, 6, 6, 5, 5, 5, 6, 6, 6, 8, 9, 10, 11, 12]],
[[0, 1, 2, 3, 4, 5, 6, 7, 5, 6, 7, 5, 6, 7, 5, 6, 7, 8, 9, 10, 11, 12]]])
judge_expression(torch.equal(pos_ids, expected_pos_ids))