import os
import types
from dataclasses import dataclass, field
import torch
from transformers.utils import is_torch_npu_available
from mindspeed.fsdp.utils.torch_patch import apply_hccl_premul_sum_patch
from mindspeed_llm.fsdp2.models.model_factory import ModelFactory
from mindspeed_llm.fsdp2.data.tokenizer import TokenizerFactory
from mindspeed_llm.fsdp2.inference.inferencer import Inferencer
from mindspeed_llm.fsdp2.utils.arguments import (
ModelArguments, ParallelArguments, InferenceArguments, OptimizationArguments, fsdp2_parse_args
)
from mindspeed_llm.fsdp2.utils.logging import setup_global_logging, get_logger
from mindspeed_llm.fsdp2.utils.global_vars import set_args
from mindspeed_llm.fsdp2.utils.device import set_accelerator_compatible
logger = get_logger(__name__)
@dataclass
class Arguments:
model: ModelArguments = field(default_factory=ModelArguments)
parallel: ParallelArguments = field(default_factory=ParallelArguments)
inference: InferenceArguments = field(default_factory=InferenceArguments)
optimization: OptimizationArguments = field(default_factory=OptimizationArguments)
class AutoInferencer:
"""
Responsible for setting up the runtime environment: NPU initialization,
distributed setup, and loading the FSDP model.
"""
def __init__(self):
root_args = fsdp2_parse_args(Arguments)
self.model_args = root_args.model
self.parallel_args = root_args.parallel
self.inference_args = root_args.inference
self.args = types.SimpleNamespace(**{
k: v for ns in [root_args.model, root_args.parallel, root_args.inference, root_args.optimization]
for k, v in ns.__dict__.items()
})
set_args(self.args)
self._initialize()
logger.info_rank0("> Building Tokenizer...")
self.tokenizer = TokenizerFactory.create(self.model_args)
self.parallel_args.recompute = False
logger.info_rank0("> Building Model for Inference...")
self.model = ModelFactory.create(self.model_args, self.parallel_args)
self.inferencer = Inferencer(
model=self.model,
tokenizer=self.tokenizer,
args=self.inference_args
)
@staticmethod
def _initialize():
"""Initialize underlying hardware and distributed environment."""
if is_torch_npu_available():
fallback = torch.npu
dist_backend = "hccl"
apply_hccl_premul_sum_patch()
elif torch.cuda.is_available():
fallback = torch.cuda
dist_backend = "nccl"
set_accelerator_compatible(fallback)
local_rank = int(os.environ.get("LOCAL_RANK", 0))
torch.accelerator.set_device_index(local_rank)
torch.accelerator.set_device(local_rank)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=dist_backend)
logger.info_rank0(f"> Distributed environment initialized. World size: {torch.distributed.get_world_size()}")
def chat(self):
"""Launch interactive chat."""
self.inferencer.run_interactive_chat()
if __name__ == "__main__":
try:
runner = AutoInferencer()
runner.chat()
except KeyboardInterrupt:
logger.info_rank0("\n> Inference interrupted by user. Exiting...")
except Exception as e:
logger.error(f"Inference failed with error: {e}")
raise
finally:
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()