"""test adjust resume training"""
import numpy as np
import pytest
import mindspore as ms
from mindspore.dataset import GeneratorDataset
from mindformers.models.llama.llama import LlamaForCausalLM
from mindformers.models.llama.llama_config import LlamaConfig
from mindformers import Trainer, TrainingArguments
ms.set_context(mode=0)
def generator_train():
"""train dataset generator"""
seq_len = 513
input_ids = np.random.randint(low=0, high=15, size=(seq_len,)).astype(np.int32)
train_data = (input_ids,)
for _ in range(16):
yield train_data
class DummyTrainer(Trainer):
def __init__(self, resume_training=True, load_checkpoint=""):
args = TrainingArguments()
train_dataset = GeneratorDataset(generator_train, column_names=["input_ids"]).batch(batch_size=4)
model_config = LlamaConfig(num_layers=1, hidden_size=1, num_heads=1, seq_length=1, vocab_size=1)
model = LlamaForCausalLM(model_config)
super().__init__(task='text_generation', model=model, args=args, train_dataset=train_dataset)
self.config.resume_training = resume_training
self.config.load_checkpoint = load_checkpoint
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_empty_string_checkpoint():
"""
Feature: Resume training with empty checkpoint path
Description: Set resume_training=True and load_checkpoint="" (empty string)
Expectation: resume_training is set to False and load_checkpoint remains ""
"""
trainer = DummyTrainer(resume_training=True, load_checkpoint="")
trainer._adjust_resume_training_if_ckpt_path_invalid()
assert trainer.config.resume_training is False
assert trainer.config.load_checkpoint == ""
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_empty_directory_checkpoint(tmp_path):
"""
Feature: Resume training with empty directory as checkpoint
Description: Set resume_training=True and load_checkpoint to an empty directory path
Expectation: resume_training is set to False and load_checkpoint is reset to ""
"""
trainer = DummyTrainer(resume_training=True, load_checkpoint=str(tmp_path))
trainer._adjust_resume_training_if_ckpt_path_invalid()
assert trainer.config.resume_training is False
assert trainer.config.load_checkpoint == ""
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_valid_checkpoint(tmp_path):
"""
Feature: Resume training with valid checkpoint
Description: Set resume_training=True and load_checkpoint to a directory with a file
Expectation: resume_training remains True and load_checkpoint path is preserved
"""
file_path = tmp_path / "checkpoint.safetensors"
file_path.write_text("dummy checkpoint")
trainer = DummyTrainer(resume_training=True, load_checkpoint=str(tmp_path))
trainer._adjust_resume_training_if_ckpt_path_invalid()
assert trainer.config.resume_training is True
assert trainer.config.load_checkpoint == str(tmp_path)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_resume_training_false():
"""
Feature: Skip resume training when flag is False
Description: Set resume_training=False regardless of checkpoint path
Expectation: No change, resume_training remains False and load_checkpoint remains ""
"""
trainer = DummyTrainer(resume_training=False, load_checkpoint="")
trainer._adjust_resume_training_if_ckpt_path_invalid()
assert trainer.config.resume_training is False
assert trainer.config.load_checkpoint == ""