"""Utils for model training"""
import os
import numpy as np
from mindspore.mindrecord import FileWriter
def generate_mindrecord_file(
seq_length: int = 128,
batch_size: int = 1,
train_steps: int = 1000,
dataset_path: str = None,
data_schema: dict = None
):
"""
Generate mindrecord file.
Args:
seq_length (int): Sequence length of each sample. Default: 128.
batch_size (int): Batch size for training. Default: 1.
train_steps (int): Number of training steps. Default: 1000.
dataset_path (str): Path to save the generated mindrecord file.
If None, defaults to "./test.mindrecord". Default: None.
"""
if dataset_path is None:
raise ValueError("dataset_path should be specified.")
if data_schema is None:
raise ValueError("data_schema should be specified.")
data_dir = os.path.dirname(dataset_path)
if data_dir:
os.makedirs(data_dir, exist_ok=True)
data_num = batch_size * train_steps
np.random.seed(0)
def _resolve_shape(shape):
return tuple(seq_length if dim == -1 else dim for dim in shape)
def _generate_data(dtype, shape):
np_dtype = np.dtype(dtype)
if np_dtype in (np.int32, np.uint8, np.int64):
return np.random.randint(0, 1024, size=shape).astype(np_dtype)
if np_dtype in (np.float16, np.float32, np.float64):
return np.random.rand(*shape).astype(np_dtype)
raise ValueError(f"Unsupported dtype: {dtype}")
retry = True
count = 0
success_sig = False
while retry:
try:
count += 1
writer = FileWriter(dataset_path)
writer.add_schema(data_schema, "test-schema")
for _ in range(data_num):
features = {}
for field_name, field_info in data_schema.items():
resolved_shape = _resolve_shape(field_info["shape"])
features[field_name] = _generate_data(field_info["type"], resolved_shape)
writer.write_raw_data([features])
writer.commit()
retry = False
success_sig = True
except Exception as e:
if os.path.exists(dataset_path):
os.remove(dataset_path)
if os.path.exists(dataset_path + ".db"):
os.remove(dataset_path + ".db")
print(f"mindrecord data initialize failed, due to \"{e}\".")
if count >= 3:
retry = False
if not success_sig:
raise RuntimeError(f"mindrecord data initialize failed for {count} times.")