from mindspeed_mm.fsdp.data.dataloader.dataloader import (
prepare_base_dataloader,
prepare_sampler_dataloader,
)
from mindspeed_mm.fsdp.utils.register import data_register
def build_mm_dataset(dataset_param):
"""
Build a multimodal dataset based on different tasks
Args:
dataset_param
Return:
dataset
"""
if not isinstance(dataset_param, dict):
dataset_param = dataset_param.to_dict()
for check_key in ["dataset_type", "basic_parameters", "preprocess_parameters"]:
if check_key not in dataset_param:
raise AssertionError(f"Key parameter missing: {check_key}")
dataset_type = dataset_param["dataset_type"]
basic_param = dataset_param["basic_parameters"]
preprocess_param = dataset_param["preprocess_parameters"]
dataset_cls_or_func = data_register.get(dataset_type)
return dataset_cls_or_func(
basic_param=basic_param,
preprocess_param=preprocess_param,
dataset_param=dataset_param
)
def build_mm_dataloader(
dataset,
dataloader_param,
process_group=None,
dataset_param=None,
model=None,
):
"""
Build a multimodal dataloader based on different tasks
dataloader_type interpretation:
base: raw dataloader based on torch.utils.data.DataLoader
sampler: prepare a dataloader for distributed training by building a specific sampler
variable: used for variable dataset
Args:
dataloader_param_dict
Return:
dataloader
"""
if not isinstance(dataloader_param, dict):
dataloader_param = dataloader_param.to_dict()
if "dataloader_mode" not in dataloader_param:
raise AssertionError("Key parameter missing: dataloader_mode")
dataloader_mode = dataloader_param.pop("dataloader_mode")
if dataloader_mode == "base":
data_loader = prepare_base_dataloader(dataset, **dataloader_param, dataset_param=dataset_param)
return data_loader
elif dataloader_mode == "sampler":
data_loader = prepare_sampler_dataloader(
dataset, **dataloader_param,
process_group=process_group,
dataset_param=dataset_param,
model=model,
)
return data_loader
else:
raise NotImplementedError(f"Unsupported dataloader_mode: {dataloader_mode}")