"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
"""
DistributedTaskScheduler(分波调度器)的单元测试
测试范围:
- submit 分波与依赖/前缀/parallel 冲突语义(通过默认 wave backend 检视)
- 单卡场景下的基本执行流程
"""
import unittest
from typing import Any, Dict, List, Tuple
from unittest.mock import MagicMock
import torch
import torch.nn as nn
from msmodelslim.utils.distributed import DistributedTaskScheduler
from test.cases.utils.distributed.dts_test_internals import (
_DtsMultiRankParallelWaveScheduler,
_DtsSequentialWaveScheduler,
dts_waves,
)
from msmodelslim.utils.exception import SchemaValidateError
from test.cases.utils.distributed.dts_test_utils import build_dts_dependency_mock_model
class TestDistributedTaskScheduler(unittest.TestCase):
"""DistributedTaskScheduler 单元测试"""
def setUp(self):
"""测试前置准备"""
self.mock_model = build_dts_dependency_mock_model()
self.executed_payloads: List[str] = []
def mock_worker_fn(payload: Any) -> str:
"""模拟任务执行,记录执行的 payload(无 exec_ctx 注入)"""
self.executed_payloads.append(str(payload))
return f"result_{payload}"
self.mock_worker_fn = mock_worker_fn
def test_first_submit_creates_one_wave(self):
"""首个 submit 应落在第一波。"""
scheduler = DistributedTaskScheduler(self.mock_model)
scheduler.submit(self.mock_worker_fn, args=("task0",), dependencies=["module1"])
self.assertEqual(len(dts_waves(scheduler)), 1)
def test_disjoint_deps_merge_into_same_wave_when_parallel_unchanged(self):
"""无依赖冲突且 parallel 一致时并入当前波次。"""
scheduler = DistributedTaskScheduler(self.mock_model)
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=["module1", "module2"])
scheduler.submit(self.mock_worker_fn, args=("task2",), dependencies=["module3"])
waves = dts_waves(scheduler)
self.assertEqual(len(waves), 1)
self.assertEqual(len(waves[0]._tasks), 2)
def test_overlapping_deps_start_new_wave(self):
"""依赖与当前波次冲突时新开波次。"""
scheduler = DistributedTaskScheduler(self.mock_model)
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=["module1", "module2"])
scheduler.submit(self.mock_worker_fn, args=("task2",), dependencies=["module1", "module3"])
self.assertEqual(len(dts_waves(scheduler)), 2)
def test_prefix_dependency_conflict_splits_waves_both_directions(self):
"""前缀依赖在两个方向上均视为冲突并分波。"""
scheduler = DistributedTaskScheduler(self.mock_model)
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=["model.layers.0.self_attn"])
scheduler.submit(self.mock_worker_fn, args=("task2",), dependencies=["model.layers.0.self_attn.q_proj"])
self.assertEqual(len(dts_waves(scheduler)), 2)
scheduler2 = DistributedTaskScheduler(self.mock_model)
scheduler2.submit(self.mock_worker_fn, args=("task_deep",), dependencies=["model.layers.0.self_attn.q_proj"])
scheduler2.submit(self.mock_worker_fn, args=("task_parent",), dependencies=["model.layers.0.self_attn"])
self.assertEqual(len(dts_waves(scheduler2)), 2)
def test_dependencies_conflict_detects_subtree_prefix_with_trie(self):
"""测试:前缀树可识别“新路径是当前wave父路径”场景。"""
scheduler = DistributedTaskScheduler(self.mock_model)
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=["model.layers.0.self_attn.q_proj"])
self.assertTrue(dts_waves(scheduler)[-1].dependencies_conflict_with_wave(["model.layers.0.self_attn"]))
def test_dependencies_conflict_with_wave_multi_dependency_input(self):
"""测试:输入 deps 列表中任一元素冲突即返回 True。"""
scheduler = DistributedTaskScheduler(self.mock_model)
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=["model.layers.0.self_attn"])
self.assertTrue(
dts_waves(scheduler)[-1].dependencies_conflict_with_wave(
["model.layers.1.mlp", "model.norm", "model.layers.0.self_attn.q_proj"]
)
)
self.assertFalse(
dts_waves(scheduler)[-1].dependencies_conflict_with_wave(
["model.layers.1.mlp", "model.layers.2.mlp", "model.embed_tokens"]
)
)
def test_parallel_category_change_splits_wave_even_when_deps_disjoint(self):
"""依赖无交集但 parallel 与当前波次不一致时也应分波。"""
scheduler = DistributedTaskScheduler(self.mock_model)
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=["m1"], parallel=True)
scheduler.submit(self.mock_worker_fn, args=("task2",), dependencies=["m2"], parallel=False)
self.assertEqual(len(dts_waves(scheduler)), 2)
def test_submission_conflicts_with_wave_includes_parallel_key(self):
"""最后一波统一冲突接口:并行类别不一致即冲突(与依赖是否相交无关)。"""
scheduler = DistributedTaskScheduler(self.mock_model)
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=["m1"], parallel=True)
wave = dts_waves(scheduler)[-1]
self.assertTrue(wave.submission_conflicts_with_wave(["m2"], False))
self.assertFalse(wave.submission_conflicts_with_wave(["m2"], True))
self.assertTrue(wave.submission_conflicts_with_wave(["m1", "m3"], True))
def test_submit_splits_waves_when_parallel_changes_disjoint_deps(self):
"""分波:deps 不冲突但 parallel 不同 → 两波"""
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(self.mock_worker_fn, args=("a",), dependencies=["m1"], parallel=True)
scheduler.submit(self.mock_worker_fn, args=("b",), dependencies=["m2"], parallel=False)
self.assertEqual(len(dts_waves(scheduler)), 2)
self.assertIsInstance(dts_waves(scheduler)[0], _DtsSequentialWaveScheduler)
self.assertIsInstance(dts_waves(scheduler)[1], _DtsSequentialWaveScheduler)
def test_scheduler_disable_parallel_coerces_to_local_wave(self):
"""调度器级 disable_parallel=True 时默认 submit 仍落本地波次且 spec 中 parallel 为 False"""
scheduler = DistributedTaskScheduler(self.mock_model, disable_parallel=True)
with scheduler:
scheduler.submit(self.mock_worker_fn, args=("a",), dependencies=["m1"])
self.assertEqual(len(dts_waves(scheduler)), 1)
self.assertIsInstance(dts_waves(scheduler)[0], _DtsSequentialWaveScheduler)
self.assertFalse(dts_waves(scheduler)[0]._tasks[0].spec.parallel)
def test_submit_creates_expected_waves_when_dependencies_conflict_across_tasks(self):
"""测试:submit 方法正确创建波次"""
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=["m1"])
scheduler.submit(self.mock_worker_fn, args=("task2",), dependencies=["m2"])
scheduler.submit(self.mock_worker_fn, args=("task3",), dependencies=["m1", "m3"])
scheduler.submit(self.mock_worker_fn, args=("task4",), dependencies=["m2", "m4"])
scheduler.submit(self.mock_worker_fn, args=("task5",), dependencies=["m3", "m5"])
self.assertEqual(len(dts_waves(scheduler)), 3)
self.assertEqual(len(dts_waves(scheduler)[0]._tasks), 2)
self.assertEqual(len(dts_waves(scheduler)[1]._tasks), 2)
self.assertEqual(len(dts_waves(scheduler)[2]._tasks), 1)
self.assertEqual(dts_waves(scheduler)[0]._tasks[0].spec.dependencies, ["m1"])
self.assertEqual(dts_waves(scheduler)[0]._tasks[1].spec.dependencies, ["m2"])
self.assertEqual(dts_waves(scheduler)[1]._tasks[0].spec.dependencies, ["m1", "m3"])
self.assertEqual(dts_waves(scheduler)[1]._tasks[1].spec.dependencies, ["m2", "m4"])
def test_submit_raises_when_scheduler_closed(self):
"""测试:向已关闭的 scheduler 提交任务应报错"""
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=["m1"])
with self.assertRaises(RuntimeError) as context:
scheduler.submit(self.mock_worker_fn, args=("task2",), dependencies=["m2"])
self.assertIn("closed", str(context.exception).lower())
def test_submit_raises_when_dependency_path_not_under_model(self):
"""测试:非法 ``dependencies`` 在 submit 阶段即失败。"""
scheduler = DistributedTaskScheduler(self.mock_model)
fn = lambda: None
with scheduler:
with self.assertRaises(SchemaValidateError) as ctx:
scheduler.submit(fn, dependencies=["no_such_submodule"])
self.assertIn("invalid dependency path", str(ctx.exception).lower())
def test_submit_generates_same_semantic_hash_when_tensor_values_differ(self):
"""测试:Tensor 仅值不同(元信息相同)时,submit 语义哈希一致。"""
scheduler = DistributedTaskScheduler(self.mock_model)
t1 = torch.ones(2, 3, dtype=torch.float32)
t2 = torch.zeros(2, 3, dtype=torch.float32)
with scheduler:
scheduler.submit(self.mock_worker_fn, args=(t1,), dependencies=["m1"])
scheduler.submit(self.mock_worker_fn, args=(t2,), dependencies=["m2"])
h1 = dts_waves(scheduler)[0]._tasks[0].spec.semantic_hash
h2 = dts_waves(scheduler)[0]._tasks[1].spec.semantic_hash
self.assertNotEqual(h1, "")
self.assertNotEqual(h2, "")
scheduler2 = DistributedTaskScheduler(self.mock_model)
with scheduler2:
scheduler2.submit(self.mock_worker_fn, args=(t1,), dependencies=[])
scheduler2.submit(self.mock_worker_fn, args=(t2,), dependencies=[])
self.assertEqual(
dts_waves(scheduler2)[0]._tasks[0].spec.semantic_hash,
dts_waves(scheduler2)[0]._tasks[1].spec.semantic_hash,
)
def test_submit_raises_for_unstable_arg_type_in_hash(self):
"""测试:不可稳定序列化参数在 submit 阶段 fail-fast。"""
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
with self.assertRaises(SchemaValidateError) as ctx:
scheduler.submit(self.mock_worker_fn, args=(object(),), dependencies=["m1"])
self.assertIn("unsupported value type", str(ctx.exception).lower())
def test_run_executes_all_tasks_when_single_rank(self):
"""测试:单卡场景下的基本执行流程(使用空依赖避免模块解析)"""
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=[])
scheduler.submit(self.mock_worker_fn, args=("task2",), dependencies=[])
scheduler.submit(self.mock_worker_fn, args=("task3",), dependencies=[])
records = scheduler.run()
self.assertEqual(len(records), 3)
self.assertEqual([r.task_id for r in records], ["w0_t0", "w0_t1", "w0_t2"])
self.assertEqual(sorted(self.executed_payloads), ["task1", "task2", "task3"])
def test_run_returns_empty_records_when_no_tasks_submitted(self):
"""测试:空任务列表的执行"""
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
records = scheduler.run()
self.assertEqual(len(records), 0)
self.assertEqual(len(dts_waves(scheduler)), 0)
def test_current_wave_deps_accumulates_when_tasks_submitted_to_same_wave(self):
"""测试:当前波次 deps 合集正确累积(由最后一个 wave 的前缀树登记)"""
scheduler = DistributedTaskScheduler(self.mock_model)
self.assertEqual(len(dts_waves(scheduler)), 0)
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=["m1", "m2"])
self.assertEqual(len(dts_waves(scheduler)), 1)
self.assertEqual(dts_waves(scheduler)[-1].registered_dependency_paths(), {"m1", "m2"})
scheduler.submit(self.mock_worker_fn, args=("task2",), dependencies=["m3"])
self.assertEqual(len(dts_waves(scheduler)), 1)
self.assertEqual(dts_waves(scheduler)[-1].registered_dependency_paths(), {"m1", "m2", "m3"})
scheduler.submit(self.mock_worker_fn, args=("task3",), dependencies=["m1"])
self.assertEqual(len(dts_waves(scheduler)), 2)
self.assertEqual(dts_waves(scheduler)[-1].registered_dependency_paths(), {"m1"})
def test_submit_puts_tasks_in_same_wave_when_dependencies_empty(self):
"""测试:空依赖列表的任务处理"""
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(self.mock_worker_fn, args=("task1",), dependencies=[])
scheduler.submit(self.mock_worker_fn, args=("task2",), dependencies=[])
self.assertEqual(len(dts_waves(scheduler)), 1)
self.assertEqual(len(dts_waves(scheduler)[0]._tasks), 2)
def test_submit_distributes_tasks_into_waves_when_complex_dependency_pattern(self):
"""测试:复杂的分波场景验证"""
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
for i in range(4):
scheduler.submit(self.mock_worker_fn, args=(f"quantize_layer{i}_q",), dependencies=[f"layer.{i}.q_proj"],
)
scheduler.submit(self.mock_worker_fn, args=(f"quantize_layer{i}_k",), dependencies=[f"layer.{i}.k_proj"],
)
scheduler.submit(self.mock_worker_fn, args=("smooth_layer0_q",), dependencies=["layer.0.q_proj", "layer.0.o_proj"],
)
scheduler.submit(self.mock_worker_fn, args=("smooth_layer0_k",), dependencies=["layer.0.k_proj", "layer.0.v_proj"],
)
scheduler.submit(self.mock_worker_fn, args=("rotate_layer0_o",), dependencies=["layer.0.o_proj"],
)
self.assertEqual(len(dts_waves(scheduler)), 3)
self.assertEqual(len(dts_waves(scheduler)[0]._tasks), 8)
self.assertEqual(len(dts_waves(scheduler)[1]._tasks), 2)
self.assertEqual(len(dts_waves(scheduler)[2]._tasks), 1)
def test_submit_splits_waves_in_mixed_prefix_and_parallel_scenario(self):
"""测试:复杂混合场景(依赖前缀冲突 + parallel 切换)的分波结果。"""
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(self.mock_worker_fn, args=("t1",), dependencies=["model.layers.0.self_attn"], parallel=True)
scheduler.submit(self.mock_worker_fn, args=("t2",), dependencies=["layer.1.q_proj"], parallel=True)
scheduler.submit(
self.mock_worker_fn,
args=("t3",),
dependencies=["model.layers.0.self_attn.q_proj"],
parallel=True,
)
scheduler.submit(self.mock_worker_fn, args=("t4",), dependencies=["layer.2.q_proj"], parallel=False)
scheduler.submit(self.mock_worker_fn, args=("t5",), dependencies=["layer.3.q_proj"], parallel=False)
scheduler.submit(self.mock_worker_fn, args=("t6",), dependencies=["layer.0"], parallel=True)
scheduler.submit(self.mock_worker_fn, args=("t7",), dependencies=["layer.0.q_proj"], parallel=True)
self.assertEqual(len(dts_waves(scheduler)), 5)
self.assertEqual([len(w._tasks) for w in dts_waves(scheduler)], [2, 1, 2, 1, 1])
def test_submit_splits_waves_in_mixed_parent_child_and_disjoint_batches(self):
"""测试:父子冲突与批量无冲突任务混合提交时的分波稳定性。"""
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(self.mock_worker_fn, args=("w0_1",), dependencies=["model.layers.0.self_attn.q_proj"], parallel=True)
scheduler.submit(self.mock_worker_fn, args=("w0_2",), dependencies=["layer.1.q_proj"], parallel=True)
scheduler.submit(self.mock_worker_fn, args=("w0_3",), dependencies=["layer.2.q_proj"], parallel=True)
scheduler.submit(self.mock_worker_fn, args=("w1_1",), dependencies=["model.layers.0.self_attn"], parallel=True)
scheduler.submit(self.mock_worker_fn, args=("w1_2",), dependencies=["layer.3.q_proj"], parallel=True)
scheduler.submit(self.mock_worker_fn, args=("w2_1",), dependencies=["layer.2.k_proj"], parallel=False)
scheduler.submit(self.mock_worker_fn, args=("w2_2",), dependencies=["layer.3.k_proj"], parallel=False)
scheduler.submit(self.mock_worker_fn, args=("w3_1",), dependencies=["module1"], parallel=True)
scheduler.submit(self.mock_worker_fn, args=("w3_2",), dependencies=["module2"], parallel=True)
self.assertEqual(len(dts_waves(scheduler)), 4)
self.assertEqual([len(w._tasks) for w in dts_waves(scheduler)], [3, 2, 2, 2])