"""Unit tests for the O(P+D) dispatch-compatibility helpers in motor.common.resources.dispatch."""
from dataclasses import dataclass, field
from motor.common.resources.dispatch import (
DispatchPlan,
dispatch_plan_union,
has_compatible_dispatch_pair,
shared_dispatch_plans,
)
CONCURRENT = DispatchPlan.CONCURRENT_ENGINE_SYNC.value
HANDOFF = DispatchPlan.PREFILL_HANDOFF_DECODE.value
@dataclass
class _Inst:
"""Minimal stand-in carrying only the dispatch_capabilities attribute the helpers read."""
id: int = 0
dispatch_capabilities: list = field(default_factory=list)
def _pairwise_compatible(prefill_instances, decode_instances):
"""Reference O(P*D) definition the optimized helpers must stay equivalent to."""
decode_list = list(decode_instances)
return any(shared_dispatch_plans(p, d) for p in prefill_instances for d in decode_list)
def test_dispatch_plan_union_aggregates_and_ignores_unknown_values():
instances = [_Inst(dispatch_capabilities=[CONCURRENT]), _Inst(dispatch_capabilities=[HANDOFF, "bogus"])]
assert dispatch_plan_union(instances) == {
DispatchPlan.CONCURRENT_ENGINE_SYNC,
DispatchPlan.PREFILL_HANDOFF_DECODE,
}
assert dispatch_plan_union([]) == set()
assert dispatch_plan_union([_Inst(dispatch_capabilities=[])]) == set()
def test_has_compatible_dispatch_pair_matches_pairwise_definition():
cases = [
([_Inst(dispatch_capabilities=[CONCURRENT])], [_Inst(dispatch_capabilities=[CONCURRENT])]),
([_Inst(dispatch_capabilities=[CONCURRENT])], [_Inst(dispatch_capabilities=[HANDOFF])]),
(
[_Inst(dispatch_capabilities=[CONCURRENT]), _Inst(dispatch_capabilities=[HANDOFF])],
[_Inst(dispatch_capabilities=[HANDOFF])],
),
([_Inst(dispatch_capabilities=[])], [_Inst(dispatch_capabilities=[CONCURRENT])]),
([], [_Inst(dispatch_capabilities=[CONCURRENT])]),
]
for prefill, decode in cases:
assert has_compatible_dispatch_pair(prefill, decode) == _pairwise_compatible(prefill, decode)