# Copyright 2025 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DeepseekV3 models' APIs."""
import os
from mindformers.tools.register import MindFormerRegister, MindFormerModuleType
from research.deepseek3.deepseek3_model_train import TrainingDeepseekV3ForCausalLM
from research.deepseek3.deepseek3_model_infer import InferenceDeepseekV3ForCausalLM
__all__ = ['DeepseekV3ForCausalLM']
@MindFormerRegister.register(MindFormerModuleType.MODELS)
class DeepseekV3ForCausalLM:
r"""
Provide DeepseekV3 Model for training and inference.
Args:
config (DeepseekV3Config): The config of DeepseekV3 model.
Returns:
Tensor, the loss or logits of the network.
"""
def __new__(cls, config, *args, **kwargs):
# get run mode to init different model.
# predict mode used to deploy.
# when predict mode not supported, we can use online_predict mode to do inference task.
if os.environ.get("RUN_MODE") == "predict":
return InferenceDeepseekV3ForCausalLM(config=config)
return TrainingDeepseekV3ForCausalLM(config=config, *args, **kwargs)