"""
-------------------------------------------------------------------------
This file is part of the MindStudio project.
Copyright (c) 2025 Huawei Technologies Co.,Ltd.
MindStudio is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
-------------------------------------------------------------------------
"""
"""
DistributedTaskScheduler(分波调度器)的单元测试
测试范围:
- submit 分波与执行语义
- 单卡场景下的基本执行流程
"""
import unittest
from typing import Any, Dict, List, Tuple
from unittest.mock import MagicMock
import torch
import torch.nn as nn
from msmodelslim.utils.distributed import DistributedTaskScheduler
from test.cases.utils.distributed.dts_test_internals import (
_DtsSequentialWaveScheduler,
dts_waves,
)
from test.cases.utils.distributed.dts_test_utils import build_dts_dependency_mock_model
class TestDistributedTaskSchedulerExecution(unittest.TestCase):
"""DistributedTaskScheduler 执行相关测试"""
def setUp(self):
"""测试前置准备"""
self.mock_model = nn.Linear(10, 10)
self.execution_order: List[Tuple[int, str]] = []
def test_run_executes_waves_in_order_when_multiple_waves_exist(self):
"""测试:多波次按顺序执行(使用空依赖避免模块解析)"""
def fn(payload: Any) -> str:
wave_idx = len(self.execution_order)
self.execution_order.append((wave_idx, str(payload)))
return f"result_{payload}"
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(fn, args=("task1",), dependencies=[])
scheduler.submit(fn, args=("task2",), dependencies=[])
scheduler.submit(fn, args=("task3",), dependencies=[])
scheduler.submit(fn, args=("task4",), dependencies=[])
records = scheduler.run()
self.assertEqual(len(records), 4)
self.assertEqual(len(dts_waves(scheduler)), 1)
def test_run_records_task_result_when_worker_returns_value(self):
"""测试:任务结果正确记录在 TaskExecutionRecord 中"""
def fn(payload: Any) -> str:
return f"processed_{payload}"
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(fn, args=("task1",), dependencies=[])
scheduler.submit(fn, args=("task2",), dependencies=[])
records = scheduler.run()
self.assertEqual(len(records), 2)
self.assertEqual([r.result for r in records], ["processed_task1", "processed_task2"])
def test_run_records_executor_rank_when_single_rank(self):
"""测试:TaskExecutionRecord 中记录了正确的执行 rank"""
def fn(payload: Any) -> str:
return str(payload)
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(fn, args=("task1",), dependencies=[])
records = scheduler.run()
self.assertEqual(len(records), 1)
self.assertEqual(records[0].executor_rank, 0)
def test_submit_creates_expected_waves_when_dependencies_alternate(self):
"""测试:交替依赖模式(A->B->A->B)- 仅验证分波逻辑,不执行"""
fn = lambda payload: str(payload)
scheduler = DistributedTaskScheduler(build_dts_dependency_mock_model())
with scheduler:
scheduler.submit(fn, args=("task1",), dependencies=["A"])
scheduler.submit(fn, args=("task2",), dependencies=["B"])
scheduler.submit(fn, args=("task3",), dependencies=["A", "C"])
scheduler.submit(fn, args=("task4",), dependencies=["B", "D"])
scheduler.submit(fn, args=("task5",), dependencies=["A"])
self.assertEqual(len(dts_waves(scheduler)), 3)
self.assertEqual(len(dts_waves(scheduler)[0]._tasks), 2)
self.assertEqual(len(dts_waves(scheduler)[1]._tasks), 2)
self.assertEqual(len(dts_waves(scheduler)[2]._tasks), 1)
class TestDistributedTaskSchedulerEdgeCases(unittest.TestCase):
"""DistributedTaskScheduler 边界条件测试"""
def setUp(self):
self.mock_model = build_dts_dependency_mock_model()
def test_submit_creates_new_wave_per_task_when_all_dependencies_same(self):
"""测试:所有任务具有相同依赖"""
fn = lambda payload: str(payload)
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(fn, args=("task1",), dependencies=["shared"])
scheduler.submit(fn, args=("task2",), dependencies=["shared"])
scheduler.submit(fn, args=("task3",), dependencies=["shared"])
self.assertEqual(len(dts_waves(scheduler)), 3)
for i in range(3):
self.assertEqual(len(dts_waves(scheduler)[i]._tasks), 1)
def test_submit_keeps_single_wave_when_all_dependencies_unique(self):
"""测试:大量任务在同一个波次"""
fn = lambda payload: str(payload)
scheduler = DistributedTaskScheduler(self.mock_model)
num_tasks = 100
with scheduler:
for i in range(num_tasks):
scheduler.submit(fn, args=(f"task{i}",), dependencies=[f"module{i}"],
)
self.assertEqual(len(dts_waves(scheduler)), 1)
self.assertEqual(len(dts_waves(scheduler)[0]._tasks), num_tasks)
def test_submit_creates_many_waves_when_dependencies_always_conflict(self):
"""测试:大量波次"""
fn = lambda payload: str(payload)
scheduler = DistributedTaskScheduler(self.mock_model)
num_waves = 50
with scheduler:
for i in range(num_waves):
scheduler.submit(fn, args=(f"task{i}",), dependencies=["shared_module"],
)
self.assertEqual(len(dts_waves(scheduler)), num_waves)
def test_run_executes_payload_driven_worker_when_dependencies_empty(self):
"""测试:空依赖但有 payload 的任务"""
executed_tasks: List[str] = []
def fn(payload: Any) -> str:
executed_tasks.append(payload.get("key", "no_key"))
return "ok"
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(fn, args=({"key": "value1"},), dependencies=[])
scheduler.submit(fn, args=({"key": "value2"},), dependencies=[])
records = scheduler.run()
self.assertEqual(len(executed_tasks), 2)
self.assertIn("value1", executed_tasks)
self.assertIn("value2", executed_tasks)
def test_submit_creates_new_wave_when_any_dependency_conflicts(self):
"""测试:多个依赖全部与当前波次冲突"""
fn = lambda payload: str(payload)
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(fn, args=("task1",), dependencies=["m1", "m2", "m3"])
scheduler.submit(fn, args=("task2",), dependencies=["m3", "m4", "m5"])
self.assertEqual(len(dts_waves(scheduler)), 2)
def test_submit_keeps_same_wave_when_no_dependencies_conflict(self):
"""测试:部分依赖但无冲突"""
fn = lambda payload: str(payload)
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(fn, args=("task1",), dependencies=["m1"])
scheduler.submit(fn, args=("task2",), dependencies=["m2"])
scheduler.submit(fn, args=("task3",), dependencies=["m3", "m4"])
self.assertEqual(len(dts_waves(scheduler)), 1)
self.assertEqual(len(dts_waves(scheduler)[0]._tasks), 3)
def test_run_does_not_remove_waves_when_completed(self):
"""测试:run 后波次状态保持不变"""
fn = lambda payload: str(payload)
scheduler = DistributedTaskScheduler(self.mock_model)
with scheduler:
scheduler.submit(fn, args=("task1",), dependencies=[])
scheduler.submit(fn, args=("task2",), dependencies=[])
records = scheduler.run()
self.assertEqual(len(dts_waves(scheduler)), 1)
self.assertEqual(len(dts_waves(scheduler)[0]._tasks), 2)
class TestDistributedTaskSchedulerIntegration(unittest.TestCase):
"""DistributedTaskScheduler 集成测试"""
def test_run_matches_impl_behavior_when_no_dependencies(self):
"""测试:DistributedTaskScheduler(wave) 与内部单波调度器的兼容性(使用空依赖)"""
mock_model = nn.Linear(10, 10)
results = []
def fn(payload: Any) -> str:
results.append(str(payload))
return str(payload)
with _DtsSequentialWaveScheduler(mock_model) as impl:
impl.submit(fn, args=("task1",), dependencies=[])
impl.submit(fn, args=("task2",), dependencies=[])
impl_records = impl.run()
with DistributedTaskScheduler(mock_model) as wave_dts:
wave_dts.submit(fn, args=("task1",), dependencies=[])
wave_dts.submit(fn, args=("task2",), dependencies=[])
wave_records = wave_dts.run()
self.assertEqual(len(impl_records), 2)
self.assertEqual(len(wave_records), 2)
self.assertEqual(len(results), 4)
def test_submit_creates_expected_waves_when_mixed_dependency_scenario(self):
"""测试:混合依赖场景(模拟实际量化流程)- 仅验证分波逻辑"""
mock_model = build_dts_dependency_mock_model()
fn = lambda payload: payload
scheduler = DistributedTaskScheduler(mock_model)
with scheduler:
for layer in range(3):
for proj in ["q_proj", "k_proj", "v_proj", "o_proj"]:
scheduler.submit(
fn,
args=({"phase": "weight_quant", "name": f"l{layer}_{proj}"},),
dependencies=[f"layer.{layer}.{proj}"],
)
for layer in range(3):
scheduler.submit(
fn,
args=({"phase": "smooth", "name": f"l{layer}"},),
dependencies=[f"layer.{layer}.q_proj", f"layer.{layer}.o_proj"],
)
for layer in range(3):
scheduler.submit(
fn,
args=({"phase": "rotation", "name": f"l{layer}"},),
dependencies=[f"layer.{layer}.q_proj"],
)
self.assertEqual(len(dts_waves(scheduler)), 3)
self.assertEqual(len(dts_waves(scheduler)[0]._tasks), 12)
self.assertEqual(len(dts_waves(scheduler)[1]._tasks), 3)
self.assertEqual(len(dts_waves(scheduler)[2]._tasks), 3)
wave0_phases = {t.spec.args[0].get("phase") for t in dts_waves(scheduler)[0]._tasks}
wave1_phases = {t.spec.args[0].get("phase") for t in dts_waves(scheduler)[1]._tasks}
wave2_phases = {t.spec.args[0].get("phase") for t in dts_waves(scheduler)[2]._tasks}
self.assertEqual(wave0_phases, {"weight_quant"})
self.assertEqual(wave1_phases, {"smooth"})
self.assertEqual(wave2_phases, {"rotation"})