import argparse
import time
from typing import Tuple, Optional
from torch.utils.data import DataLoader
from dlrm.data.datasets import ParametricDataset
from dlrm.data.factories import create_dataset_factory
from dlrm.data.feature_spec import FeatureSpec
def get_data_loaders(flags, feature_spec: FeatureSpec, device_mapping: Optional[dict] = None) -> \
Tuple[DataLoader, DataLoader]:
dataset_factory = create_dataset_factory(flags, feature_spec=feature_spec, device_mapping=device_mapping)
dataset_train, dataset_test = dataset_factory.create_datasets()
train_sampler = dataset_factory.create_sampler(dataset_train) if flags.shuffle_batch_order else None
collate_fn = dataset_factory.create_collate_fn()
data_loader_train = dataset_factory.create_data_loader(dataset_train, collate_fn=collate_fn, sampler=train_sampler)
data_loader_test = dataset_factory.create_data_loader(dataset_test, collate_fn=collate_fn)
return data_loader_train, data_loader_test
if __name__ == '__main__':
print('Dataloader benchmark')
parser = argparse.ArgumentParser()
parser.add_argument('--fspec_path', type=str)
parser.add_argument('--batch_size', type=int)
parser.add_argument('--steps', type=int, default=1000)
args = parser.parse_args()
fspec = FeatureSpec.from_yaml(args.fspec_path)
dataset = ParametricDataset(fspec, args.mapping, batch_size=args.batch_size, numerical_features_enabled=True,
categorical_features_to_read=fspec.get_categorical_feature_names())
begin = time.time()
for i in range(args.steps):
_ = dataset[i]
end = time.time()
step_time = (end - begin) / args.steps
throughput = args.batch_size / step_time
print(f'Mean step time: {step_time:.6f} [s]')
print(f'Mean throughput: {throughput:,.0f} [samples / s]')