import argparse
import json
import sys
import numpy as np
import torch
import tritonclient.http as http_client
from sklearn.metrics import roc_auc_score
from tqdm import tqdm
from dlrm.data.datasets import SyntheticDataset, SplitCriteoDataset
from dlrm.utils.distributed import get_device_mapping
def get_data_loader(batch_size, *, data_path, model_config):
with open(model_config.dataset_config) as f:
categorical_sizes = list(json.load(f).values())
categorical_sizes = [s + 1 for s in categorical_sizes]
device_mapping = get_device_mapping(categorical_sizes, num_gpus=1)
if data_path:
data = SplitCriteoDataset(
data_path=data_path,
batch_size=batch_size,
numerical_features=True,
categorical_features=device_mapping['embedding'][0],
categorical_feature_sizes=categorical_sizes,
prefetch_depth=1,
drop_last_batch=model_config.drop_last_batch
)
else:
data = SyntheticDataset(
num_entries=batch_size * 1024,
batch_size=batch_size,
numerical_features=model_config.num_numerical_features,
categorical_feature_sizes=categorical_sizes,
device="cpu"
)
if model_config.test_batches > 0:
data = torch.utils.data.Subset(data, list(range(model_config.test_batches)))
return torch.utils.data.DataLoader(data,
batch_size=None,
num_workers=0,
pin_memory=False)
def run_infer(model_name, model_version, numerical_features, categorical_features, headers=None):
inputs = []
outputs = []
num_type = "FP16" if numerical_features.dtype == np.float16 else "FP32"
inputs.append(http_client.InferInput('input__0', numerical_features.shape, num_type))
inputs.append(http_client.InferInput('input__1', categorical_features.shape, "INT64"))
inputs[0].set_data_from_numpy(numerical_features, binary_data=True)
inputs[1].set_data_from_numpy(categorical_features, binary_data=False)
outputs.append(http_client.InferRequestedOutput('output__0', binary_data=True))
results = triton_client.infer(model_name,
inputs,
model_version=str(model_version) if model_version != -1 else '',
outputs=outputs,
headers=headers)
return results
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--triton-server-url',
type=str,
required=True,
help='URL adress of triton server (with port)')
parser.add_argument('--triton-model-name', type=str, required=True,
help='Triton deployed model name')
parser.add_argument('--triton-model-version', type=int, default=-1,
help='Triton model version')
parser.add_argument('-v',
'--verbose',
action="store_true",
required=False,
default=False,
help='Enable verbose output')
parser.add_argument('-H', dest='http_headers', metavar="HTTP_HEADER",
required=False, action='append',
help='HTTP headers to add to inference server requests. ' +
'Format is -H"Header:Value".')
parser.add_argument("--dataset_config", type=str, required=True)
parser.add_argument("--inference_data", type=str,
help="Path to file with inference data.")
parser.add_argument("--batch_size", type=int, default=1,
help="Inference request batch size")
parser.add_argument("--drop_last_batch", type=bool, default=True,
help="Drops the last batch size if it's not full")
parser.add_argument("--fp16", action="store_true", default=False,
help="Use 16bit for numerical input")
parser.add_argument("--test_batches", type=int, default=0,
help="Specifies number of batches used in the inference")
FLAGS = parser.parse_args()
try:
triton_client = http_client.InferenceServerClient(url=FLAGS.triton_server_url, verbose=FLAGS.verbose)
except Exception as e:
print("channel creation failed: " + str(e))
sys.exit(1)
if FLAGS.http_headers is not None:
headers_dict = {l.split(':')[0]: l.split(':')[1]
for l in FLAGS.http_headers}
else:
headers_dict = None
triton_client.load_model(FLAGS.triton_model_name)
if not triton_client.is_model_ready(FLAGS.triton_model_name):
sys.exit(1)
dataloader = get_data_loader(FLAGS.batch_size,
data_path=FLAGS.inference_data,
model_config=FLAGS)
results = []
tgt_list = []
for numerical_features, categorical_features, target in tqdm(dataloader):
numerical_features = numerical_features.cpu().numpy()
numerical_features = numerical_features.astype(np.float16 if FLAGS.fp16 else np.float32)
categorical_features = categorical_features.long().cpu().numpy()
output = run_infer(FLAGS.triton_model_name, FLAGS.triton_model_version,
numerical_features, categorical_features, headers_dict)
results.append(output.as_numpy('output__0'))
tgt_list.append(target.cpu().numpy())
results = np.concatenate(results).squeeze()
tgt_list = np.concatenate(tgt_list)
score = roc_auc_score(tgt_list, results)
print(f"Model score: {score}")
statistics = triton_client.get_inference_statistics(model_name=FLAGS.triton_model_name, headers=headers_dict)
print(statistics)
if len(statistics['model_stats']) != 1:
print("FAILED: Inference Statistics")
sys.exit(1)