import os
from typing import Tuple, Optional, List
import numpy as np
import torch
from torch import Tensor
from torch.cuda import Stream
from torch.utils.data import Dataset, DataLoader
import tqdm
from dlrm.data.defaults import TRAIN_MAPPING, TEST_MAPPING, DTYPE_SELECTOR
from dlrm.data.feature_spec import FeatureSpec
def collate_split_tensors(
tensors: Tuple[Tensor, Tensor, Tensor],
device: str,
orig_stream: Stream,
numerical_type: torch.dtype = torch.float32
):
tensors = [tensor.to(device, non_blocking=True) if tensor is not None else None for tensor in
tensors]
if device == 'cuda':
for tensor in tensors:
if tensor is not None:
tensor.record_stream(orig_stream)
numerical_features, categorical_features, click = tensors
if numerical_features is not None:
numerical_features = numerical_features.to(numerical_type)
return numerical_features, categorical_features, click
def collate_array(
array: np.array,
device: str,
orig_stream: Stream,
num_numerical_features: int,
selected_categorical_features: Optional[Tensor] = None
):
numerical_features = array[:, 1:1 + num_numerical_features].view(dtype=np.float32)
numerical_features = torch.from_numpy(numerical_features)
categorical_features = torch.from_numpy(array[:, 1 + num_numerical_features:])
click = torch.from_numpy(array[:, 0])
categorical_features = categorical_features.to(device, non_blocking=True).to(torch.long)
numerical_features = numerical_features.to(device, non_blocking=True)
click = click.to(torch.float32).to(device, non_blocking=True)
if selected_categorical_features is not None:
categorical_features = categorical_features[:, selected_categorical_features]
if device == 'cuda':
numerical_features.record_stream(orig_stream)
categorical_features.record_stream(orig_stream)
click.record_stream(orig_stream)
return numerical_features, categorical_features, click
def write_dataset_to_disk(dataset_train: Dataset, dataset_test: Dataset, feature_spec: FeatureSpec,
saving_batch_size=512) -> None:
feature_spec.check_feature_spec()
categorical_features_list = feature_spec.get_categorical_feature_names()
categorical_features_types = [feature_spec.feature_spec[feature_name][DTYPE_SELECTOR]
for feature_name in categorical_features_list]
number_of_numerical_features = feature_spec.get_number_of_numerical_features()
number_of_categorical_features = len(categorical_features_list)
for mapping_name, dataset in zip((TRAIN_MAPPING, TEST_MAPPING),
(dataset_train, dataset_test)):
file_streams = []
label_path, numerical_path, categorical_paths = feature_spec.get_mapping_paths(mapping_name)
try:
os.makedirs(os.path.dirname(numerical_path), exist_ok=True)
numerical_f = open(numerical_path, "wb+")
file_streams.append(numerical_f)
os.makedirs(os.path.dirname(label_path), exist_ok=True)
label_f = open(label_path, 'wb+')
file_streams.append(label_f)
categorical_fs = []
for feature_name in categorical_features_list:
local_path = categorical_paths[feature_name]
os.makedirs(os.path.dirname(local_path), exist_ok=True)
fs = open(local_path, 'wb+')
categorical_fs.append(fs)
file_streams.append(fs)
for numerical, categorical, label in tqdm.tqdm(
DataLoader(dataset, saving_batch_size),
desc=mapping_name + " dataset saving",
unit_scale=saving_batch_size
):
assert (numerical.shape[-1] == number_of_numerical_features)
assert (categorical.shape[-1] == number_of_categorical_features)
numerical_f.write(numerical.to(torch.float16).cpu().numpy().tobytes())
label_f.write(label.to(torch.bool).cpu().numpy().tobytes())
for cat_idx, cat_feature_type in enumerate(categorical_features_types):
categorical_fs[cat_idx].write(
categorical[:, :, cat_idx].cpu().numpy().astype(cat_feature_type).tobytes())
finally:
for stream in file_streams:
stream.close()
feature_spec.to_yaml()
def prefetcher(load_iterator, prefetch_stream):
def _prefetch():
with torch.cuda.stream(prefetch_stream):
try:
data_batch = next(load_iterator)
except StopIteration:
return None
return data_batch
next_data_batch = _prefetch()
while next_data_batch is not None:
torch.cuda.current_stream().wait_stream(prefetch_stream)
data_batch = next_data_batch
next_data_batch = _prefetch()
yield data_batch
def get_embedding_sizes(fspec: FeatureSpec, max_table_size: Optional[int]) -> List[int]:
if max_table_size is not None:
return [min(s, max_table_size) for s in fspec.get_categorical_sizes()]
else:
return fspec.get_categorical_sizes()