from __future__ import absolute_import, division, print_function, unicode_literals
import functools
import operator
import time
import copy
import dlrm_data_pytorch as dp
import numpy as np
import sklearn.metrics
import warnings
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=DeprecationWarning)
try:
import onnx
import caffe2.python.onnx.frontend
except ImportError as error:
print('Unable to import onnx or caffe2.python.onnx.frontend ', error)
from caffe2.proto import caffe2_pb2
from caffe2.python import brew, core, dyndep, model_helper, net_drawer, workspace
"""
# auxiliary routine used to split input on the mini-bacth dimension
def where_to_split(mini_batch_size, ndevices, _add_leftover=False):
n = (mini_batch_size + ndevices - 1) // ndevices # ceiling
l = mini_batch_size - n * (ndevices - 1) # leftover
s = [n] * (ndevices - 1)
if _add_leftover:
ls += [l if l > 0 else n]
return ls
"""
class DLRM_Net(object):
def FeedBlobWrapper(self, tag, val, add_prefix=True, split=False, device_id=-1):
if self.ndevices > 1 and add_prefix:
if split:
mini_batch_size = val.shape[0]
if mini_batch_size % self.ndevices != 0:
sys.exit("ERROR: caffe2 net assumes that the mini_batch_size "
+ str(mini_batch_size)
+ " is evenly divisible by the number of available devices"
+ str(self.ndevices))
vals = np.split(val, self.ndevices, axis=0)
"""
# approach 2: np and caffe2 operators do not assume exact divisibility
if args.mini_batch_size != mini_batch_size:
sys.exit("ERROR: caffe2 net was prepared for mini-batch size "
+ str(args.mini_batch_size)
+ " which is different from current mini-batch size "
+ str(mini_batch_size) + " being passed to it. "
+ "This is common for the last mini-batch, when "
+ "mini-batch size does not evenly divided the number of "
+ "elements in the data set.")
ls = where_to_split(mini_batch_size, self.ndevices)
vals = np.split(val, ls, axis=0)
"""
for d in range(self.ndevices):
tag_on_device = "gpu_" + str(d) + "/" + tag
_d = core.DeviceOption(workspace.GpuDeviceType, d)
workspace.FeedBlob(tag_on_device, vals[d], device_option=_d)
else:
for d in range(self.ndevices):
tag_on_device = "gpu_" + str(d) + "/" + tag
_d = core.DeviceOption(workspace.GpuDeviceType, d)
workspace.FeedBlob(tag_on_device, val, device_option=_d)
else:
if device_id >= 0:
_d = core.DeviceOption(workspace.GpuDeviceType, device_id)
workspace.FeedBlob(tag, val, device_option=_d)
else:
workspace.FeedBlob(tag, val)
def FetchBlobWrapper(self, tag, add_prefix=True, reduce_across=None, device_id=-1):
if self.ndevices > 1 and add_prefix:
vals = []
for d in range(self.ndevices):
if tag.__class__ == list:
tag_on_device = tag[d]
else:
tag_on_device = "gpu_" + str(0) + "/" + tag
val = workspace.FetchBlob(tag_on_device)
vals.append(val)
if reduce_across == "add":
return functools.reduce(operator.add, vals)
elif reduce_across == "concat":
return np.concatenate(vals)
else:
return vals
else:
if device_id >= 0:
tag_on_device = "gpu_" + str(device_id) + "/" + tag
return workspace.FetchBlob(tag_on_device)
else:
return workspace.FetchBlob(tag)
def AddLayerWrapper(self, layer, inp_blobs, out_blobs,
add_prefix=True, reset_grad=False, **kwargs):
def adjust_tag(blobs, on_device):
if blobs.__class__ == str:
_blobs = on_device + blobs
elif blobs.__class__ == list:
_blobs = list(map(lambda tag: on_device + tag, blobs))
else:
_blobs = blobs
return _blobs
if self.ndevices > 1 and add_prefix:
ll = []
for d in range(self.ndevices):
on_device = "gpu_" + str(d) + "/"
_inp_blobs = adjust_tag(inp_blobs, on_device)
_out_blobs = adjust_tag(out_blobs, on_device)
if reset_grad:
w_grad = self.gradientMap[_inp_blobs[0]]
_inp_blobs[2] = w_grad
with core.DeviceScope(core.DeviceOption(workspace.GpuDeviceType, d)):
if kwargs:
new_layer = layer(_inp_blobs, _out_blobs, **kwargs)
else:
new_layer = layer(_inp_blobs, _out_blobs)
ll.append(new_layer)
return ll
else:
if reset_grad:
w_grad = self.gradientMap[inp_blobs[0]]
inp_blobs[2] = w_grad
if kwargs:
new_layer = layer(inp_blobs, out_blobs, **kwargs)
else:
new_layer = layer(inp_blobs, out_blobs)
return new_layer
def create_mlp(self, ln, sigmoid_layer, model, tag):
(tag_layer, tag_in, tag_out) = tag
layers = []
weights = []
for i in range(1, ln.size):
n = ln[i - 1]
m = ln[i]
tag_fc_w = tag_layer + ":::" + "fc" + str(i) + "_w"
tag_fc_b = tag_layer + ":::" + "fc" + str(i) + "_b"
tag_fc_y = tag_layer + ":::" + "fc" + str(i) + "_y"
tag_fc_z = tag_layer + ":::" + "fc" + str(i) + "_z"
if i == ln.size - 1:
tag_fc_z = tag_out
weights.append(tag_fc_w)
weights.append(tag_fc_b)
mean = 0.0
std_dev = np.sqrt(2 / (m + n))
W = np.random.normal(mean, std_dev, size=(m, n)).astype(np.float32)
std_dev = np.sqrt(1 / m)
b = np.random.normal(mean, std_dev, size=m).astype(np.float32)
self.FeedBlobWrapper(tag_fc_w, W)
self.FeedBlobWrapper(tag_fc_b, b)
if self.emb_optimizer in ["adagrad", "rwsadagrad"]:
self.FeedBlobWrapper(
"momentum_mlp_{}_{}".format(tag_layer, 2 * i - 1),
np.full((m, n), 0, dtype=np.float32)
)
self.FeedBlobWrapper(
"momentum_mlp_{}_{}".format(tag_layer, 2 * i),
np.full((m), 0, dtype=np.float32)
)
if self.save_onnx:
self.onnx_tsd[tag_fc_w] = (onnx.TensorProto.FLOAT, W.shape)
self.onnx_tsd[tag_fc_b] = (onnx.TensorProto.FLOAT, b.shape)
fc = self.AddLayerWrapper(
model.net.FC, [tag_in, tag_fc_w, tag_fc_b], tag_fc_y
)
layers.append(fc)
if i == sigmoid_layer:
layer = self.AddLayerWrapper(model.net.Sigmoid, tag_fc_y, tag_fc_z)
else:
layer = self.AddLayerWrapper(model.net.Relu, tag_fc_y, tag_fc_z)
tag_in = tag_fc_z
layers.append(layer)
return layers, weights
def create_emb(self, m, ln, model, tag):
(tag_layer, tag_in, tag_out) = tag
emb_l = []
weights_l = []
vw_l = []
for i in range(0, ln.size):
n = ln[i]
if self.ndevices > 1:
d = i % self.ndevices
else:
d = -1
on_device = "" if self.ndevices <= 1 else "gpu_" + str(d) + "/"
len_s = on_device + tag_layer + ":::" + "sls" + str(i) + "_l"
ind_s = on_device + tag_layer + ":::" + "sls" + str(i) + "_i"
tbl_s = on_device + tag_layer + ":::" + "sls" + str(i) + "_w"
sum_s = on_device + tag_layer + ":::" + "sls" + str(i) + "_z"
weights_l.append(tbl_s)
W = np.random.uniform(low=-np.sqrt(1 / n),
high=np.sqrt(1 / n),
size=(n, m)).astype(np.float32)
self.FeedBlobWrapper(tbl_s, W, False, device_id=d)
if self.emb_optimizer == "adagrad":
self.FeedBlobWrapper("momentum_emb_{}".format(i),
np.full((n, m), 0), add_prefix=False, device_id=d)
elif self.emb_optimizer == "rwsadagrad":
self.FeedBlobWrapper("momentum_emb_{}".format(i),
np.full((n), 0), add_prefix=False, device_id=d)
if self.save_onnx:
self.onnx_tsd[tbl_s] = (onnx.TensorProto.FLOAT, W.shape)
if self.weighted_pooling is not None:
vw_s = on_device + tag_layer + ":::" + "sls" + str(i) + "_v"
psw_s = on_device + tag_layer + ":::" + "sls" + str(i) + "_s"
VW = np.ones(n).astype(np.float32)
self.FeedBlobWrapper(vw_s, VW, False, device_id=d)
if self.weighted_pooling == "learned":
vw_l.append(vw_s)
grad_on_weights = True
else:
grad_on_weights = False
if self.save_onnx:
self.onnx_tsd[vw_s] = (onnx.TensorProto.FLOAT, VW.shape)
if self.ndevices <= 1:
PSW = model.net.Gather([vw_s, ind_s], [psw_s])
EE = model.net.SparseLengthsWeightedSum(
[tbl_s, PSW, ind_s, len_s], [sum_s],
grad_on_weights=grad_on_weights
)
else:
with core.DeviceScope(
core.DeviceOption(workspace.GpuDeviceType, d)
):
PSW = model.net.Gather([vw_s, ind_s], [psw_s])
EE = model.net.SparseLengthsWeightedSum(
[tbl_s, PSW, ind_s, len_s], [sum_s],
grad_on_weights=grad_on_weights
)
else:
if self.ndevices <= 1:
EE = model.net.SparseLengthsSum(
[tbl_s, ind_s, len_s], [sum_s]
)
else:
with core.DeviceScope(
core.DeviceOption(workspace.GpuDeviceType, d)
):
EE = model.net.SparseLengthsSum(
[tbl_s, ind_s, len_s], [sum_s]
)
emb_l.append(EE)
return emb_l, weights_l, vw_l
def create_interactions(self, x, ly, model, tag):
(tag_dense_in, tag_sparse_in, tag_int_out) = tag
if self.arch_interaction_op == "dot":
tag_int_out_info = tag_int_out + "_info"
T, T_info = model.net.Concat(
x + ly,
[tag_int_out + "_cat_axis0", tag_int_out_info + "_cat_axis0"],
axis=1,
add_axis=1,
)
Z = model.net.BatchMatMul([T, T], tag_int_out + "_matmul", trans_b=1)
Zflat_all = model.net.Flatten(Z, tag_int_out + "_flatten_all", axis=1)
Zflat = model.net.BatchGather(
[Zflat_all, tag_int_out + "_tril_indices"],
tag_int_out + "_flatten"
)
R, R_info = model.net.Concat(
x + [Zflat], [tag_int_out, tag_int_out_info], axis=1
)
elif self.arch_interaction_op == "cat":
tag_int_out_info = tag_int_out + "_info"
R, R_info = model.net.Concat(
x + ly, [tag_int_out, tag_int_out_info], axis=1
)
else:
sys.exit("ERROR: --arch-interaction-op="
+ self.arch_interaction_op + " is not supported")
return R
def create_sequential_forward_ops(self):
tag = (self.temb, self.tsin, self.tsout)
self.emb_l, self.emb_w, self.emb_vw = self.create_emb(
self.m_spa, self.ln_emb, self.model, tag
)
tag = (self.tbot, self.tdin, self.tdout)
self.bot_l, self.bot_w = self.create_mlp(self.ln_bot, self.sigmoid_bot,
self.model, tag)
tag = (self.tdout, self.tsout, self.tint)
Z = self.create_interactions([self.bot_l[-1]], self.emb_l, self.model, tag)
tag = (self.ttop, Z, self.tout)
self.top_l, self.top_w = self.create_mlp(self.ln_top, self.sigmoid_top,
self.model, tag)
self.last_output = self.top_l[-1]
def create_parallel_forward_ops(self):
tag = (self.temb, self.tsin, self.tsout)
self.emb_l, self.emb_w, self.emb_vw = self.create_emb(
self.m_spa, self.ln_emb, self.model, tag
)
tag = (self.tbot, self.tdin, self.tdout)
self.bot_l, self.bot_w = self.create_mlp(self.ln_bot, self.sigmoid_bot,
self.model, tag)
t_list = []
for i, emb_output in enumerate(self.emb_l):
src_d = i % self.ndevices
lo = [emb_output + "_split_" + str(d) for d in range(self.ndevices)]
with core.DeviceScope(core.DeviceOption(workspace.GpuDeviceType, src_d)):
self.model.net.Split(emb_output, lo, axis=0)
"""
# approach 2: np and caffe2 operators do not assume exact divisibility
ls = where_to_split(args.mini_batch_size, self.ndevices, _add_leftover=True)
with core.DeviceScope(core.DeviceOption(workspace.GpuDeviceType, src_d)):
emb_output_split = self.model.net.Split(
emb_output, lo, split=lp, axis=0
)
"""
y = []
for dst_d in range(len(lo)):
src_blob = lo[dst_d]
dst_blob = str(src_blob).replace(
"gpu_" + str(src_d), "gpu_" + str(dst_d), 1
)
if src_blob != dst_blob:
with core.DeviceScope(
core.DeviceOption(workspace.GpuDeviceType, dst_d)
):
blob = self.model.Copy(src_blob, dst_blob)
else:
blob = dst_blob
y.append(blob)
t_list.append(y)
x = list(map(lambda x: list(x), zip(*self.bot_l)))
ly = list(map(lambda y: list(y), zip(*t_list)))
for d in range(self.ndevices):
on_device = "gpu_" + str(d) + "/"
tag = (on_device + self.tdout, on_device + self.tsout, on_device + self.tint)
with core.DeviceScope(core.DeviceOption(workspace.GpuDeviceType, d)):
self.create_interactions([x[d][-1]], ly[d], self.model, tag)
tag = (self.ttop, self.tint, self.tout)
self.top_l, self.top_w = self.create_mlp(self.ln_top, self.sigmoid_top,
self.model, tag)
self.last_output = self.top_l[-1]
def __init__(
self,
m_spa,
ln_emb,
ln_bot,
ln_top,
arch_interaction_op,
arch_interaction_itself=False,
sigmoid_bot=-1,
sigmoid_top=-1,
save_onnx=False,
model=None,
test_net=None,
tag=None,
ndevices=-1,
forward_ops=True,
enable_prof=False,
weighted_pooling=None,
emb_optimizer="sgd"
):
super(DLRM_Net, self).__init__()
if model is None:
global_init_opt = ["caffe2", "--caffe2_log_level=0"]
if enable_prof:
global_init_opt += [
"--logtostderr=0",
"--log_dir=$HOME",
"--caffe2_logging_print_net_summary=1",
]
workspace.GlobalInit(global_init_opt)
self.set_tags()
self.model = model_helper.ModelHelper(name="DLRM", init_params=True)
self.test_net = None
else:
self.set_tags(tag[0], tag[1], tag[2], tag[3], tag[4], tag[5], tag[6],
tag[7], tag[8], tag[9])
self.model = model
self.test_net = test_net
self.m_spa = m_spa
self.ln_emb = ln_emb
self.ln_bot = ln_bot
self.ln_top = ln_top
self.arch_interaction_op = arch_interaction_op
self.arch_interaction_itself = arch_interaction_itself
self.sigmoid_bot = sigmoid_bot
self.sigmoid_top = sigmoid_top
self.save_onnx = save_onnx
self.ndevices = ndevices
self.emb_optimizer = emb_optimizer
if weighted_pooling is not None and weighted_pooling != "fixed":
self.weighted_pooling = "learned"
else:
self.weighted_pooling = weighted_pooling
if self.save_onnx:
self.onnx_tsd = {}
if forward_ops:
if self.ndevices <= 1:
return self.create_sequential_forward_ops()
else:
return self.create_parallel_forward_ops()
def set_tags(
self,
_tag_layer_top_mlp="top",
_tag_layer_bot_mlp="bot",
_tag_layer_embedding="emb",
_tag_feature_dense_in="dense_in",
_tag_feature_dense_out="dense_out",
_tag_feature_sparse_in="sparse_in",
_tag_feature_sparse_out="sparse_out",
_tag_interaction="interaction",
_tag_dense_output="prob_click",
_tag_dense_target="target",
):
self.ttop = _tag_layer_top_mlp
self.tbot = _tag_layer_bot_mlp
self.temb = _tag_layer_embedding
self.tdin = _tag_feature_dense_in
self.tdout = _tag_feature_dense_out
self.tsin = _tag_feature_sparse_in
self.tsout = _tag_feature_sparse_out
self.tint = _tag_interaction
self.ttar = _tag_dense_target
self.tout = _tag_dense_output
def parameters(self):
return self.model
def get_loss(self):
return self.FetchBlobWrapper(self.loss, reduce_across="add")
def get_output(self):
return self.FetchBlobWrapper(self.last_output, reduce_across="concat")
def create(self, X, S_lengths, S_indices, T):
self.create_input(X, S_lengths, S_indices, T)
self.create_model(X, S_lengths, S_indices, T)
def create_input(self, X, S_lengths, S_indices, T):
self.FeedBlobWrapper(self.tdin, X, split=True)
if self.save_onnx:
self.onnx_tsd[self.tdin] = (onnx.TensorProto.FLOAT, X.shape)
for i in range(len(self.emb_l)):
if self.ndevices > 1:
d = i % self.ndevices
else:
d = -1
on_device = "" if self.ndevices <= 1 else "gpu_" + str(d) + "/"
len_s = on_device + self.temb + ":::" + "sls" + str(i) + "_l"
ind_s = on_device + self.temb + ":::" + "sls" + str(i) + "_i"
self.FeedBlobWrapper(len_s, np.array(S_lengths[i]), False, device_id=d)
self.FeedBlobWrapper(ind_s, np.array(S_indices[i]), False, device_id=d)
if self.save_onnx:
lshape = (len(S_lengths[i]),)
ishape = (len(S_indices[i]),)
self.onnx_tsd[len_s] = (onnx.TensorProto.INT32, lshape)
self.onnx_tsd[ind_s] = (onnx.TensorProto.INT32, ishape)
if T is not None:
zeros_fp32 = np.zeros(T.shape).astype(np.float32)
self.FeedBlobWrapper(self.ttar, zeros_fp32, split=True)
if self.save_onnx:
self.onnx_tsd[self.ttar] = (onnx.TensorProto.FLOAT, T.shape)
def create_model(self, X, S_lengths, S_indices, T):
offset = 1 if self.arch_interaction_itself else 0
num_fea = len(self.emb_l) + 1
tril_indices = np.array([j + i * num_fea
for i in range(num_fea) for j in range(i + offset)])
self.FeedBlobWrapper(self.tint + "_tril_indices", tril_indices)
if T is not None:
workspace.RunNetOnce(self.model.param_init_net)
workspace.CreateNet(self.model.net)
if self.test_net is not None:
workspace.CreateNet(self.test_net)
def run(self, X, S_lengths, S_indices, T, test_net=False, enable_prof=False):
self.FeedBlobWrapper(self.tdin, X, split=True)
for i in range(len(self.emb_l)):
if self.ndevices > 1:
d = i % self.ndevices
else:
d = -1
on_device = "" if self.ndevices <= 1 else "gpu_" + str(d) + "/"
len_s = on_device + self.temb + ":::" + "sls" + str(i) + "_l"
ind_s = on_device + self.temb + ":::" + "sls" + str(i) + "_i"
self.FeedBlobWrapper(len_s, np.array(S_lengths[i]), False, device_id=d)
self.FeedBlobWrapper(ind_s, np.array(S_indices[i]), False, device_id=d)
if T is not None:
self.FeedBlobWrapper(self.ttar, T, split=True)
if test_net:
workspace.RunNet(self.test_net)
else:
if enable_prof:
workspace.C.benchmark_net(self.model.net.Name(), 0, 1, True)
else:
workspace.RunNet(self.model.net)
def MSEloss(self, scale=1.0):
self.AddLayerWrapper(self.model.SquaredL2Distance, [self.tout, self.ttar], "sd")
self.AddLayerWrapper(self.model.Scale, "sd", "sd2", scale=2.0 * scale)
self.loss = self.AddLayerWrapper(self.model.AveragedLoss, "sd2", "loss")
def BCEloss(self, scale=1.0, threshold=0.0):
if 0.0 < threshold and threshold < 1.0:
self.AddLayerWrapper(self.model.Clip, self.tout, "tout_c",
min=threshold, max=(1.0 - threshold))
self.AddLayerWrapper(self.model.MakeTwoClass, "tout_c", "tout_2c")
else:
self.AddLayerWrapper(self.model.MakeTwoClass, self.tout, "tout_2c")
self.AddLayerWrapper(self.model.LabelCrossEntropy, ["tout_2c", self.ttar], "sd")
if scale == 1.0:
self.loss = self.AddLayerWrapper(self.model.AveragedLoss, "sd", "loss")
else:
self.AddLayerWrapper(self.model.Scale, "sd", "sd2", scale=scale)
self.loss = self.AddLayerWrapper(self.model.AveragedLoss, "sd2", "loss")
def sgd_optimizer(self, learning_rate,
T=None, _gradientMap=None, sync_dense_params=True):
if T is not None:
(tag_one, tag_it, tag_lr) = T
else:
(tag_one, tag_it, tag_lr) = ("const_one", "optim_it", "optim_lr")
self.AddLayerWrapper(self.model.param_init_net.ConstantFill,
[], tag_one, shape=[1], value=1.0)
self.AddLayerWrapper(brew.iter, self.model, tag_it)
self.AddLayerWrapper(self.model.LearningRate, tag_it, tag_lr,
base_lr=-1 * learning_rate, policy="fixed")
if self.save_onnx:
self.onnx_tsd[tag_one] = (onnx.TensorProto.FLOAT, (1,))
self.onnx_tsd[tag_it] = (onnx.TensorProto.INT64, (1,))
if _gradientMap is not None:
self.gradientMap = _gradientMap
else:
if self.loss.__class__ == list:
self.gradientMap = self.model.AddGradientOperators(self.loss)
else:
self.gradientMap = self.model.AddGradientOperators([self.loss])
for w in self.top_w:
if sync_dense_params and self.ndevices > 1:
grad_blobs = [
self.gradientMap["gpu_{}/".format(d) + w]
for d in range(self.ndevices)
]
self.model.NCCLAllreduce(grad_blobs, grad_blobs)
self.AddLayerWrapper(self.model.WeightedSum,
[w, tag_one, "", tag_lr], w, reset_grad=True)
for w in self.bot_w:
if sync_dense_params and self.ndevices > 1:
grad_blobs = [
self.gradientMap["gpu_{}/".format(d) + w]
for d in range(self.ndevices)
]
self.model.NCCLAllreduce(grad_blobs, grad_blobs)
self.AddLayerWrapper(self.model.WeightedSum,
[w, tag_one, "", tag_lr], w, reset_grad=True)
for i, w in enumerate(self.emb_w):
if self.ndevices > 1:
d = i % self.ndevices
on_device = "" if self.ndevices <= 1 else "gpu_" + str(d) + "/"
_tag_one = on_device + tag_one
_tag_lr = on_device + tag_lr
w_grad = self.gradientMap[w]
if self.ndevices > 1:
with core.DeviceScope(core.DeviceOption(workspace.GpuDeviceType, d)):
self.model.ScatterWeightedSum([w, _tag_one, w_grad.indices,
w_grad.values, _tag_lr], w)
else:
self.model.ScatterWeightedSum([w, _tag_one, w_grad.indices,
w_grad.values, _tag_lr], w)
if self.weighted_pooling == "learned":
for i, w in enumerate(self.emb_vw):
if self.ndevices > 1:
d = i % self.ndevices
on_device = "" if self.ndevices <= 1 else "gpu_" + str(d) + "/"
_tag_one = on_device + tag_one
_tag_lr = on_device + tag_lr
w_grad = self.gradientMap[w]
if self.ndevices > 1:
with core.DeviceScope(
core.DeviceOption(workspace.GpuDeviceType, d)
):
self.model.ScatterWeightedSum(
[w, _tag_one, w_grad.indices,
w_grad.values, _tag_lr], w
)
else:
self.model.ScatterWeightedSum(
[w, _tag_one, w_grad.indices, w_grad.values, _tag_lr], w
)
def adagrad_optimizer(self, learning_rate,
T=None, _gradientMap=None, sync_dense_params=True,
epsilon=1e-10, decay_=0.0, weight_decay_=0.0):
if T is not None:
(tag_one, tag_it, tag_lr) = T
else:
(tag_one, tag_it, tag_lr) = ("const_one", "optim_it", "optim_lr")
self.AddLayerWrapper(self.model.param_init_net.ConstantFill,
[], tag_one, shape=[1], value=1.0)
self.AddLayerWrapper(brew.iter, self.model, tag_it)
self.AddLayerWrapper(self.model.LearningRate, tag_it, tag_lr,
base_lr=-1 * learning_rate, policy="fixed")
if self.save_onnx:
self.onnx_tsd[tag_one] = (onnx.TensorProto.FLOAT, (1,))
self.onnx_tsd[tag_it] = (onnx.TensorProto.INT64, (1,))
if _gradientMap is not None:
self.gradientMap = _gradientMap
else:
if self.loss.__class__ == list:
self.gradientMap = self.model.AddGradientOperators(self.loss)
else:
self.gradientMap = self.model.AddGradientOperators([self.loss])
for i, w in enumerate(self.top_w):
if sync_dense_params and self.ndevices > 1:
grad_blobs = [
self.gradientMap["gpu_{}/".format(d) + w]
for d in range(self.ndevices)
]
self.model.NCCLAllreduce(grad_blobs, grad_blobs)
self.model.Adagrad(
[
w,
"momentum_mlp_top_{}".format(i + 1),
self.gradientMap[w],
tag_lr
],
[w, "momentum_mlp_top_{}".format(i + 1)],
epsilon=epsilon,
decay_=decay_,
weight_decay_=weight_decay_
)
for i, w in enumerate(self.bot_w):
if sync_dense_params and self.ndevices > 1:
grad_blobs = [
self.gradientMap["gpu_{}/".format(d) + w]
for d in range(self.ndevices)
]
self.model.NCCLAllreduce(grad_blobs, grad_blobs)
self.model.Adagrad(
[
w,
"momentum_mlp_bot_{}".format(i + 1),
self.gradientMap[w],
tag_lr
],
[w, "momentum_mlp_bot_{}".format(i + 1)],
epsilon=epsilon,
decay_=decay_,
weight_decay_=weight_decay_
)
for i, w in enumerate(self.emb_w):
if self.ndevices > 1:
d = i % self.ndevices
on_device = "" if self.ndevices <= 1 else "gpu_" + str(d) + "/"
_tag_one = on_device + tag_one
_tag_lr = on_device + tag_lr
w_grad = self.gradientMap[w]
def add_optimizer():
self.model.Unique(
w_grad.indices,
["unique_w_grad_indices", "remapping_w_grad_indices"]
)
self.model.UnsortedSegmentSum(
[w_grad.values, "remapping_w_grad_indices"],
"unique_w_grad_values"
)
if self.emb_optimizer == "adagrad":
self.model.SparseAdagrad(
[
w,
"momentum_emb_{}".format(i),
"unique_w_grad_indices",
"unique_w_grad_values",
_tag_lr
],
[w, "momentum_emb_{}".format(i)],
epsilon=epsilon,
decay_=decay_,
weight_decay_=weight_decay_
)
elif self.emb_optimizer == "rwsadagrad":
self.model.RowWiseSparseAdagrad(
[
w,
"momentum_emb_{}".format(i),
"unique_w_grad_indices",
"unique_w_grad_values",
_tag_lr
],
[w, "momentum_emb_{}".format(i)],
epsilon=epsilon,
decay_=decay_,
weight_decay_=weight_decay_
)
if self.ndevices > 1:
with core.DeviceScope(core.DeviceOption(workspace.GpuDeviceType, d)):
add_optimizer()
else:
add_optimizer()
if self.weighted_pooling == "learned":
for i, w in enumerate(self.emb_vw):
if self.ndevices > 1:
d = i % self.ndevices
on_device = "" if self.ndevices <= 1 else "gpu_" + str(d) + "/"
_tag_one = on_device + tag_one
_tag_lr = on_device + tag_lr
w_grad = self.gradientMap[w]
if self.ndevices > 1:
with core.DeviceScope(
core.DeviceOption(workspace.GpuDeviceType, d)
):
self.model.ScatterWeightedSum(
[w, _tag_one, w_grad.indices,
w_grad.values, _tag_lr], w
)
else:
self.model.ScatterWeightedSum(
[w, _tag_one, w_grad.indices, w_grad.values, _tag_lr], w
)
def print_all(self):
print(workspace.Blobs(), end='\n')
for _, l in enumerate(workspace.Blobs()):
print(l)
print(self.FetchBlobWrapper(l))
def print_weights(self):
for _, l in enumerate(self.emb_w):
print(self.FetchBlobWrapper(l, False))
if self.weighted_pooling == "learned":
for _, l in enumerate(self.emb_vw):
print(self.FetchBlobWrapper(l, False))
for _, l in enumerate(self.bot_w):
if self.ndevices > 1:
print(self.FetchBlobWrapper(l, False, device_id=0))
else:
print(self.FetchBlobWrapper(l))
for _, l in enumerate(self.top_w):
if self.ndevices > 1:
print(self.FetchBlobWrapper(l, False, device_id=0))
else:
print(self.FetchBlobWrapper(l))
def print_activations(self):
for _, l in enumerate(self.emb_l):
print(l)
print(self.FetchBlobWrapper(l, False))
for _, l in enumerate(self.bot_l):
print(l)
print(self.FetchBlobWrapper(l))
print(self.tint)
print(self.FetchBlobWrapper(self.tint))
for _, l in enumerate(self.top_l):
print(l)
print(self.FetchBlobWrapper(l))
def define_metrics():
metrics = {
'loss': lambda y_true, y_score:
sklearn.metrics.log_loss(
y_true=y_true,
y_pred=y_score,
labels=[0,1]),
'recall': lambda y_true, y_score:
sklearn.metrics.recall_score(
y_true=y_true,
y_pred=np.round(y_score)
),
'precision': lambda y_true, y_score:
sklearn.metrics.precision_score(
y_true=y_true,
y_pred=np.round(y_score)
),
'f1': lambda y_true, y_score:
sklearn.metrics.f1_score(
y_true=y_true,
y_pred=np.round(y_score)
),
'ap': sklearn.metrics.average_precision_score,
'roc_auc': sklearn.metrics.roc_auc_score,
'accuracy': lambda y_true, y_score:
sklearn.metrics.accuracy_score(
y_true=y_true,
y_pred=np.round(y_score)
),
}
return metrics
def calculate_metrics(targets, scores):
scores = np.concatenate(scores, axis=0)
targets = np.concatenate(targets, axis=0)
metrics = define_metrics()
validation_results = {}
for metric_name, metric_function in metrics.items():
try:
validation_results[metric_name] = metric_function(
targets,
scores
)
except Exception as error :
validation_results[metric_name] = -1
print("{} in calculating {}".format(error, metric_name))
return validation_results
if __name__ == "__main__":
import sys
import argparse
parser = argparse.ArgumentParser(
description="Train Deep Learning Recommendation Model (DLRM)"
)
parser.add_argument("--arch-sparse-feature-size", type=int, default=2)
parser.add_argument("--arch-embedding-size", type=str, default="4-3-2")
parser.add_argument("--arch-mlp-bot", type=str, default="4-3-2")
parser.add_argument("--arch-mlp-top", type=str, default="4-2-1")
parser.add_argument("--arch-interaction-op", type=str, default="dot")
parser.add_argument("--arch-interaction-itself", action="store_true", default=False)
parser.add_argument("--activation-function", type=str, default="relu")
parser.add_argument("--loss-function", type=str, default="mse")
parser.add_argument("--loss-threshold", type=float, default=0.0)
parser.add_argument("--round-targets", type=bool, default=False)
parser.add_argument("--weighted-pooling", type=str, default=None)
parser.add_argument("--data-size", type=int, default=1)
parser.add_argument("--num-batches", type=int, default=0)
parser.add_argument("--data-generation", type=str, default="random")
parser.add_argument("--rand-data-dist", type=str, default="uniform")
parser.add_argument("--rand-data-min", type=float, default=0)
parser.add_argument("--rand-data-max", type=float, default=1)
parser.add_argument("--rand-data-mu", type=float, default=-1)
parser.add_argument("--rand-data-sigma", type=float, default=1)
parser.add_argument("--data-trace-file", type=str, default="./input/dist_emb_j.log")
parser.add_argument("--data-set", type=str, default="kaggle")
parser.add_argument("--raw-data-file", type=str, default="")
parser.add_argument("--processed-data-file", type=str, default="")
parser.add_argument("--data-randomize", type=str, default="total")
parser.add_argument("--data-trace-enable-padding", type=bool, default=False)
parser.add_argument("--max-ind-range", type=int, default=-1)
parser.add_argument("--data-sub-sample-rate", type=float, default=0.0)
parser.add_argument("--num-indices-per-lookup", type=int, default=10)
parser.add_argument("--num-indices-per-lookup-fixed", type=bool, default=False)
parser.add_argument("--num-workers", type=int, default=0)
parser.add_argument("--memory-map", action="store_true", default=False)
parser.add_argument("--mini-batch-size", type=int, default=1)
parser.add_argument("--nepochs", type=int, default=1)
parser.add_argument("--learning-rate", type=float, default=0.01)
parser.add_argument("--print-precision", type=int, default=5)
parser.add_argument("--numpy-rand-seed", type=int, default=123)
parser.add_argument("--sync-dense-params", type=bool, default=True)
parser.add_argument("--caffe2-net-type", type=str, default="")
parser.add_argument("--optimizer", type=str, default="sgd",
help="""This is the optimizer for embedding tables.""")
parser.add_argument(
"--dataset-multiprocessing",
action="store_true",
default=False,
help="The Kaggle dataset can be multiprocessed in an environment \
with more than 7 CPU cores and more than 20 GB of memory. \n \
The Terabyte dataset can be multiprocessed in an environment \
with more than 24 CPU cores and at least 1 TB of memory.",
)
parser.add_argument("--inference-only", action="store_true", default=False)
parser.add_argument("--save-onnx", action="store_true", default=False)
parser.add_argument("--save-proto-types-shapes", action="store_true", default=False)
parser.add_argument("--use-gpu", action="store_true", default=False)
parser.add_argument("--print-freq", type=int, default=1)
parser.add_argument("--test-freq", type=int, default=-1)
parser.add_argument("--test-mini-batch-size", type=int, default=-1)
parser.add_argument("--test-num-workers", type=int, default=-1)
parser.add_argument("--print-time", action="store_true", default=False)
parser.add_argument("--debug-mode", action="store_true", default=False)
parser.add_argument("--enable-profiling", action="store_true", default=False)
parser.add_argument("--plot-compute-graph", action="store_true", default=False)
parser.add_argument("--mlperf-logging", action="store_true", default=False)
parser.add_argument("--mlperf-acc-threshold", type=float, default=0.0)
parser.add_argument("--mlperf-auc-threshold", type=float, default=0.0)
args = parser.parse_args()
if args.dataset_multiprocessing:
assert float(sys.version[:3]) > 3.7, "The dataset_multiprocessing " + \
"flag is susceptible to a bug in Python 3.7 and under. " + \
"https://github.com/facebookresearch/dlrm/issues/172"
np.random.seed(args.numpy_rand_seed)
np.set_printoptions(precision=args.print_precision)
if (args.test_mini_batch_size < 0):
args.test_mini_batch_size = args.mini_batch_size
if (args.test_num_workers < 0):
args.test_num_workers = args.num_workers
use_gpu = args.use_gpu
if use_gpu:
device_opt = core.DeviceOption(workspace.GpuDeviceType, 0)
ngpus = workspace.NumGpuDevices()
print("Using {} GPU(s)...".format(ngpus))
else:
device_opt = core.DeviceOption(caffe2_pb2.CPU)
print("Using CPU...")
ln_bot = np.fromstring(args.arch_mlp_bot, dtype=int, sep="-")
if args.data_generation == "dataset":
if args.num_workers > 0 or args.test_num_workers > 0:
print("WARNING: non default --num-workers or --test-num-workers options"
+ " are not supported and will be ignored")
if args.mini_batch_size != args.test_mini_batch_size:
print("WARNING: non default ----test-mini-batch-size option"
+ " is not supported and will be ignored")
train_data, train_ld, test_data, test_ld = \
dp.make_criteo_data_and_loaders(
args,
offset_to_length_converter=True,
)
nbatches = args.num_batches if args.num_batches > 0 \
else len(train_ld)
nbatches_test = len(test_ld)
ln_emb = train_data.counts
m_den = train_data.m_den
if args.max_ind_range > 0:
ln_emb = np.array(list(map(
lambda x: x if x < args.max_ind_range else args.max_ind_range,
ln_emb
)))
ln_bot[0] = m_den
else:
if args.num_workers > 0 or args.test_num_workers > 0:
print("WARNING: non default --num-workers or --test-num-workers options"
+ " are not supported and will be ignored")
if args.mini_batch_size != args.test_mini_batch_size:
print("WARNING: non default ----test-mini-batch-size option"
+ " is not supported and will be ignored")
ln_emb = np.fromstring(args.arch_embedding_size, dtype=int, sep="-")
m_den = ln_bot[0]
train_data, train_ld, test_data, test_ld = dp.make_random_data_and_loader(args, ln_emb, m_den, \
offset_to_length_converter=True,
)
nbatches = args.num_batches if args.num_batches > 0 else len(train_ld)
nbatches_test = len(test_ld)
m_spa = args.arch_sparse_feature_size
ln_emb = np.asarray(ln_emb)
num_fea = ln_emb.size + 1
m_den_out = ln_bot[ln_bot.size - 1]
if args.arch_interaction_op == "dot":
if args.arch_interaction_itself:
num_int = (num_fea * (num_fea + 1)) // 2 + m_den_out
else:
num_int = (num_fea * (num_fea - 1)) // 2 + m_den_out
elif args.arch_interaction_op == "cat":
num_int = num_fea * m_den_out
else:
sys.exit("ERROR: --arch-interaction-op="
+ args.arch_interaction_op + " is not supported")
arch_mlp_top_adjusted = str(num_int) + "-" + args.arch_mlp_top
ln_top = np.fromstring(arch_mlp_top_adjusted, dtype=int, sep="-")
if m_den != ln_bot[0]:
sys.exit("ERROR: arch-dense-feature-size "
+ str(m_den) + " does not match first dim of bottom mlp " + str(ln_bot[0]))
if m_spa != m_den_out:
sys.exit("ERROR: arch-sparse-feature-size "
+ str(m_spa) + " does not match last dim of bottom mlp " + str(m_den_out))
if num_int != ln_top[0]:
sys.exit("ERROR: # of feature interactions "
+ str(num_int) + " does not match first dim of top mlp " + str(ln_top[0]))
if args.debug_mode:
print("model arch:")
print("mlp top arch " + str(ln_top.size - 1)
+ " layers, with input to output dimensions:")
print(ln_top)
print("# of interactions")
print(num_int)
print("mlp bot arch " + str(ln_bot.size - 1)
+ " layers, with input to output dimensions:")
print(ln_bot)
print("# of features (sparse and dense)")
print(num_fea)
print("dense feature size")
print(m_den)
print("sparse feature size")
print(m_spa)
print("# of embeddings (= # of sparse features) " + str(ln_emb.size)
+ ", with dimensions " + str(m_spa) + "x:")
print(ln_emb)
print("data (inputs and targets):")
for j, inputBatch in enumerate(train_ld):
lX_j, lS_l_j, lS_i_j, lT_j = inputBatch
print("mini-batch: %d" % j)
print(lX_j)
print(lS_l_j)
print(lS_i_j)
print(lT_j)
ndevices = min(ngpus, args.mini_batch_size, num_fea - 1) if use_gpu else -1
flag_types_shapes = args.save_onnx or args.save_proto_types_shapes
flag_forward_ops = not (use_gpu and ndevices > 1)
with core.DeviceScope(device_opt):
dlrm = DLRM_Net(
m_spa,
ln_emb,
ln_bot,
ln_top,
args.arch_interaction_op,
arch_interaction_itself=args.arch_interaction_itself,
sigmoid_bot=-1,
sigmoid_top=ln_top.size - 1,
save_onnx=flag_types_shapes,
ndevices=ndevices,
enable_prof=args.enable_profiling,
weighted_pooling=args.weighted_pooling,
emb_optimizer=args.optimizer
)
if args.sync_dense_params and ndevices > 1:
dyndep.InitOpsLibrary("//caffe2/caffe2/contrib/nccl:nccl_ops")
if args.caffe2_net_type:
dlrm.parameters().net.Proto().type = args.caffe2_net_type
if args.plot_compute_graph:
graph = net_drawer.GetPydotGraph(
dlrm.parameters().net,
"dlrm_s_caffe2_graph",
"BT"
)
graph.write_pdf(graph.get_name() + ".pdf")
if args.debug_mode:
print("initial parameters (weights and bias):")
dlrm.print_weights()
if not args.inference_only:
with core.DeviceScope(device_opt):
nd = 1.0 if dlrm.ndevices <= 1 else 1.0 / dlrm.ndevices
if args.loss_function == "mse":
dlrm.MSEloss(scale=nd)
elif args.loss_function == "bce":
dlrm.BCEloss(scale=nd, threshold=args.loss_threshold)
else:
sys.exit("ERROR: --loss-function=" + args.loss_function
+ " is not supported")
dlrm.test_net = core.Net(copy.deepcopy(dlrm.model.net.Proto()))
if args.optimizer == "sgd":
dlrm.sgd_optimizer(
args.learning_rate, sync_dense_params=args.sync_dense_params
)
elif args.optimizer in ["adagrad", "rwsadagrad"]:
dlrm.adagrad_optimizer(
args.learning_rate, sync_dense_params=args.sync_dense_params
)
else:
sys.exit("""ERROR: Select an optimizer for
embedding tables : 'sgd', 'adagrad',
or 'rwsadagrad' """)
X, lS_l, lS_i, T = next(iter(train_ld))
dlrm.create(X, lS_l, lS_i, T.int())
best_gA_test = 0
best_auc_test = 0
total_time = 0
total_loss = 0
total_accu = 0
total_iter = 0
total_samp = 0
k = 0
print("time/loss/accuracy (if enabled):")
while k < args.nepochs:
j = 0
for j, inputBatch in enumerate(train_ld):
time1 = time.time()
lX_j, lS_l_j, lS_i_j, lT_j = inputBatch
lT_j = lT_j.int() if args.loss_function == "bce" else lT_j
dlrm.run(lX_j, lS_l_j, lS_i_j, lT_j)
time2 = time.time()
total_time += time2 - time1
Z = dlrm.get_output()
T = lT_j.numpy()
'''
# debug prints
print("output and loss")
print(Z)
print(dlrm.get_loss())
'''
mbs = T.shape[0]
A = np.sum((np.round(Z, 0) == T).astype(np.uint8))
total_accu += 0 if args.inference_only else A
total_loss += 0 if args.inference_only else dlrm.get_loss() * mbs
total_iter += 1
total_samp += mbs
should_print = ((j + 1) % args.print_freq == 0) or (j + 1 == nbatches)
should_test = (
(args.test_freq > 0)
and (args.data_generation in ["dataset", "random"])
and (((j + 1) % args.test_freq == 0) or (j + 1 == nbatches))
)
if should_print or should_test:
gT = 1000. * total_time / total_iter if args.print_time else -1
total_time = 0
gA = total_accu / total_samp
total_accu = 0
gL = total_loss / total_samp
total_loss = 0
str_run_type = "inference" if args.inference_only else "training"
print(
"Finished {} it {}/{} of epoch {}, {:.2f} ms/it,".format(
str_run_type, j + 1, nbatches, k, gT
)
+ " loss {:.6f}".format(gL)
)
total_iter = 0
total_samp = 0
if should_test and not args.inference_only:
if args.mlperf_logging:
previous_iteration_time = None
test_accu = 0
test_loss = 0
test_samp = 0
if args.mlperf_logging:
scores = []
targets = []
for i, testBatch in enumerate(test_ld):
if nbatches > 0 and i >= nbatches:
break
lX_test_i, lS_l_test_i, lS_i_test_i, lT_test_i = testBatch
lT_test_i = lT_test_i.int() if args.loss_function == "bce" else lT_test_i
dlrm.run(lX_test_i, lS_l_test_i, lS_i_test_i, lT_test_i, test_net=True)
Z_test = dlrm.get_output()
T_test = lT_test_i.numpy()
if args.mlperf_logging:
scores.append(Z_test)
targets.append(T_test)
else:
L_test = dlrm.get_loss()
mbs_test = T_test.shape[0]
A_test = np.sum((np.round(Z_test, 0) == T_test).astype(np.uint8))
test_accu += A_test
test_loss += L_test * mbs_test
test_samp += mbs_test
if args.mlperf_logging:
validation_results = calculate_metrics(targets, scores)
gA_test = validation_results['accuracy']
gL_test = validation_results['loss']
else:
gA_test = test_accu / test_samp
gL_test = test_loss / test_samp
is_best = gA_test > best_gA_test
if is_best:
best_gA_test = gA_test
if args.mlperf_logging:
is_best = validation_results['roc_auc'] > best_auc_test
if is_best:
best_auc_test = validation_results['roc_auc']
print(
"Testing at - {}/{} of epoch {},".format(j + 1, nbatches, k)
+ " loss {:.6f}, recall {:.4f}, precision {:.4f},".format(
validation_results['loss'],
validation_results['recall'],
validation_results['precision']
)
+ " f1 {:.4f}, ap {:.4f},".format(
validation_results['f1'],
validation_results['ap'],
)
+ " auc {:.4f}, best auc {:.4f},".format(
validation_results['roc_auc'],
best_auc_test
)
+ " accuracy {:3.3f} %, best accuracy {:3.3f} %".format(
validation_results['accuracy'] * 100,
best_gA_test * 100
)
)
else:
print(
"Testing at - {}/{} of epoch {},".format(j + 1, nbatches, 0)
+ " loss {:.6f}, accuracy {:3.3f} %, best {:3.3f} %".format(
gL_test, gA_test * 100, best_gA_test * 100
)
)
if (args.mlperf_logging
and (args.mlperf_acc_threshold > 0)
and (best_gA_test > args.mlperf_acc_threshold)):
print("MLPerf testing accuracy threshold "
+ str(args.mlperf_acc_threshold)
+ " reached, stop training")
break
if (args.mlperf_logging
and (args.mlperf_auc_threshold > 0)
and (best_auc_test > args.mlperf_auc_threshold)):
print("MLPerf testing auc threshold "
+ str(args.mlperf_auc_threshold)
+ " reached, stop training")
break
j += 1
k += 1
if not args.inference_only and args.debug_mode:
print("updated parameters (weights and bias):")
dlrm.print_weights()
if args.save_onnx:
pnet = dlrm.parameters().net.Proto()
inet = dlrm.parameters().param_init_net.Proto()
value_info = dlrm.onnx_tsd
c2f = caffe2.python.onnx.frontend.Caffe2Frontend()
dlrm_caffe2_onnx = c2f.caffe2_net_to_onnx_model(pnet, inet, value_info)
onnx.checker.check_model(dlrm_caffe2_onnx)
with open("dlrm_s_caffe2.onnx", "w+") as dlrm_caffe2_onnx_file:
dlrm_caffe2_onnx_file.write(str(dlrm_caffe2_onnx))
if args.save_proto_types_shapes:
__TYPE_MAPPING = {
onnx.TensorProto.FLOAT: caffe2_pb2.TensorProto.FLOAT,
onnx.TensorProto.UINT8: caffe2_pb2.TensorProto.UINT8,
onnx.TensorProto.INT8: caffe2_pb2.TensorProto.INT8,
onnx.TensorProto.UINT16: caffe2_pb2.TensorProto.UINT16,
onnx.TensorProto.INT16: caffe2_pb2.TensorProto.INT16,
onnx.TensorProto.INT32: caffe2_pb2.TensorProto.INT32,
onnx.TensorProto.INT64: caffe2_pb2.TensorProto.INT64,
onnx.TensorProto.STRING: caffe2_pb2.TensorProto.STRING,
onnx.TensorProto.BOOL: caffe2_pb2.TensorProto.BOOL,
onnx.TensorProto.FLOAT16: caffe2_pb2.TensorProto.FLOAT16,
onnx.TensorProto.DOUBLE: caffe2_pb2.TensorProto.DOUBLE,
}
pnet = dlrm.parameters().net.Proto()
arg = pnet.arg.add()
arg.name = "input_shape_info"
for i in pnet.external_input:
if i in dlrm.onnx_tsd:
onnx_dtype, shape = dlrm.onnx_tsd[i]
t = arg.tensors.add()
t.name = i
t.data_type = __TYPE_MAPPING[onnx_dtype]
t.dims.extend(shape)
else:
print("Warning: we don't have shape/type info for input: {}".format(i))
with open("dlrm_s_caffe2.proto", "w+") as dlrm_s_proto_file:
dlrm_s_proto_file.write(str(pnet))
"""
# export the protobuf with types and shapes as well as weights
# see https://github.com/pytorch/pytorch/issues/9533
#save
net = dlrm.parameters().net
params = dlrm.parameters().params
init_net, predict_net = mobile_exporter.Export(workspace, net, params)
with open("dlrm_s_caffe2.predict", "wb") as dlrm_s_predict_file:
dlrm_s_predict_file.write(predict_net.SerializeToString())
with open("dlrm_s_caffe2.init", "wb") as dlrm_s_init_file:
dlrm_s_init_file.write(init_net.SerializeToString())
#load
net_def = caffe2_pb2.NetDef()
init_def= caffe2_pb2.NetDef()
with open("dlrm_s_caffe2.predict", "rb") as dlrm_s_predict_file:
net_def.ParseFromString(dlrm_s_predict_file.read())
print(net_def)
with open("dlrm_s_caffe2.init", "rb") as dlrm_s_init_file:
init_def.ParseFromString(dlrm_s_init_file.read())
print(init_def)
"""