import os
from copy import deepcopy
from pathlib import Path
from unittest.mock import patch
import numpy as np
import pytest
from experimental.optix.config.config import (
DecodeContext,
ErrorPatternConfig,
ErrorType,
HealthCheckConfig,
MindieConfig,
OptimizerConfigField,
_get_mindie_config_paths,
_repair_ternary_factories_with_priority,
map_param_with_value,
resolve_priority,
update_optimizer_value,
)
DEFAULT_MINDIE_CONFIG = Path("/usr/local/Ascend/mindie/latest/mindie-service/conf/config.json")
DEFAULT_MINDIE_BACKUP = Path("/usr/local/Ascend/mindie/latest/mindie-service/conf/config_bak.json")
def field(
name,
dtype,
min_=0,
max_=0,
value=0,
dtype_param=None,
config_position=None,
constant=None,
):
return OptimizerConfigField(
name=name,
config_position=config_position or f"Test.{name}",
min=min_,
max=max_,
dtype=dtype,
value=value,
dtype_param=dtype_param,
constant=constant,
)
def schedule_fields():
return [
field(
"max_batch_size",
"int",
25,
300,
config_position="BackendConfig.ScheduleConfig.maxBatchSize",
),
field(
"max_prefill_batch_size",
"int",
1,
25,
config_position="BackendConfig.ScheduleConfig.maxPrefillBatchSize",
),
field(
"prefill_time_ms_per_req",
"int",
0,
1000,
config_position="BackendConfig.ScheduleConfig.prefillTimeMsPerReq",
),
field(
"decode_time_ms_per_req",
"int",
0,
1000,
config_position="BackendConfig.ScheduleConfig.decodeTimeMsPerReq",
),
field(
"support_select_batch",
"bool",
0,
1,
config_position="BackendConfig.ScheduleConfig.supportSelectBatch",
),
field(
"max_prefill_token",
"int",
4096,
409600,
config_position="BackendConfig.ScheduleConfig.maxPrefillTokens",
),
field(
"max_queue_deloy_microseconds",
"int",
500,
1000000,
config_position="BackendConfig.ScheduleConfig.maxQueueDelayMicroseconds",
),
field(
"prefill_policy_type",
"enum",
0,
1,
dtype_param=[0, 1, 3],
config_position="BackendConfig.ScheduleConfig.prefillPolicyType",
),
field(
"decode_policy_type",
"enum",
0,
1,
dtype_param=[0, 1, 3],
config_position="BackendConfig.ScheduleConfig.decodePolicyType",
),
field(
"max_preempt_count",
"ratio",
0,
1,
dtype_param="max_batch_size",
config_position="BackendConfig.ScheduleConfig.maxPreemptCount",
),
]
def pd_share_fields():
return [
field("default_p_rate", "int", 1, 3, 1, config_position="default_p_rate"),
field(
"default_d_rate",
"share",
1,
3,
dtype_param="default_p_rate",
config_position="default_d_rate",
),
]
def clone_with_values(fields, values):
cloned = [deepcopy(item) for item in fields]
for item, value in zip(cloned, values):
item.value = value
return cloned
def derive(fields, values, support_select_is_false=False, context=None):
runtime_fields = clone_with_values(fields, values)
update_optimizer_value(tuple(fields), tuple(runtime_fields), support_select_is_false, context)
return runtime_fields
def pair_fields(product=32, policy="balanced", priority=None, tp_candidates=None, pp_candidates=None):
dtype_param = {
"target_names": ["tp", "pp"],
"product": product,
"dtype": "int",
"priority_policy": policy,
}
if priority is not None:
dtype_param["priority"] = priority
return (
field("tp", "enum", 0, 1, dtype_param=tp_candidates or [1, 2, 4, 8]),
field("pp", "enum", 0, 1, dtype_param=pp_candidates or [1, 2, 4]),
field("dp", "ternary_factories", 0, 0, dtype_param=dtype_param),
)
@pytest.mark.parametrize(
"params,fields,expected",
[
(
np.array([26.7, 12.3, 999.9, 500.0, 0.6, 40960.0, 750000.0]),
schedule_fields()[:7],
[26, 12, 999, 500, True, 40960, 750000],
),
(
np.array([24.9, 0.0, 0.0, 0.0, 0.4, 4095.9, 499.9, -1.0, 2.0, 1.1]),
schedule_fields(),
[24, 1, 0, 0, False, 4095, 499, 0],
),
],
)
def test_map_param_converts_schedule_fields(params, fields, expected):
result = map_param_with_value(params, fields)
assert [item.value for item in result[: len(expected)]] == expected
def test_map_param_selects_numeric_enum_segments():
result = map_param_with_value(np.array([0.0, 0.3, 0.6, 1.0]), schedule_fields()[7:9])
assert [item.value for item in result] == [0, 0]
def test_ratio_field_uses_resolved_target_value():
max_batch_size = field(
"max_batch_size",
"int",
value=100,
constant=100,
config_position="BackendConfig.ScheduleConfig.maxBatchSize",
)
ratio = schedule_fields()[9]
result = map_param_with_value(np.array([0.5]), [max_batch_size, ratio])
assert result[1].value == 50
def test_share_field_keeps_complementary_rate():
assert map_param_with_value(np.array([1, 2]), pd_share_fields())[1].value == 3
def test_error_pattern_config_accepts_custom_and_empty_sets():
custom = ErrorPatternConfig(
fatal_patterns={ErrorType.OUT_OF_MEMORY: ["custom OOM pattern"]},
retryable_patterns={ErrorType.NETWORK_ERROR: ["custom network pattern"]},
)
empty = ErrorPatternConfig(fatal_patterns={}, retryable_patterns={})
assert custom.fatal_patterns[ErrorType.OUT_OF_MEMORY] == ["custom OOM pattern"]
assert custom.retryable_patterns[ErrorType.NETWORK_ERROR] == ["custom network pattern"]
assert empty.fatal_patterns == {}
assert empty.retryable_patterns == {}
def test_health_check_config_defaults_and_overrides():
default = HealthCheckConfig()
custom = HealthCheckConfig(
service_errors=ErrorPatternConfig(
fatal_patterns={ErrorType.DEVICE_ERROR: ["device fault"]},
retryable_patterns={},
),
benchmark_errors=ErrorPatternConfig(fatal_patterns={}, retryable_patterns={ErrorType.IO_ERROR: ["disk full"]}),
log_snippet_length=300,
)
assert isinstance(default.service_errors, ErrorPatternConfig)
assert ErrorType.OUT_OF_MEMORY in default.service_errors.fatal_patterns
assert ErrorType.NETWORK_ERROR in default.service_errors.retryable_patterns
assert default.benchmark_errors.fatal_patterns == {}
assert ErrorType.IO_ERROR in default.benchmark_errors.retryable_patterns
assert HealthCheckConfig(log_snippet_length=500).log_snippet_length == 500
assert custom.service_errors.fatal_patterns[ErrorType.DEVICE_ERROR] == ["device fault"]
assert custom.benchmark_errors.retryable_patterns[ErrorType.IO_ERROR] == ["disk full"]
assert custom.log_snippet_length == 300
@patch.object(Path, "is_file")
def test_mindie_config_paths_use_default_when_available(mock_is_file):
mock_is_file.return_value = True
assert _get_mindie_config_paths() == (DEFAULT_MINDIE_CONFIG, DEFAULT_MINDIE_BACKUP)
@patch.object(Path, "is_file")
def test_mindie_config_paths_fallback_to_default_without_env(mock_is_file, monkeypatch):
mock_is_file.return_value = False
monkeypatch.delenv("MIES_INSTALL_PATH", raising=False)
assert _get_mindie_config_paths() == (DEFAULT_MINDIE_CONFIG, DEFAULT_MINDIE_BACKUP)
@patch("experimental.optix.config.config._get_mindie_config_paths")
def test_mindie_config_defaults_are_bound_from_path_resolver(mock_get_paths):
mock_get_paths.return_value = (
Path("/test/config.json"),
Path("/test/config_bak.json"),
)
config = MindieConfig()
assert config.process_name == "mindie, mindie-llm, mindieservice_daemon, mindie_llm"
assert config.output == Path("mindie")
assert config.config_path == Path("/test/config.json")
assert config.config_bak_path == Path("/test/config_bak.json")
assert isinstance(config.target_field, list)
assert config.target_field
@patch("experimental.optix.config.config._get_mindie_config_paths")
def test_mindie_config_allows_custom_output(mock_get_paths):
mock_get_paths.return_value = (
Path("/test/config.json"),
Path("/test/config_bak.json"),
)
assert MindieConfig(output=Path("/custom/output")).output == Path("/custom/output")
@pytest.mark.parametrize(
"kwargs,expected",
[
(
{"min_": 100, "max_": 100, "dtype": "int"},
{"constant": 100, "min": 100, "max": 100},
),
(
{"min_": 0, "max_": 100, "dtype": "int", "constant": 50},
{"constant": 50, "min": 50, "max": 50},
),
],
)
def test_optimizer_field_constant_normalization(kwargs, expected):
item = field("test_field", config_position="test.position", **kwargs)
assert {"constant": item.constant, "min": item.min, "max": item.max} == expected
def test_optimizer_field_rejects_invalid_bounds():
with pytest.raises(ValueError, match="min.*max"):
field("test_field", "int", 100, 0, config_position="test.position")
@pytest.mark.parametrize(
"item,value,expected",
[
(field("bounded", "int", 0, 100), 50, 50),
(field("lower", "int", 0, 100), -10, 0),
(field("upper", "int", 0, 100), 150, 100),
(field("enum_exact", "enum", 0, 1, dtype_param=[1, 2, 4, 8]), 2, 2),
(field("enum_next", "enum", 0, 1, dtype_param=[1, 2, 4, 8]), 3, 4),
(field("enum_floor", "enum", 0, 1, dtype_param=[1, 2, 4, 8]), 0, 1),
],
)
def test_find_available_value_uses_bounds_or_enum_candidates(item, value, expected):
assert item.find_available_value(value) == expected
def test_convert_dtype_uses_field_dtype():
assert field("int_field", "int", config_position="test.position").convert_dtype("42") == 42
assert field("float_field", "float", config_position="test.position").convert_dtype("3.14") == pytest.approx(3.14)
@pytest.mark.parametrize(
"fields,values,index,expected",
[
(
[
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp",
"ternary_factories",
dtype_param={
"target_names": ["tp", "pp"],
"product": 16,
"dtype": "int",
},
),
],
[2, 4, 0],
2,
2,
),
(
[
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp_f",
"ternary_factories",
dtype_param={
"target_names": ["tp", "pp"],
"product": 10.0,
"dtype": "float",
},
),
],
[2, 2, 0.0],
2,
2.5,
),
(
[
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp",
"ternary_factories",
dtype_param={"target_names": ["tp", "pp"], "dtype": "int"},
),
],
[2, 1, 0],
2,
1,
),
(
[
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp",
"ternary_factories",
dtype_param={
"target_names": ["tp", "pp"],
"product": 16,
"dtype": "int",
"min_value": 1,
},
),
],
[8, 4, 99],
2,
1,
),
(
[
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp",
"ternary_factories",
value=99,
dtype_param={
"target_names": ["tp", "pp"],
"product": 16,
"dtype": "int",
},
),
],
[8, 4, 99],
2,
1,
),
(
[
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp",
"ternary_factories",
dtype_param={
"target_names": ["tp", "pp"],
"product": 64,
"dtype": "int",
"max_value": 8,
},
),
],
[1, 1, 0],
2,
8,
),
(
[
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp",
"ternary_factories",
dtype_param={
"target_names": ["tp", "pp"],
"product": 8,
"dtype": "int",
"min_value": 1,
"max_value": 3,
},
),
],
[2, 2, 0],
2,
2,
),
(
[
field("seq_len", "int", 128, 4096),
field("batch_size", "int", 1, 64),
field(
"total_tokens",
"ternary_times",
dtype_param={
"target_names": ["seq_len", "batch_size"],
"product": 2,
"dtype": "int",
},
),
],
[512, 4, 0],
2,
4096,
),
(
[
field("a", "int", 1, 10),
field("b", "int", 1, 10),
field(
"c",
"ternary_times",
dtype_param={
"target_names": ["a", "b"],
"product": 1,
"dtype": "int",
},
),
],
[3, 7, 0],
2,
21,
),
(
[
field("a", "int", 1, 10),
field("b", "int", 1, 10),
field(
"c",
"ternary_times",
dtype_param={"target_names": ["a", "b"], "dtype": "int"},
),
],
[3, 5, 0],
2,
15,
),
],
)
def test_ternary_derived_fields_update_value(fields, values, index, expected):
assert derive(fields, values)[index].value == expected
@pytest.mark.parametrize(
"fields,values,index,original",
[
(
[
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp",
"ternary_factories",
value=99,
dtype_param={
"target_names": ["tp", "pp"],
"product": 16,
"dtype": "int",
},
),
],
[0, 4, 99],
2,
99,
),
(
[
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp",
"ternary_factories",
value=88,
dtype_param={
"target_names": ["tp", "pp"],
"product": 16,
"dtype": "int",
},
),
],
[4, 0, 88],
2,
88,
),
(
[
field("a", "float", 1.0, 10.0, value=float("nan")),
field("b", "int", 1, 10),
field(
"c",
"ternary_times",
value=999,
dtype_param={
"target_names": ["a", "b"],
"product": 2,
"dtype": "int",
},
),
],
[float("nan"), 5, 999],
2,
999,
),
(
[
field("a", "int", 1, 10),
field("b", "float", 1.0, 10.0, value=float("nan")),
field(
"c",
"ternary_times",
value=777,
dtype_param={
"target_names": ["a", "b"],
"product": 3,
"dtype": "int",
},
),
],
[5, float("nan"), 777],
2,
777,
),
(
[
field("a", "int", 1, 10),
field(
"c",
"ternary_times",
value=777,
dtype_param={
"target_names": ["a", "missing_b"],
"product": 2,
"dtype": "int",
},
),
],
[3, 777],
1,
777,
),
],
)
def test_ternary_derived_fields_keep_value_when_source_invalid(fields, values, index, original):
assert derive(fields, values)[index].value == original
@pytest.mark.parametrize(
"fields,values,product",
[
(
[
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp",
"ternary_factories",
dtype_param={
"target_names": ["tp", "pp"],
"product": 16,
"dtype": "int",
},
),
],
[8, 4, 0],
16,
),
(
[
field("tp", "enum", 0, 1, dtype_param=[1, 2, 4, 8]),
field("pp", "enum", 0, 1, dtype_param=[1, 2, 4]),
field(
"dp",
"ternary_factories",
dtype_param={
"target_names": ["tp", "pp"],
"product": 16,
"dtype": "int",
},
),
],
[8, 4, 0],
16,
),
(
[
field("tp", "int", 2, 3),
field("pp", "int", 2, 3),
field(
"dp",
"ternary_factories",
dtype_param={
"target_names": ["tp", "pp"],
"product": 12,
"dtype": "int",
},
),
],
[3, 3, 0],
12,
),
],
)
def test_ternary_factories_repair_keeps_product_consistent(fields, values, product):
result = derive(fields, values)
tp_value, pp_value, dp_value = [item.value for item in result]
assert dp_value > 0
assert tp_value * pp_value * dp_value == product
def test_ternary_factories_repair_falls_back_to_clamp_for_non_discrete_sources():
fields = [
field("tp", "float", 0.5, 8.0),
field("pp", "float", 0.5, 4.0),
field(
"dp",
"ternary_factories",
dtype_param={"target_names": ["tp", "pp"], "product": 16, "dtype": "int"},
),
]
assert derive(fields, [8.0, 4.0, 0])[2].value == 1
def test_ternary_factories_non_divisible_and_unrepairable_combination_raises():
fields = [
field("tp", "int", 1, 1000),
field("pp", "enum", 0, 1, value=3, dtype_param=[3]),
field(
"dp",
"ternary_factories",
value=99,
dtype_param={"target_names": ["tp", "pp"], "product": 32, "dtype": "int"},
),
]
with pytest.raises(ValueError, match="product=32 not divisible by divisor=24"):
derive(fields, [8, 3, 99])
def test_ternary_factories_map_param_integration():
fields = [
field("tp", "int", 1, 8),
field("pp", "int", 1, 4),
field(
"dp",
"ternary_factories",
0,
0,
dtype_param={"target_names": ["tp", "pp"], "product": 16, "dtype": "int"},
),
]
result = map_param_with_value(np.array([2.0, 4.0]), fields)
assert [item.value for item in result] == [2, 4, 2]
def test_ternary_times_map_param_integration():
fields = [
field("seq_len", "int", 128, 4096),
field("batch_size", "int", 1, 64),
field(
"total_tokens",
"ternary_times",
0,
0,
dtype_param={
"target_names": ["seq_len", "batch_size"],
"product": 1,
"dtype": "int",
},
),
]
result = map_param_with_value(np.array([512.0, 4.0]), fields)
assert [item.value for item in result] == [512, 4, 2048]
@pytest.mark.parametrize(
"dtype_param,context,expected",
[
(
{
"target_names": ["tp", "pp"],
"priority_policy": "fixed",
"priority": ["pp", "tp"],
},
None,
["pp", "tp"],
),
(
{"target_names": ["tp", "pp"], "priority_policy": "fixed"},
None,
["tp", "pp"],
),
(
{
"target_names": ["tp", "pp"],
"priority_policy": "fixed",
"priority": ["tp"],
},
None,
["tp", "pp"],
),
(
{"target_names": ["tp", "pp"], "priority_policy": "balanced"},
None,
["tp", "pp"],
),
(
{"target_names": ["tp", "pp"], "priority_policy": "balanced"},
DecodeContext(),
["tp", "pp"],
),
(
{"target_names": ["tp", "pp"], "priority_policy": "balanced"},
DecodeContext(particle_index=0, n_particles=10),
["tp", "pp"],
),
(
{"target_names": ["tp", "pp"], "priority_policy": "balanced"},
DecodeContext(particle_index=9, n_particles=10),
["pp", "tp"],
),
(
{"target_names": ["tp", "pp"], "priority_policy": "balanced"},
DecodeContext(particle_index=0, n_particles=10, iteration=1),
["pp", "tp"],
),
(
{"target_names": ["tp", "pp"], "priority_policy": "balanced"},
DecodeContext(particle_index=9, n_particles=10, iteration=1),
["tp", "pp"],
),
(
{"target_names": ["tp", "pp"]},
DecodeContext(particle_index=0, n_particles=4),
["tp", "pp"],
),
(
{"target_names": ["tp"], "priority_policy": "balanced"},
DecodeContext(particle_index=0, n_particles=10),
["tp"],
),
],
)
def test_resolve_priority_strategies(dtype_param, context, expected):
assert resolve_priority(dtype_param, context) == expected
@pytest.mark.parametrize(
"total,forward_indexes,reverse_indexes",
[(10, range(0, 5), range(5, 10)), (11, range(0, 6), range(6, 11))],
)
def test_balanced_priority_splits_particle_population(total, forward_indexes, reverse_indexes):
dtype_param = {"target_names": ["tp", "pp"], "priority_policy": "balanced"}
assert [
resolve_priority(dtype_param, DecodeContext(particle_index=i, n_particles=total)) for i in forward_indexes
] == [["tp", "pp"]] * len(list(forward_indexes))
assert [
resolve_priority(dtype_param, DecodeContext(particle_index=i, n_particles=total)) for i in reverse_indexes
] == [["pp", "tp"]] * len(list(reverse_indexes))
def test_balanced_priority_alternates_by_iteration():
dtype_param = {"target_names": ["tp", "pp"], "priority_policy": "balanced"}
assert [
resolve_priority(
dtype_param,
DecodeContext(particle_index=0, n_particles=10, iteration=iteration),
)
for iteration in (0, 2, 4)
] == [["tp", "pp"]] * 3
assert [
resolve_priority(
dtype_param,
DecodeContext(particle_index=0, n_particles=10, iteration=iteration),
)
for iteration in (1, 3, 5)
] == [["pp", "tp"]] * 3
@pytest.mark.parametrize(
"params_fields,values,context,expected_name,expected_value",
[
(
pair_fields(policy="fixed", priority=["tp", "pp"]),
[8, 5, 0],
None,
"tp",
8,
),
(
pair_fields(policy="fixed", priority=["pp", "tp"]),
[3, 4, 0],
None,
"pp",
4,
),
(
pair_fields(policy="balanced"),
[8, 3, 0],
DecodeContext(particle_index=2, n_particles=10),
"tp",
8,
),
(
pair_fields(policy="balanced"),
[3, 4, 0],
DecodeContext(particle_index=7, n_particles=10),
"pp",
4,
),
],
)
def test_priority_repair_preserves_expected_source(params_fields, values, context, expected_name, expected_value):
runtime_fields = clone_with_values(params_fields, values)
ok = _repair_ternary_factories_with_priority(
params_fields[2],
runtime_fields,
params_fields,
product=32,
min_val=1,
max_val=None,
conv=int,
context=context,
)
by_name = {item.name: item.value for item in runtime_fields}
assert ok
assert by_name[expected_name] == expected_value
assert by_name["tp"] * by_name["pp"] * by_name["dp"] == 32
def test_priority_repair_returns_false_when_no_candidate_combination_is_valid():
params_fields = pair_fields(policy="fixed", priority=["tp", "pp"], tp_candidates=[4, 8], pp_candidates=[3])
runtime_fields = clone_with_values(params_fields, [8, 3, 0])
ok = _repair_ternary_factories_with_priority(
params_fields[2],
runtime_fields,
params_fields,
product=32,
min_val=1,
max_val=None,
conv=int,
)
assert ok is False
def test_map_param_forwards_decode_context_to_priority_repair():
fields = list(pair_fields(policy="balanced"))
result = map_param_with_value(
np.array([0.375, 0.375]),
fields,
decode_context=DecodeContext(particle_index=0, n_particles=10),
)
tp_value, pp_value, dp_value = [item.value for item in result]
assert tp_value > 0
assert pp_value > 0
assert dp_value == int(32 / (tp_value * pp_value))
def test_map_param_without_decode_context_still_repairs_to_consistent_values():
result = map_param_with_value(np.array([0.875, 0.375]), list(pair_fields(policy="balanced")))
tp_value, pp_value, dp_value = [item.value for item in result]
assert 32 % (tp_value * pp_value) == 0
assert dp_value == int(32 / (tp_value * pp_value))
def test_env_backup_is_restored_for_manual_mindie_path_check(monkeypatch):
monkeypatch.setenv("MIES_INSTALL_PATH", "/opt/mindie/latest/bin")
before = os.environ.get("MIES_INSTALL_PATH")
with patch.object(Path, "is_file", return_value=False):
config_path, backup_path = _get_mindie_config_paths()
assert os.environ.get("MIES_INSTALL_PATH") == before
assert config_path == Path("/opt/mindie/latest/mindie_llm/conf/config.json")
assert backup_path == Path("/opt/mindie/latest/mindie_llm/conf/config_bak.json")