"""
MMEngine patches for NPU adaptation.
Provides:
- OptimizerWrapper: mmengine optimizer wrapper with fused gradient clipping
- Training loop patches with profiling and early stopping support
"""
import importlib
import re
from typing import Dict, List
from mx_driving.patcher.patcher_logger import patcher_logger
from mx_driving.patcher.patch import (
AtomicPatch,
BasePatch,
LegacyPatch,
Patch,
mmcv_version,
)
class OptimizerWrapper(Patch):
"""
Optimizer wrapper patch for mmengine with gradient clipping support.
Only applies when mmcv 2.x is detected (mmengine exists).
"""
name = "optimizer_wrapper"
legacy_name = "optimizer_wrapper"
target_module = "mmengine"
@staticmethod
def _wrap_init(original):
"""Wrap OptimWrapper.__init__ to add fused gradient clipping."""
def _get_clip_func(optimizer):
def clip_func(params, **kwargs):
return optimizer.clip_grad_norm_fused_(**kwargs)
return clip_func
def new_init(self, *args, **kwargs):
original(self, *args, **kwargs)
self.clip_grads = _get_clip_func(self.optimizer)
return new_init
@classmethod
def patches(cls, options=None) -> List[BasePatch]:
return [
AtomicPatch(
"mmengine.optim.optimizer.optimizer_wrapper.OptimWrapper.__init__",
target_wrapper=cls._wrap_init,
precheck=lambda: mmcv_version.is_v2x,
),
]
def _parse_profiler_options(options: Dict):
import torch_npu
path = options["profiling_path"]
level = options["profiling_level"]
if bool(re.search(r'[ +#%&{}\<>*?/$!\'":@`|;=]', path)):
patcher_logger.warning("profiling path contains illegal character")
if level < 0 or level > 2:
raise ValueError("valid profiling levels are integers within range [0, 2]")
step_ctrl = options.get('step_ctrl', (1, 1, 1, 1, 20))
activities = (
[torch_npu.profiler.ProfilerActivity.NPU]
if level == 0
else [torch_npu.profiler.ProfilerActivity.NPU, torch_npu.profiler.ProfilerActivity.CPU]
)
profiler_level = torch_npu.profiler.ProfilerLevel.Level0 if level == 0 else torch_npu.profiler.ProfilerLevel.Level1
return path, level, activities, profiler_level, step_ctrl
def build_mmengine_epoch_train_loop_patch(options: Dict) -> LegacyPatch:
def _apply(module, _options):
import sys
import torch_npu
enable_profiler = bool(options.get("enable_profiler"))
enable_brake = bool(options.get("enable_brake"))
if enable_profiler:
path, level, activities, profiler_level, step_ctrl = _parse_profiler_options(options)
wait, warmup, active, repeat, skip_first = step_ctrl
if enable_brake:
brake_step = options.get("brake_step")
def run_epoch(self):
self.runner.call_hook("before_train_epoch")
self.runner.model.train()
if enable_profiler:
with torch_npu.profiler.profile(
activities=activities,
with_stack=level == 2,
record_shapes=level > 0,
profile_memory=level == 2,
schedule=torch_npu.profiler.schedule(wait, warmup, active, repeat, skip_first),
experimental_config=torch_npu.profiler._ExperimentalConfig(profiler_level=profiler_level),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(path),
) as prof:
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
prof.step()
if enable_brake and self._iter == brake_step:
sys.exit(0)
else:
for idx, data_batch in enumerate(self.dataloader):
self.run_iter(idx, data_batch)
if enable_brake and self._iter == brake_step:
sys.exit(0)
self.runner.call_hook("after_train_epoch")
self._epoch += 1
importlib.import_module(f"{module.__name__}.runner")
module.runner.EpochBasedTrainLoop.run_epoch = run_epoch
return LegacyPatch(_apply, target_module="mmengine")
def build_mmengine_iter_train_loop_patch(options: Dict) -> LegacyPatch:
def _apply(module, _options):
import sys
import torch_npu
enable_profiler = bool(options.get("enable_profiler"))
enable_brake = bool(options.get("enable_brake"))
if enable_profiler:
path, level, activities, profiler_level, step_ctrl = _parse_profiler_options(options)
wait, warmup, active, repeat, skip_first = step_ctrl
if enable_brake:
brake_step = options.get("brake_step")
print_log = module.logging.print_log
logging = module.logging
def run(self):
self.runner.call_hook("before_train")
self.runner.call_hook("before_train_epoch")
if self._iter > 0:
print_log(
f"Advance dataloader {self._iter} steps to skip data that has already been trained",
logger="current",
level=logging.WARNING,
)
for _ in range(self._iter):
next(self.dataloader_iterator)
if enable_profiler:
with torch_npu.profiler.profile(
activities=activities,
with_stack=level == 2,
record_shapes=level > 0,
profile_memory=level == 2,
schedule=torch_npu.profiler.schedule(wait, warmup, active, repeat, skip_first),
experimental_config=torch_npu.profiler._ExperimentalConfig(profiler_level=profiler_level),
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(path),
) as prof:
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()
data_batch = next(self.dataloader_iterator)
self.run_iter(data_batch)
prof.step()
if enable_brake and self._iter == brake_step:
sys.exit(0)
self._decide_current_val_interval()
if (
self.runner.val_loop is not None
and self._iter >= self.val_begin
and (self._iter % self.val_interval == 0 or self._iter == self._max_iters)
):
self.runner.val_loop.run()
else:
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()
data_batch = next(self.dataloader_iterator)
self.run_iter(data_batch)
if enable_brake and self._iter == brake_step:
sys.exit(0)
self._decide_current_val_interval()
if (
self.runner.val_loop is not None
and self._iter >= self.val_begin
and (self._iter % self.val_interval == 0 or self._iter == self._max_iters)
):
self.runner.val_loop.run()
self.runner.call_hook("after_train_epoch")
self.runner.call_hook("after_train")
return self.runner.model
importlib.import_module(f"{module.__name__}.runner")
module.runner.IterBasedTrainLoop.run = run
return LegacyPatch(_apply, target_module="mmengine")