import json
import tempfile
import os
from unittest import mock
import pytest
import torch
from mindie_llm.runtime.utils.distributed.utils import (
get_device_from_ranktable,
even_divide,
set_device
)
from mindie_llm.runtime.utils.helpers.env import ENV
def test_even_divide_success():
assert even_divide(12, 3) == 4
assert even_divide(100, 1) == 100
def test_even_divide_not_divisible():
with pytest.raises(ValueError, match="10 is not evenly divisible by 3"):
even_divide(10, 3)
def test_even_divide_zero_denominator():
with pytest.raises(ZeroDivisionError):
even_divide(5, 0)
@pytest.fixture
def valid_rank_table_data():
return {
"server_list": [
{
"device": [
{"rank_id": "0", "device_id": "0"},
{"rank_id": "1", "device_id": "1"}
]
},
{
"device": [
{"rank_id": "2", "device_id": "2"},
{"rank_id": "3", "device_id": "3"}
]
}
]
}
def test_get_device_from_ranktable_found(valid_rank_table_data):
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(valid_rank_table_data, f)
f.flush()
try:
device = get_device_from_ranktable(rank=2, rank_table=f.name)
assert device == torch.device("npu:2")
finally:
os.unlink(f.name)
def test_get_device_from_ranktable_rank_not_found(valid_rank_table_data):
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(valid_rank_table_data, f)
f.flush()
try:
with pytest.raises(ValueError, match="Rank id is not in the rankTableFile.*5"):
get_device_from_ranktable(rank=5, rank_table=f.name)
finally:
os.unlink(f.name)
def test_get_device_from_ranktable_invalid_json():
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write("{ invalid json }")
f.flush()
try:
with pytest.raises(json.JSONDecodeError):
get_device_from_ranktable(rank=0, rank_table=f.name)
finally:
os.unlink(f.name)
def test_get_device_from_ranktable_file_not_exist():
with pytest.raises(FileNotFoundError):
get_device_from_ranktable(rank=0, rank_table="/non/existent/path.json")
@pytest.fixture(autouse=True)
def clean_env():
"""Ensure ENV.rank_table_file is reset after each test."""
original = ENV.rank_table_file
yield
ENV.rank_table_file = original
@mock.patch("torch.npu.set_device")
def test_set_device_without_rank_table(mock_set_device):
"""Test set_device when rank_table_file is NOT set."""
ENV.rank_table_file = None
device = set_device(rank=3, npu_id=5)
assert device == torch.device("npu:5")
mock_set_device.assert_called_once_with(device)
@mock.patch("torch.npu.set_device")
def test_set_device_default_npu_id(mock_set_device):
"""Test npu_id defaults to rank when not provided."""
ENV.rank_table_file = None
device = set_device(rank=7)
assert device == torch.device("npu:7")
mock_set_device.assert_called_once_with(device)
@mock.patch("torch.npu.set_device")
def test_set_device_with_rank_table(mock_set_device, valid_rank_table_data):
"""Test set_device uses rank table when ENV.rank_table_file is set."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(valid_rank_table_data, f)
f.flush()
try:
ENV.rank_table_file = f.name
device = set_device(rank=1, npu_id=99)
assert device == torch.device("npu:1")
mock_set_device.assert_called_once_with(device)
finally:
ENV.rank_table_file = None
@mock.patch("torch.npu.set_device")
def test_set_device_rank_not_in_table(mock_set_device):
"""Test set_device raises error if rank not in rank table."""
rank_table = {"server_list": [{"device": [{"rank_id": "0", "device_id": "0"}]}]}
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(rank_table, f)
f.flush()
try:
ENV.rank_table_file = f.name
with pytest.raises(ValueError, match="Rank id is not in the rankTableFile.*5"):
set_device(rank=5)
finally:
ENV.rank_table_file = None