from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import apex
from apex import amp
import builtins
import datetime
import json
import os
import sys
import time
import warnings
import dlrm_data_pytorch as dp
import extend_distributed as ext_dist
import mlperf_logger
import numpy as np
import sklearn.metrics
import torch
if torch.__version__ >= "1.8":
import torch_npu
import torch.nn as nn
from torch._ops import ops
from torch.autograd.profiler import record_function
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.scatter_gather import gather, scatter
from torch.nn.parameter import Parameter
from torch.optim.lr_scheduler import _LRScheduler
import optim.rwsadagrad as RowWiseSparseAdagrad
from torch.utils.tensorboard import SummaryWriter
from tricks.md_embedding_bag import PrEmbeddingBag, md_solver
from tricks.qr_embedding_bag import QREmbeddingBag
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
try:
import onnx
except ImportError as error:
print("Unable to import onnx. ", error)
exc = getattr(builtins, "IOError", "FileNotFoundError")
class NpuLinear(nn.Linear):
def forward(self, input):
return torch_npu.npu_linear(input, self.weight, self.bias)
def time_wrap(use_npu):
if use_npu:
torch.npu.synchronize()
return time.time()
def dlrm_wrap(X, lS_o, lS_i, use_npu, device, ndevices=1):
with record_function("DLRM forward"):
if use_npu:
if ndevices == 1:
lS_i = (
[S_i.to(device) for S_i in lS_i]
if isinstance(lS_i, list)
else lS_i.to(device)
)
lS_o = (
[S_o.to(device) for S_o in lS_o]
if isinstance(lS_o, list)
else lS_o.to(device)
)
return dlrm(X.to(device), lS_o, lS_i)
def loss_fn_wrap(Z, T, use_npu, device):
with record_function("DLRM loss compute"):
if args.loss_function == "mse" or args.loss_function == "bce":
return dlrm.loss_fn(Z, T.to(device))
elif args.loss_function == "wbce":
loss_ws_ = dlrm.loss_ws[T.data.view(-1).long()].view_as(T).to(device)
loss_fn_ = dlrm.loss_fn(Z, T.to(device))
loss_sc_ = loss_ws_ * loss_fn_
return loss_sc_.mean()
def unpack_batch(b):
return b[0], b[1], b[2], b[3], torch.ones(b[3].size()), None
class LRPolicyScheduler(_LRScheduler):
def __init__(self, optimizer, num_warmup_steps, decay_start_step, num_decay_steps):
self.num_warmup_steps = num_warmup_steps
self.decay_start_step = decay_start_step
self.decay_end_step = decay_start_step + num_decay_steps
self.num_decay_steps = num_decay_steps
if self.decay_start_step < self.num_warmup_steps:
sys.exit("Learning rate warmup must finish before the decay starts")
super(LRPolicyScheduler, self).__init__(optimizer)
def get_lr(self):
step_count = self._step_count
if step_count < self.num_warmup_steps:
scale = 1.0 - (self.num_warmup_steps - step_count) / self.num_warmup_steps
lr = [base_lr * scale for base_lr in self.base_lrs]
self.last_lr = lr
elif self.decay_start_step <= step_count and step_count < self.decay_end_step:
decayed_steps = step_count - self.decay_start_step
scale = ((self.num_decay_steps - decayed_steps) / self.num_decay_steps) ** 2
min_lr = 0.0000001
lr = [max(min_lr, base_lr * scale) for base_lr in self.base_lrs]
self.last_lr = lr
else:
if self.num_decay_steps > 0:
lr = self.last_lr
else:
lr = self.base_lrs
return lr
class DLRM_Net(nn.Module):
def create_mlp(self, ln, sigmoid_layer):
layers = nn.ModuleList()
for i in range(0, ln.size - 1):
n = ln[i]
m = ln[i + 1]
LL = NpuLinear(int(n), int(m), bias=True)
mean = 0.0
std_dev = np.sqrt(2 / (m + n))
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
std_dev = np.sqrt(1 / m)
bt = np.random.normal(mean, std_dev, size=m).astype(np.float32)
LL.weight.data = torch.tensor(W, requires_grad=True)
LL.bias.data = torch.tensor(bt, requires_grad=True)
layers.append(LL)
if i == sigmoid_layer:
layers.append(nn.Sigmoid())
else:
layers.append(nn.ReLU())
return torch.nn.Sequential(*layers)
def create_emb(self, m, ln, weighted_pooling=None):
emb_l = nn.ModuleList()
v_W_l = []
for i in range(0, ln.size):
if ext_dist.my_size > 1:
if i not in self.local_emb_indices:
continue
n = ln[i]
if self.qr_flag and n > self.qr_threshold:
EE = QREmbeddingBag(
n,
m,
self.qr_collisions,
operation=self.qr_operation,
mode="sum",
sparse=True,
)
elif self.md_flag and n > self.md_threshold:
base = max(m)
_m = m[i] if n > self.md_threshold else base
EE = PrEmbeddingBag(n, _m, base)
W = np.random.uniform(
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, _m)
).astype(np.float32)
EE.embs.weight.data = torch.tensor(W, requires_grad=True)
else:
EE = nn.Embedding(n, m)
W = np.random.uniform(
low=-np.sqrt(1 / n), high=np.sqrt(1 / n), size=(n, m)
).astype(np.float32)
EE.weight.data = torch.tensor(W, requires_grad=True)
if weighted_pooling is None:
v_W_l.append(None)
else:
v_W_l.append(torch.ones(n, dtype=torch.float32))
emb_l.append(EE)
return emb_l, v_W_l
def __init__(
self,
m_spa=None,
ln_emb=None,
ln_bot=None,
ln_top=None,
arch_interaction_op=None,
arch_interaction_itself=False,
sigmoid_bot=-1,
sigmoid_top=-1,
sync_dense_params=True,
loss_threshold=0.0,
ndevices=-1,
qr_flag=False,
qr_operation="mult",
qr_collisions=0,
qr_threshold=200,
md_flag=False,
md_threshold=200,
weighted_pooling=None,
loss_function="bce"
):
super(DLRM_Net, self).__init__()
if (
(m_spa is not None)
and (ln_emb is not None)
and (ln_bot is not None)
and (ln_top is not None)
and (arch_interaction_op is not None)
):
self.ndevices = ndevices
self.output_d = 0
self.parallel_model_batch_size = -1
self.parallel_model_is_not_prepared = True
self.arch_interaction_op = arch_interaction_op
self.arch_interaction_itself = arch_interaction_itself
self.sync_dense_params = sync_dense_params
self.loss_threshold = loss_threshold
self.loss_function=loss_function
if weighted_pooling is not None and weighted_pooling != "fixed":
self.weighted_pooling = "learned"
else:
self.weighted_pooling = weighted_pooling
self.qr_flag = qr_flag
if self.qr_flag:
self.qr_collisions = qr_collisions
self.qr_operation = qr_operation
self.qr_threshold = qr_threshold
self.md_flag = md_flag
if self.md_flag:
self.md_threshold = md_threshold
if ext_dist.my_size > 1:
n_emb = len(ln_emb)
if n_emb < ext_dist.my_size:
sys.exit(
"only (%d) sparse features for (%d) devices, table partitions will fail"
% (n_emb, ext_dist.my_size)
)
self.n_global_emb = n_emb
self.n_local_emb, self.n_emb_per_rank = ext_dist.get_split_lengths(
n_emb
)
self.local_emb_slice = ext_dist.get_my_slice(n_emb)
self.local_emb_indices = list(range(n_emb))[self.local_emb_slice]
if ndevices <= 1:
self.emb_l, w_list = self.create_emb(m_spa, ln_emb, weighted_pooling)
if self.weighted_pooling == "learned":
self.v_W_l = nn.ParameterList()
for w in w_list:
self.v_W_l.append(Parameter(w))
else:
self.v_W_l = w_list
self.bot_l = self.create_mlp(ln_bot, sigmoid_bot)
self.top_l = self.create_mlp(ln_top, sigmoid_top)
self.quantize_emb = False
self.emb_l_q = []
self.quantize_bits = 32
if self.loss_function == "mse":
self.loss_fn = torch.nn.MSELoss(reduction="mean")
elif self.loss_function == "bce":
self.loss_fn = torch.nn.BCELoss(reduction="mean")
elif self.loss_function == "wbce":
self.loss_ws = torch.tensor(
np.fromstring(args.loss_weights, dtype=float, sep="-")
)
self.loss_fn = torch.nn.BCELoss(reduction="none")
else:
sys.exit(
"ERROR: --loss-function=" + self.loss_function + " is not supported"
)
def apply_mlp(self, x, layers):
return layers(x)
def apply_emb(self, lS_o, lS_i, emb_l, v_W_l):
ly = []
for k, sparse_index_group_batch in enumerate(lS_i):
sparse_offset_group_batch = lS_o[k]
if v_W_l[k] is not None:
per_sample_weights = v_W_l[k].gather(0, sparse_index_group_batch)
else:
per_sample_weights = None
if self.quantize_emb:
s1 = self.emb_l_q[k].element_size() * self.emb_l_q[k].nelement()
s2 = self.emb_l_q[k].element_size() * self.emb_l_q[k].nelement()
print("quantized emb sizes:", s1, s2)
if self.quantize_bits == 4:
QV = ops.quantized.embedding_bag_4bit_rowwise_offsets(
self.emb_l_q[k],
sparse_index_group_batch,
sparse_offset_group_batch,
per_sample_weights=per_sample_weights,
)
elif self.quantize_bits == 8:
QV = ops.quantized.embedding_bag_byte_rowwise_offsets(
self.emb_l_q[k],
sparse_index_group_batch,
sparse_offset_group_batch,
per_sample_weights=per_sample_weights,
)
ly.append(QV)
else:
E = emb_l[k]
V = E(sparse_index_group_batch)
ly.append(V)
return ly
def quantize_embedding(self, bits):
n = len(self.emb_l)
self.emb_l_q = [None] * n
for k in range(n):
if bits == 4:
self.emb_l_q[k] = ops.quantized.embedding_bag_4bit_prepack(
self.emb_l[k].weight
)
elif bits == 8:
self.emb_l_q[k] = ops.quantized.embedding_bag_byte_prepack(
self.emb_l[k].weight
)
else:
return
self.emb_l = None
self.quantize_emb = True
self.quantize_bits = bits
def interact_features(self, x, ly):
if self.arch_interaction_op == "dot":
(batch_size, d) = x.shape
T = torch.cat([x] + ly, dim=1).view((batch_size, -1, d))
Z = torch_npu.npu_bmmV2(T, T.permute(0, 2, 1), [])
_, ni, nj = Z.shape
offset = 1 if self.arch_interaction_itself else 0
li = torch.tensor([i for i in range(ni) for j in range(i + offset)])
lj = torch.tensor([j for i in range(nj) for j in range(i + offset)])
l = li * nj + lj
Zflat = Z.reshape(-1, ni * nj)[:, l]
R = torch.cat([x] + [Zflat], dim=1)
elif self.arch_interaction_op == "cat":
R = torch.cat([x] + ly, dim=1)
else:
sys.exit(
"ERROR: --arch-interaction-op="
+ self.arch_interaction_op
+ " is not supported"
)
return R
def forward(self, dense_x, lS_o, lS_i):
if ext_dist.my_size > 1:
return self.distributed_forward(dense_x, lS_o, lS_i)
elif self.ndevices <= 1:
return self.sequential_forward(dense_x, lS_o, lS_i)
else:
return self.parallel_forward(dense_x, lS_o, lS_i)
def distributed_forward(self, dense_x, lS_o, lS_i):
batch_size = dense_x.size()[0]
if batch_size < ext_dist.my_size:
sys.exit(
"ERROR: batch_size (%d) must be larger than number of ranks (%d)"
% (batch_size, ext_dist.my_size)
)
if batch_size % ext_dist.my_size != 0:
sys.exit(
"ERROR: batch_size %d can not split across %d ranks evenly"
% (batch_size, ext_dist.my_size)
)
dense_x = dense_x[ext_dist.get_my_slice(batch_size)]
lS_o = lS_o[self.local_emb_slice]
lS_i = lS_i[self.local_emb_slice]
if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)):
sys.exit(
"ERROR: corrupted model input detected in distributed_forward call"
)
with record_function("DLRM embedding forward"):
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
if len(self.emb_l) != len(ly):
sys.exit("ERROR: corrupted intermediate result in distributed_forward call")
a2a_req = ext_dist.alltoall(ly, self.n_emb_per_rank)
with record_function("DLRM bottom nlp forward"):
x = self.apply_mlp(dense_x, self.bot_l)
ly = a2a_req.wait()
ly = list(ly)
with record_function("DLRM interaction forward"):
z = self.interact_features(x, ly)
with record_function("DLRM top nlp forward"):
p = self.apply_mlp(z, self.top_l)
if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
else:
z = p
return z
def sequential_forward(self, dense_x, lS_o, lS_i):
x = self.apply_mlp(dense_x, self.bot_l)
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
z = self.interact_features(x, ly)
p = self.apply_mlp(z, self.top_l)
if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
z = torch.clamp(p, min=self.loss_threshold, max=(1.0 - self.loss_threshold))
else:
z = p
return z
def parallel_forward(self, dense_x, lS_o, lS_i):
batch_size = dense_x.size()[0]
ndevices = min(self.ndevices, batch_size, len(self.emb_l))
device_ids = range(ndevices)
if self.parallel_model_batch_size != batch_size:
self.parallel_model_is_not_prepared = True
if self.parallel_model_is_not_prepared or self.sync_dense_params:
self.bot_l_replicas = replicate(self.bot_l, device_ids)
self.top_l_replicas = replicate(self.top_l, device_ids)
self.parallel_model_batch_size = batch_size
if self.parallel_model_is_not_prepared:
t_list = []
w_list = []
for k, emb in enumerate(self.emb_l):
d = torch.device("npu:" + str(k % ndevices))
t_list.append(emb.to(d))
if self.weighted_pooling == "learned":
w_list.append(Parameter(self.v_W_l[k].to(d)))
elif self.weighted_pooling == "fixed":
w_list.append(self.v_W_l[k].to(d))
else:
w_list.append(None)
self.emb_l = nn.ModuleList(t_list)
if self.weighted_pooling == "learned":
self.v_W_l = nn.ParameterList(w_list)
else:
self.v_W_l = w_list
self.parallel_model_is_not_prepared = False
dense_x = scatter(dense_x, device_ids, dim=0)
if (len(self.emb_l) != len(lS_o)) or (len(self.emb_l) != len(lS_i)):
sys.exit("ERROR: corrupted model input detected in parallel_forward call")
t_list = []
i_list = []
for k, _ in enumerate(self.emb_l):
d = torch.device("npu:" + str(k % ndevices))
t_list.append(lS_o[k].to(d))
i_list.append(lS_i[k].to(d))
lS_o = t_list
lS_i = i_list
x = parallel_apply(self.bot_l_replicas, dense_x, None, device_ids)
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l)
if len(self.emb_l) != len(ly):
sys.exit("ERROR: corrupted intermediate result in parallel_forward call")
t_list = []
for k, _ in enumerate(self.emb_l):
d = torch.device("npu:" + str(k % ndevices))
y = scatter(ly[k], device_ids, dim=0)
t_list.append(y)
ly = list(map(lambda y: list(y), zip(*t_list)))
z = []
for k in range(ndevices):
zk = self.interact_features(x[k], ly[k])
z.append(zk)
p = parallel_apply(self.top_l_replicas, z, None, device_ids)
p0 = gather(p, self.output_d, dim=0)
if 0.0 < self.loss_threshold and self.loss_threshold < 1.0:
z0 = torch.clamp(
p0, min=self.loss_threshold, max=(1.0 - self.loss_threshold)
)
else:
z0 = p0
return z0
def dash_separated_ints(value):
vals = value.split("-")
for val in vals:
try:
int(val)
except ValueError:
raise argparse.ArgumentTypeError(
"%s is not a valid dash separated list of ints" % value
)
return value
def dash_separated_floats(value):
vals = value.split("-")
for val in vals:
try:
float(val)
except ValueError:
raise argparse.ArgumentTypeError(
"%s is not a valid dash separated list of floats" % value
)
return value
def inference(
args,
dlrm,
best_acc_test,
best_auc_test,
test_ld,
device,
use_npu,
log_iter=-1,
):
test_accu = 0
test_samp = 0
if args.mlperf_logging:
scores = []
targets = []
for i, testBatch in enumerate(test_ld):
if nbatches > 0 and i >= nbatches:
break
X_test, lS_o_test, lS_i_test, T_test, W_test, CBPP_test = unpack_batch(
testBatch
)
if ext_dist.my_size > 1 and X_test.size(0) % ext_dist.my_size != 0:
print("Warning: Skiping the batch %d with size %d" % (i, X_test.size(0)))
continue
Z_test = dlrm_wrap(
X_test,
lS_o_test,
lS_i_test,
use_npu,
device,
ndevices=ndevices,
)
if Z_test.is_npu:
torch.npu.synchronize()
(_, batch_split_lengths) = ext_dist.get_split_lengths(X_test.size(0))
if ext_dist.my_size > 1:
Z_test = ext_dist.all_gather(Z_test, batch_split_lengths)
if args.mlperf_logging:
S_test = Z_test.detach().cpu().numpy()
T_test = T_test.detach().cpu().numpy()
scores.append(S_test)
targets.append(T_test)
else:
with record_function("DLRM accuracy compute"):
S_test = Z_test.detach().cpu().numpy()
T_test = T_test.detach().cpu().numpy()
mbs_test = T_test.shape[0]
A_test = np.sum((np.round(S_test, 0) == T_test).astype(np.uint8))
test_accu += A_test
test_samp += mbs_test
if args.mlperf_logging:
with record_function("DLRM mlperf sklearn metrics compute"):
scores = np.concatenate(scores, axis=0)
targets = np.concatenate(targets, axis=0)
metrics = {
"recall": lambda y_true, y_score: sklearn.metrics.recall_score(
y_true=y_true, y_pred=np.round(y_score)
),
"precision": lambda y_true, y_score: sklearn.metrics.precision_score(
y_true=y_true, y_pred=np.round(y_score)
),
"f1": lambda y_true, y_score: sklearn.metrics.f1_score(
y_true=y_true, y_pred=np.round(y_score)
),
"ap": sklearn.metrics.average_precision_score,
"roc_auc": sklearn.metrics.roc_auc_score,
"accuracy": lambda y_true, y_score: sklearn.metrics.accuracy_score(
y_true=y_true, y_pred=np.round(y_score)
),
}
validation_results = {}
for metric_name, metric_function in metrics.items():
validation_results[metric_name] = metric_function(targets, scores)
writer.add_scalar(
"mlperf-metrics-test/" + metric_name,
validation_results[metric_name],
log_iter,
)
acc_test = validation_results["accuracy"]
else:
acc_test = test_accu / test_samp
writer.add_scalar("Test/Acc", acc_test, log_iter)
model_metrics_dict = {
"nepochs": args.nepochs,
"nbatches": nbatches,
"nbatches_test": nbatches_test,
"state_dict": dlrm.state_dict(),
"test_acc": acc_test,
}
if args.mlperf_logging:
is_best = validation_results["roc_auc"] > best_auc_test
if is_best:
best_auc_test = validation_results["roc_auc"]
model_metrics_dict["test_auc"] = best_auc_test
print(
"recall {:.4f}, precision {:.4f},".format(
validation_results["recall"],
validation_results["precision"],
)
+ " f1 {:.4f}, ap {:.4f},".format(
validation_results["f1"], validation_results["ap"]
)
+ " auc {:.4f}, best auc {:.4f},".format(
validation_results["roc_auc"], best_auc_test
)
+ " accuracy {:3.3f} %, best accuracy {:3.3f} %".format(
validation_results["accuracy"] * 100, best_acc_test * 100
),
flush=True,
)
else:
is_best = acc_test > best_acc_test
if is_best:
best_acc_test = acc_test
print(
" accuracy {:3.3f} %, best {:3.3f} %".format(
acc_test * 100, best_acc_test * 100
),
flush=True,
)
return model_metrics_dict, is_best
def run():
parser = argparse.ArgumentParser(
description="Train Deep Learning Recommendation Model (DLRM)"
)
parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
parser.add_argument(
"--arch-embedding-size", type=dash_separated_ints, default="4-3-2"
)
parser.add_argument("--arch-mlp-bot", type=dash_separated_ints, default="4-3-2")
parser.add_argument("--arch-mlp-top", type=dash_separated_ints, default="4-2-1")
parser.add_argument(
"--arch-interaction-op", type=str, choices=["dot", "cat"], default="dot"
)
parser.add_argument("--arch-interaction-itself", action="store_true", default=False)
parser.add_argument("--weighted-pooling", type=str, default=None)
parser.add_argument("--md-flag", action="store_true", default=False)
parser.add_argument("--md-threshold", type=int, default=200)
parser.add_argument("--md-temperature", type=float, default=0.3)
parser.add_argument("--md-round-dims", action="store_true", default=False)
parser.add_argument("--qr-flag", action="store_true", default=False)
parser.add_argument("--qr-threshold", type=int, default=200)
parser.add_argument("--qr-operation", type=str, default="mult")
parser.add_argument("--qr-collisions", type=int, default=4)
parser.add_argument("--activation-function", type=str, default="relu")
parser.add_argument("--loss-function", type=str, default="mse")
parser.add_argument(
"--loss-weights", type=dash_separated_floats, default="1.0-1.0"
)
parser.add_argument("--loss-threshold", type=float, default=0.0)
parser.add_argument("--round-targets", type=bool, default=False)
parser.add_argument("--data-size", type=int, default=1)
parser.add_argument("--num-batches", type=int, default=0)
parser.add_argument(
"--data-generation", type=str, default="random"
)
parser.add_argument(
"--rand-data-dist", type=str, default="uniform"
)
parser.add_argument("--rand-data-min", type=float, default=0)
parser.add_argument("--rand-data-max", type=float, default=1)
parser.add_argument("--rand-data-mu", type=float, default=-1)
parser.add_argument("--rand-data-sigma", type=float, default=1)
parser.add_argument("--data-trace-file", type=str, default="./input/dist_emb_j.log")
parser.add_argument("--data-set", type=str, default="kaggle")
parser.add_argument("--raw-data-file", type=str, default="")
parser.add_argument("--processed-data-file", type=str, default="")
parser.add_argument("--data-randomize", type=str, default="total")
parser.add_argument("--data-trace-enable-padding", type=bool, default=False)
parser.add_argument("--max-ind-range", type=int, default=-1)
parser.add_argument("--data-sub-sample-rate", type=float, default=0.0)
parser.add_argument("--num-indices-per-lookup", type=int, default=10)
parser.add_argument("--num-indices-per-lookup-fixed", type=bool, default=False)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--memory-map", action="store_true", default=False)
parser.add_argument("--step", type=int, default=-1)
parser.add_argument("--mini-batch-size", type=int, default=1)
parser.add_argument("--nepochs", type=int, default=1)
parser.add_argument("--learning-rate", type=float, default=0.01)
parser.add_argument("--print-precision", type=int, default=5)
parser.add_argument("--numpy-rand-seed", type=int, default=123)
parser.add_argument("--sync-dense-params", type=bool, default=True)
parser.add_argument("--optimizer", type=str, default="sgd")
parser.add_argument(
"--dataset-multiprocessing",
action="store_true",
default=False,
help="The Kaggle dataset can be multiprocessed in an environment \
with more than 7 CPU cores and more than 20 GB of memory. \n \
The Terabyte dataset can be multiprocessed in an environment \
with more than 24 CPU cores and at least 1 TB of memory.",
)
parser.add_argument("--inference-only", action="store_true", default=False)
parser.add_argument("--quantize-mlp-with-bit", type=int, default=32)
parser.add_argument("--quantize-emb-with-bit", type=int, default=32)
parser.add_argument("--use-npu", action="store_true", default=False)
parser.add_argument("--local_rank", type=int, default=-1)
parser.add_argument("--world_size", type=int, default=1)
parser.add_argument("--dist-backend", type=str, default="hccl")
parser.add_argument("--print-freq", type=int, default=1)
parser.add_argument("--test-freq", type=int, default=-1)
parser.add_argument("--test-mini-batch-size", type=int, default=-1)
parser.add_argument("--test-num-workers", type=int, default=-1)
parser.add_argument("--print-time", action="store_true", default=False)
parser.add_argument("--print-wall-time", action="store_true", default=False)
parser.add_argument("--tensor-board-filename", type=str, default="run_kaggle_pt")
parser.add_argument("--save-model", type=str, default="")
parser.add_argument("--load-model", type=str, default="")
parser.add_argument("--mlperf-logging", action="store_true", default=False)
parser.add_argument("--mlperf-acc-threshold", type=float, default=0.0)
parser.add_argument("--mlperf-auc-threshold", type=float, default=0.0)
parser.add_argument("--mlperf-bin-loader", action="store_true", default=False)
parser.add_argument("--mlperf-bin-shuffle", action="store_true", default=False)
parser.add_argument("--mlperf-grad-accum-iter", type=int, default=1)
parser.add_argument("--lr-num-warmup-steps", type=int, default=0)
parser.add_argument("--lr-decay-start-step", type=int, default=0)
parser.add_argument("--lr-num-decay-steps", type=int, default=0)
global args
global nbatches
global nbatches_test
global writer
args = parser.parse_args()
if args.dataset_multiprocessing:
assert float(sys.version[:3]) > 3.7, "The dataset_multiprocessing " + \
"flag is susceptible to a bug in Python 3.7 and under. " + \
"https://github.com/facebookresearch/dlrm/issues/172"
if args.mlperf_logging:
mlperf_logger.log_event(key=mlperf_logger.constants.CACHE_CLEAR, value=True)
mlperf_logger.log_start(
key=mlperf_logger.constants.INIT_START, log_all_ranks=True
)
if args.weighted_pooling is not None:
if args.qr_flag:
sys.exit("ERROR: quotient remainder with weighted pooling is not supported")
if args.md_flag:
sys.exit("ERROR: mixed dimensions with weighted pooling is not supported")
if args.quantize_emb_with_bit in [4, 8]:
if args.qr_flag:
sys.exit(
"ERROR: 4 and 8-bit quantization with quotient remainder is not supported"
)
if args.md_flag:
sys.exit(
"ERROR: 4 and 8-bit quantization with mixed dimensions is not supported"
)
if args.use_npu:
sys.exit(
"ERROR: 4 and 8-bit quantization on GPU is not supported"
)
np.random.seed(args.numpy_rand_seed)
np.set_printoptions(precision=args.print_precision)
torch.set_printoptions(precision=args.print_precision)
torch.manual_seed(args.numpy_rand_seed)
if args.test_mini_batch_size < 0:
args.test_mini_batch_size = args.mini_batch_size
if args.test_num_workers < 0:
args.test_num_workers = args.num_workers
use_npu = args.use_npu
ext_dist.init_distributed(local_rank=args.local_rank, size= args.world_size, use_npu=use_npu, backend=args.dist_backend)
if use_npu:
if ext_dist.my_size > 1:
ngpus = 1
device = torch.device("npu", ext_dist.my_local_rank)
else:
ngpus = args.world_size
device = torch.device("npu", args.local_rank)
print("Using {} NPU(s)...".format(ngpus))
else:
device = torch.device("cpu")
print("Using CPU...")
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
if args.mlperf_logging:
mlperf_logger.barrier()
mlperf_logger.log_end(key=mlperf_logger.constants.INIT_STOP)
mlperf_logger.barrier()
mlperf_logger.log_start(key=mlperf_logger.constants.RUN_START)
mlperf_logger.barrier()
if args.data_generation == "dataset":
train_data, train_ld, test_data, test_ld = dp.make_criteo_data_and_loaders(args)
table_feature_map = {idx: idx for idx in range(len(train_data.counts))}
nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
nbatches_test = len(test_ld)
ln_emb = train_data.counts
if args.max_ind_range > 0:
ln_emb = np.array(
list(
map(
lambda x: x if x < args.max_ind_range else args.max_ind_range,
ln_emb,
)
)
)
else:
ln_emb = np.array(ln_emb)
m_den = train_data.m_den
ln_bot[0] = m_den
else:
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
m_den = ln_bot[0]
train_data, train_ld, test_data, test_ld = dp.make_random_data_and_loader(args, ln_emb, m_den)
nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
nbatches_test = len(test_ld)
args.ln_emb = ln_emb.tolist()
if args.mlperf_logging:
print("command line args: ", json.dumps(vars(args)))
m_spa = args.arch_sparse_feature_size
ln_emb = np.asarray(ln_emb)
num_fea = ln_emb.size + 1
m_den_out = ln_bot[ln_bot.size - 1]
if args.arch_interaction_op == "dot":
if args.arch_interaction_itself:
num_int = (num_fea * (num_fea + 1)) // 2 + m_den_out
else:
num_int = (num_fea * (num_fea - 1)) // 2 + m_den_out
elif args.arch_interaction_op == "cat":
num_int = num_fea * m_den_out
else:
sys.exit(
"ERROR: --arch-interaction-op="
+ args.arch_interaction_op
+ " is not supported"
)
arch_mlp_top_adjusted = str(num_int) + "-" + args.arch_mlp_top
ln_top = np.fromstring(arch_mlp_top_adjusted, dtype=int, sep="-")
if m_den != ln_bot[0]:
sys.exit(
"ERROR: arch-dense-feature-size "
+ str(m_den)
+ " does not match first dim of bottom mlp "
+ str(ln_bot[0])
)
if args.qr_flag:
if args.qr_operation == "concat" and 2 * m_spa != m_den_out:
sys.exit(
"ERROR: 2 arch-sparse-feature-size "
+ str(2 * m_spa)
+ " does not match last dim of bottom mlp "
+ str(m_den_out)
+ " (note that the last dim of bottom mlp must be 2x the embedding dim)"
)
if args.qr_operation != "concat" and m_spa != m_den_out:
sys.exit(
"ERROR: arch-sparse-feature-size "
+ str(m_spa)
+ " does not match last dim of bottom mlp "
+ str(m_den_out)
)
else:
if m_spa != m_den_out:
sys.exit(
"ERROR: arch-sparse-feature-size "
+ str(m_spa)
+ " does not match last dim of bottom mlp "
+ str(m_den_out)
)
if num_int != ln_top[0]:
sys.exit(
"ERROR: # of feature interactions "
+ str(num_int)
+ " does not match first dimension of top mlp "
+ str(ln_top[0])
)
if args.md_flag:
m_spa = md_solver(
torch.tensor(ln_emb),
args.md_temperature,
d0=m_spa,
round_dim=args.md_round_dims,
).tolist()
global ndevices
ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_npu else -1
print(f'[run] ndevices={ndevices}')
global dlrm
dlrm = DLRM_Net(
m_spa,
ln_emb,
ln_bot,
ln_top,
arch_interaction_op=args.arch_interaction_op,
arch_interaction_itself=args.arch_interaction_itself,
sigmoid_bot=-1,
sigmoid_top=ln_top.size - 2,
sync_dense_params=args.sync_dense_params,
loss_threshold=args.loss_threshold,
ndevices=ndevices,
qr_flag=args.qr_flag,
qr_operation=args.qr_operation,
qr_collisions=args.qr_collisions,
qr_threshold=args.qr_threshold,
md_flag=args.md_flag,
md_threshold=args.md_threshold,
weighted_pooling=args.weighted_pooling,
loss_function=args.loss_function
)
if use_npu:
if dlrm.ndevices > 1:
dlrm.emb_l, dlrm.v_W_l = dlrm.create_emb(
m_spa, ln_emb, args.weighted_pooling
)
else:
if dlrm.weighted_pooling == "fixed":
for k, w in enumerate(dlrm.v_W_l):
dlrm.v_W_l[k] = w.npu()
dlrm = dlrm.to(device)
if ext_dist.my_size > 1:
if use_npu:
device_ids = [ext_dist.my_local_rank]
dlrm.bot_l = ext_dist.DDP(dlrm.bot_l, device_ids=device_ids)
dlrm.top_l = ext_dist.DDP(dlrm.top_l, device_ids=device_ids)
else:
dlrm.bot_l = ext_dist.DDP(dlrm.bot_l)
dlrm.top_l = ext_dist.DDP(dlrm.top_l)
if not args.inference_only:
parameters = (
dlrm.parameters()
if ext_dist.my_size == 1
else [
{
"params": [p for emb in dlrm.emb_l for p in emb.parameters()],
"lr": args.learning_rate,
},
{
"params": dlrm.bot_l.parameters(),
"lr": args.learning_rate,
},
{
"params": dlrm.top_l.parameters(),
"lr": args.learning_rate,
},
]
)
optimizer = apex.optimizers.NpuFusedSGD(parameters, lr=args.learning_rate)
lr_scheduler = LRPolicyScheduler(
optimizer,
args.lr_num_warmup_steps,
args.lr_decay_start_step,
args.lr_num_decay_steps,
)
amp.register_float_function(torch, 'sigmoid')
dlrm, optimizer = amp.initialize(dlrm, optimizer, opt_level="O1", loss_scale=128., combine_grad=True)
best_acc_test = 0
best_auc_test = 0
skip_upto_epoch = 0
skip_upto_batch = 0
total_time = 0
total_loss = 0
total_iter = 0
total_samp = 0
if args.mlperf_logging:
mlperf_logger.mlperf_submission_log("dlrm")
mlperf_logger.log_event(
key=mlperf_logger.constants.SEED, value=args.numpy_rand_seed
)
mlperf_logger.log_event(
key=mlperf_logger.constants.GLOBAL_BATCH_SIZE, value=args.mini_batch_size
)
if not (args.load_model == ""):
print("Loading saved model {}".format(args.load_model))
if use_npu:
if dlrm.ndevices > 1:
ld_model = torch.load(args.load_model)
else:
ld_model = torch.load(
args.load_model
)
else:
ld_model = torch.load(args.load_model, map_location=torch.device("cpu"))
dlrm.load_state_dict(ld_model["state_dict"])
ld_j = ld_model["iter"]
ld_k = ld_model["epoch"]
ld_nepochs = ld_model["nepochs"]
ld_nbatches = ld_model["nbatches"]
ld_nbatches_test = ld_model["nbatches_test"]
ld_train_loss = ld_model["train_loss"]
ld_total_loss = ld_model["total_loss"]
if args.mlperf_logging:
ld_gAUC_test = ld_model["test_auc"]
ld_acc_test = ld_model["test_acc"]
if not args.inference_only:
optimizer.load_state_dict(ld_model["opt_state_dict"])
best_acc_test = ld_acc_test
total_loss = ld_total_loss
skip_upto_epoch = ld_k
skip_upto_batch = ld_j
else:
args.print_freq = ld_nbatches
args.test_freq = 0
print(
"Saved at: epoch = {:d}/{:d}, batch = {:d}/{:d}, ntbatch = {:d}".format(
ld_k, ld_nepochs, ld_j, ld_nbatches, ld_nbatches_test
)
)
print(
"Training state: loss = {:.6f}".format(
ld_train_loss,
)
)
if args.mlperf_logging:
print(
"Testing state: accuracy = {:3.3f} %, auc = {:.3f}".format(
ld_acc_test * 100, ld_gAUC_test
)
)
else:
print("Testing state: accuracy = {:3.3f} %".format(ld_acc_test * 100))
if args.inference_only:
assert args.quantize_mlp_with_bit in [
8,
16,
32,
], "only support 8/16/32-bit but got {}".format(args.quantize_mlp_with_bit)
assert args.quantize_emb_with_bit in [
4,
8,
32,
], "only support 4/8/32-bit but got {}".format(args.quantize_emb_with_bit)
if args.quantize_mlp_with_bit != 32:
if args.quantize_mlp_with_bit in [8]:
quantize_dtype = torch.qint8
else:
quantize_dtype = torch.float16
dlrm = torch.quantization.quantize_dynamic(
dlrm, {torch.nn.Linear}, quantize_dtype
)
if args.quantize_emb_with_bit != 32:
dlrm.quantize_embedding(args.quantize_emb_with_bit)
print("time/loss/accuracy (if enabled):")
if args.mlperf_logging:
mlperf_logger.log_event(
key=mlperf_logger.constants.OPT_BASE_LR, value=args.learning_rate
)
mlperf_logger.log_event(
key=mlperf_logger.constants.OPT_LR_WARMUP_STEPS,
value=args.lr_num_warmup_steps,
)
mlperf_logger.log_event(
key="sgd_opt_base_learning_rate", value=args.learning_rate
)
mlperf_logger.log_event(
key="lr_decay_start_steps", value=args.lr_decay_start_step
)
mlperf_logger.log_event(
key="sgd_opt_learning_rate_decay_steps", value=args.lr_num_decay_steps
)
mlperf_logger.log_event(key="sgd_opt_learning_rate_decay_poly_power", value=2)
tb_file = "./" + args.tensor_board_filename
writer = SummaryWriter(tb_file)
ext_dist.barrier()
if not args.inference_only:
k = 0
total_time_begin = 0
while k < args.nepochs:
if args.mlperf_logging:
mlperf_logger.barrier()
mlperf_logger.log_start(
key=mlperf_logger.constants.BLOCK_START,
metadata={
mlperf_logger.constants.FIRST_EPOCH_NUM: (k + 1),
mlperf_logger.constants.EPOCH_COUNT: 1,
},
)
mlperf_logger.barrier()
mlperf_logger.log_start(
key=mlperf_logger.constants.EPOCH_START,
metadata={mlperf_logger.constants.EPOCH_NUM: (k + 1)},
)
if k < skip_upto_epoch:
continue
if args.mlperf_logging:
previous_iteration_time = None
for j, inputBatch in enumerate(train_ld):
if args.step == -1 or j < args.step:
if j < skip_upto_batch:
continue
X, lS_o, lS_i, T, W, CBPP = unpack_batch(inputBatch)
if args.mlperf_logging:
current_time = time_wrap(use_npu)
if previous_iteration_time:
iteration_time = current_time - previous_iteration_time
else:
iteration_time = 0
previous_iteration_time = current_time
else:
t1 = time_wrap(use_npu)
if nbatches > 0 and j >= nbatches:
break
if ext_dist.my_size > 1 and X.size(0) % ext_dist.my_size != 0:
print(
"Warning: Skiping the batch %d with size %d"
% (j, X.size(0))
)
continue
mbs = T.shape[0]
Z = dlrm_wrap(
X,
lS_o,
lS_i,
use_npu,
device,
ndevices=ndevices,
)
if ext_dist.my_size > 1:
T = T[ext_dist.get_my_slice(mbs)]
W = W[ext_dist.get_my_slice(mbs)]
E = loss_fn_wrap(Z, T, use_npu, device)
L = E.detach().cpu().numpy()
with record_function("DLRM backward"):
if (args.mlperf_logging and (j + 1) % args.mlperf_grad_accum_iter == 0) or not args.mlperf_logging:
optimizer.zero_grad()
with amp.scale_loss(E, optimizer) as scaled_loss:
scaled_loss.backward()
if (args.mlperf_logging and (j + 1) % args.mlperf_grad_accum_iter == 0) or not args.mlperf_logging:
optimizer.step()
lr_scheduler.step()
if args.mlperf_logging:
total_time += iteration_time
else:
t2 = time_wrap(use_npu)
total_time += t2 - t1
total_loss += L * mbs
total_iter += 1
total_samp += mbs
should_print = ((j + 1) % args.print_freq == 0) or (
j + 1 == nbatches
)
should_test = (
(args.test_freq > 0)
and (args.data_generation in ["dataset", "random"])
and (((j + 1) % args.test_freq == 0) or (j + 1 == nbatches))
)
if should_print or should_test:
gT = 1000.0 * total_time / total_iter if args.print_time else -1
total_time = 0
train_loss = total_loss / total_samp
total_loss = 0
str_run_type = (
"inference" if args.inference_only else "training"
)
wall_time = ""
if args.print_wall_time:
wall_time = " ({})".format(time.strftime("%H:%M"))
print(
"Finished {} it {}/{} of epoch {}, {:.2f} ms/it,".format(
str_run_type, j + 1, nbatches, k, gT
)
+ " loss {:.6f}".format(train_loss)
+ wall_time,
flush=True,
)
log_iter = nbatches * k + j + 1
writer.add_scalar("Train/Loss", train_loss, log_iter)
total_iter = 0
total_samp = 0
if should_test:
epoch_num_float = (j + 1) / len(train_ld) + k + 1
if args.mlperf_logging:
mlperf_logger.barrier()
mlperf_logger.log_start(
key=mlperf_logger.constants.EVAL_START,
metadata={
mlperf_logger.constants.EPOCH_NUM: epoch_num_float
},
)
if args.mlperf_logging:
previous_iteration_time = None
print(
"Testing at - {}/{} of epoch {},".format(j + 1, nbatches, k)
)
model_metrics_dict, is_best = inference(
args,
dlrm,
best_acc_test,
best_auc_test,
test_ld,
device,
use_npu,
log_iter,
)
if (
is_best
and not (args.save_model == "")
and not args.inference_only
):
model_metrics_dict["epoch"] = k
model_metrics_dict["iter"] = j + 1
model_metrics_dict["train_loss"] = train_loss
model_metrics_dict["total_loss"] = total_loss
model_metrics_dict[
"opt_state_dict"
] = optimizer.state_dict()
print("Saving model to {}".format(args.save_model))
torch.save(model_metrics_dict, args.save_model)
if args.mlperf_logging:
mlperf_logger.barrier()
mlperf_logger.log_end(
key=mlperf_logger.constants.EVAL_STOP,
metadata={
mlperf_logger.constants.EPOCH_NUM: epoch_num_float
},
)
if (
args.mlperf_logging
and (args.mlperf_acc_threshold > 0)
and (best_acc_test > args.mlperf_acc_threshold)
):
print(
"MLPerf testing accuracy threshold "
+ str(args.mlperf_acc_threshold)
+ " reached, stop training"
)
break
if (
args.mlperf_logging
and (args.mlperf_auc_threshold > 0)
and (best_auc_test > args.mlperf_auc_threshold)
):
print(
"MLPerf testing auc threshold "
+ str(args.mlperf_auc_threshold)
+ " reached, stop training"
)
if args.mlperf_logging:
mlperf_logger.barrier()
mlperf_logger.log_end(
key=mlperf_logger.constants.RUN_STOP,
metadata={
mlperf_logger.constants.STATUS: mlperf_logger.constants.SUCCESS
},
)
break
else:
break
if args.mlperf_logging:
mlperf_logger.barrier()
mlperf_logger.log_end(
key=mlperf_logger.constants.EPOCH_STOP,
metadata={mlperf_logger.constants.EPOCH_NUM: (k + 1)},
)
mlperf_logger.barrier()
mlperf_logger.log_end(
key=mlperf_logger.constants.BLOCK_STOP,
metadata={mlperf_logger.constants.FIRST_EPOCH_NUM: (k + 1)},
)
k += 1
if args.mlperf_logging and best_auc_test <= args.mlperf_auc_threshold:
mlperf_logger.barrier()
mlperf_logger.log_end(
key=mlperf_logger.constants.RUN_STOP,
metadata={
mlperf_logger.constants.STATUS: mlperf_logger.constants.ABORTED
},
)
else:
print("Testing for inference only")
inference(
args,
dlrm,
best_acc_test,
best_auc_test,
test_ld,
device,
use_npu,
)
total_time_end = time_wrap(use_npu)
if __name__ == "__main__":
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = '29688'
run()