import torch
from .native_pipeline import build_native_pipeline
from .input_iterators import ConvertDaliInputIterator, RateMatcher, FakeInputIterator
from torch.utils.data import DataLoader
from mlperf_logger import log_event
from mlperf_logging.mllog import constants
"""
Build a train pipe for training (without touching the data)
returns train_pipe
"""
def prebuild_pipeline(args):
return None
"""
Build a data pipeline for either training or eval
Training : returns loader, epoch_size
Eval : returns loader, inv_class_map, cocoGt
"""
def build_pipeline(args, training=True, pipe=None):
if training:
builder_fn = build_native_pipeline
train_loader, epoch_size = builder_fn(args, training=True, pipe=pipe)
log_event(key=constants.TRAIN_SAMPLES, value=epoch_size)
if args.fake_input:
train_loader = FakeInputIterator(train_loader, epoch_size, args.N_gpu)
if args.input_batch_multiplier > 1:
train_loader = RateMatcher(input_it=train_loader, output_size=args.batch_size)
return train_loader, epoch_size
else:
return build_native_pipeline(args, training=False)