"""DeepseekV3 models' APIs."""
import os
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from research.deepseek3.deepseek3_model_train import TrainingDeepseekV3ForCausalLM
from research.deepseek3.deepseek3_model_infer import InferenceDeepseekV3ForCausalLM
from research.deepseek3.deepseek3_model_infer import InferenceDeepseekV3MTPForCausalLM
__all__ = ['DeepseekV3ForCausalLM']
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class DeepseekV3ForCausalLM:
r"""
Provide DeepseekV3 Model for training and inference.
Args:
config (DeepseekV3Config): The config of DeepseekV3 model.
Returns:
Tensor, the loss or logits of the network.
"""
def __new__(cls, config, *args, **kwargs):
if os.environ.get("RUN_MODE") == "predict":
return InferenceDeepseekV3ForCausalLM(config=config) if not config.is_mtp_model else \
InferenceDeepseekV3MTPForCausalLM(config=config)
return TrainingDeepseekV3ForCausalLM(config=config, *args, **kwargs)