6cffd44a创建于 2025年3月17日历史提交
diff --git a/loops/epoch/training_epoch_loop.py b/loops/epoch/training_epoch_loop.py
index 3954718..3f2d40b 100644
--- a/loops/epoch/training_epoch_loop.py
+++ b/loops/epoch/training_epoch_loop.py
@@ -17,6 +17,7 @@ from typing import Any, Dict, Generator, List, Optional, overload, Tuple, Union

 import numpy as np
 import torch
+import time

 import pytorch_lightning as pl
 from pytorch_lightning import loops  # import as loops to avoid circular imports
@@ -204,9 +205,10 @@ class TrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):

             self.batch_progress.increment_started()

+            step_start_time = time.time()
             with self.trainer.profiler.profile("run_training_batch"):
                 batch_output = self.batch_loop.run(batch, batch_idx)
-
+            print(f"id {batch_idx} step_time= ", time.time() - step_start_time, flush=True)
         self.batch_progress.increment_processed()

         # update non-plateau LR schedulers
diff --git a/strategies/launchers/subprocess_script.py b/strategies/launchers/subprocess_script.py
index 5a8632f..a210573 100644
--- a/strategies/launchers/subprocess_script.py
+++ b/strategies/launchers/subprocess_script.py
@@ -13,13 +13,15 @@
 # limitations under the License.
 import os
 import subprocess
+from multiprocessing import cpu_count
 import sys
+from copy import deepcopy
 from time import sleep
 from typing import Any, Callable, Optional

 import __main__
 import numpy as np
-
+from copy import deepcopy
 import pytorch_lightning as pl
 from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment
 from pytorch_lightning.strategies.launchers.base import _Launcher
@@ -123,6 +125,9 @@ class _SubprocessScriptLauncher(_Launcher):
             command = [sys.executable] + command
         else:  # Script called as `python -m a.b.c`
             command = [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]
+        command_src = deepcopy(command)
+        cpu_kernel_num = int(cpu_count() / self.num_processes / 2)
+        command = ['taskset', '-c', f'0-{cpu_kernel_num-1}'] + command_src

         os.environ["WORLD_SIZE"] = f"{self.num_processes * self.num_nodes}"

@@ -142,6 +147,7 @@ class _SubprocessScriptLauncher(_Launcher):
                     cwd = get_original_cwd()
                     os_cwd = f'"{os.getcwd()}"'
                     command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
+            command = ['taskset', '-c', f'{local_rank * cpu_kernel_num}-{local_rank * cpu_kernel_num + (cpu_kernel_num-1)}'] + command_src
             subprocess.Popen(command, env=env_copy, cwd=cwd)

             # starting all processes at once can cause issues
diff --git a/trainer/connectors/accelerator_connector.py b/trainer/connectors/accelerator_connector.py
index e795358..db1d396 100644
--- a/trainer/connectors/accelerator_connector.py
+++ b/trainer/connectors/accelerator_connector.py
@@ -705,7 +705,7 @@ class AcceleratorConnector:
             )

             if self._amp_type_flag == AMPType.NATIVE:
-                device = "cpu" if self._accelerator_flag == "cpu" else "cuda"
+                device = "cpu" if self._accelerator_flag == "cpu" else "npu"

                 if isinstance(self.strategy, (DDPShardedStrategy, DDPSpawnShardedStrategy)):
                     return ShardedNativeMixedPrecisionPlugin(self._precision_flag, device)
diff --git a/utilities/types.py b/utilities/types.py
index c5e3841..a2a674d 100644
--- a/utilities/types.py
+++ b/utilities/types.py
@@ -26,6 +26,12 @@ from torch.utils.data import DataLoader
 from torchmetrics import Metric
 from typing_extensions import Protocol, runtime_checkable

+try:
+    from torch.optim.lr_scheduler import LRScheduler as TorchLRScheduler
+except ImportError:
+    # For torch <= 1.13.x
+    from torch.optim.lr_scheduler import _LRScheduler as TorchLRScheduler
+
 _NUMBER = Union[int, float]
 _METRIC = Union[Metric, torch.Tensor, _NUMBER]
 _METRIC_COLLECTION = Union[_METRIC, Mapping[str, _METRIC]]
@@ -93,9 +99,9 @@ class ReduceLROnPlateau(_Stateful, Protocol):


 # todo: improve LRSchedulerType naming/typing
-LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
-LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
-LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]
+LRSchedulerTypeTuple = (TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau)
+LRSchedulerTypeUnion = Union[TorchLRScheduler, torch.optim.lr_scheduler.ReduceLROnPlateau]
+LRSchedulerType = Union[Type[TorchLRScheduler], Type[torch.optim.lr_scheduler.ReduceLROnPlateau]]


 @dataclass