Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <algorithm>
#include <atomic>
#include <map>
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "key_process/key_process.h"
#include "key_process/feature_admit_and_evict.h"
#include "utils/common.h"
#include "utils/safe_queue.h"
#include "utils/singleton.h"
#include "utils/time_cost.h"
#include "error/error.h"
using namespace tensorflow;
using shape_inference::InferenceContext;
using shape_inference::ShapeHandle;
using namespace std;
using namespace chrono;
using namespace MxRec;
using OpKernelConstructionPtr = OpKernelConstruction*;
using OpKernelContextPtr = OpKernelContext*;
using InferenceContextPtr = ::tensorflow::shape_inference::InferenceContext*;
namespace {
static TimeCost g_staticSw{};
}
namespace MxRec {
class ClearChannel : public OpKernel {
public:
explicit ClearChannel(OpKernelConstructionPtr context) : OpKernel(context)
{
LOG_DEBUG("clear channel init");
OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId));
if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat(
"ClearChannel channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)",
MAX_CHANNEL_NUM)));
return;
}
}
~ClearChannel() override = default;
void Compute(OpKernelContextPtr context) override
{
LOG_DEBUG("clear channel {}, context {}", channelId, context->step_id());
HybridMgmtBlock* hybridMgmtBlock = Singleton<HybridMgmtBlock>::GetInstance();
hybridMgmtBlock->ResetAll(channelId);
}
private:
int channelId {};
};
class SetThreshold : public OpKernel {
public:
explicit SetThreshold(OpKernelConstructionPtr context) : OpKernel(context)
{
LOG_DEBUG("SetThreshold init");
OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embName));
OP_REQUIRES_OK(context, context->GetAttr("ids_name", &idsName));
}
~SetThreshold() override = default;
void Compute(OpKernelContextPtr context) override
{
LOG_DEBUG("enter SetThreshold");
int threshold = 1;
const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0);
int available = ParseThresholdAndCheck(inputTensor, threshold);
if (available == 0) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ",
StringFormat("threshold[%d] error", threshold)));
return;
}
if (!FeatureAdmitAndEvict::m_cfgThresholds.empty()) {
auto keyProcess = Singleton<KeyProcess>::GetInstance();
if (!keyProcess->isRunning) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running."));
return;
}
if (!keyProcess->GetFeatAdmitAndEvict().SetTableThresholds(threshold, embName)) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "threshold set error ...")
);
return;
}
} else {
LOG_DEBUG("SetThreshold failed, because feature admit-and-evict switch is closed");
}
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output));
auto out = output->flat<int32>();
out(0) = available;
}
int ParseThresholdAndCheck(const Tensor& inputTensor, int& threshold) const
{
auto src = reinterpret_cast<const int*>(inputTensor.tensor_data().data());
std::copy(src, src + 1, &threshold);
if (threshold < 0) {
auto error = Error(ModuleName::M_DATASET_OPS, ErrorType::INVALID_ARGUMENT,
StringFormat("Threshold should >= 0, get:%d.", threshold));
LOG_ERROR(error.ToString());
return 0;
}
LOG_DEBUG("ParseThresholdAndCheck, emb_name:[{}], ids_name: [{}], threshold: [{}]",
embName, idsName, threshold);
return 1;
}
private:
string embName {};
string idsName {};
};
class ReturnTimestamp : public OpKernel {
public:
explicit ReturnTimestamp(OpKernelConstructionPtr context) : OpKernel(context)
{}
~ReturnTimestamp() override = default;
void Compute(OpKernelContextPtr context) override
{
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output));
auto out = output->flat<int64>();
out(0) = time(nullptr);
}
};
class ReadEmbKeyV2Dynamic : public OpKernel {
public:
explicit ReadEmbKeyV2Dynamic(OpKernelConstructionPtr context) : OpKernel(context)
{
LOG_DEBUG("ReadEmbKeyV2Dynamic init");
OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId));
OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames));
OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp));
hybridMgmtBlock = Singleton<HybridMgmtBlock>::GetInstance();
if (!FeatureAdmitAndEvict::m_cfgThresholds.empty() &&
!FeatureAdmitAndEvict::IsThresholdCfgOK(
FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp)
) {
context->SetStatus(
errors::Aborted(__FILE__, ":", __LINE__, " ", "threshold config, or timestamp error ..."));
return;
}
if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat(
"ReadEmbKeyV2Dynamic channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)",
MAX_CHANNEL_NUM)));
return;
}
LOG_DEBUG(HYBRID_BLOCKING + " reset channel {}", channelId);
hybridMgmtBlock->ResetAll(channelId);
Singleton<HDTransfer>::GetInstance()->ClearTransChannel(channelId);
threadNum = GetThreadNumEnv();
if (threadNum <= 0) {
context->SetStatus(
errors::Aborted(__FILE__, ":", __LINE__, " ", "ThreadNum invalid. It should be bigger than 0 ..."));
return;
}
auto keyProcess = Singleton<KeyProcess>::GetInstance();
if (!keyProcess->isRunning) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running."));
return;
}
maxStep = keyProcess->GetMaxStep(channelId);
}
~ReadEmbKeyV2Dynamic() override = default;
void Compute(OpKernelContextPtr context) override
{
LOG_DEBUG("enter ReadEmbKeyV2Dynamic");
TimeCost tc = TimeCost();
int batchId = hybridMgmtBlock->readEmbedBatchId[channelId];
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output));
auto out = output->flat<int32>();
out(0) = batchId;
if (channelId == 1) {
if (maxStep != -1 && batchId >= maxStep) {
LOG_DEBUG("skip excess batch after {}/{}", batchId, maxStep);
return;
}
}
hybridMgmtBlock->readEmbedBatchId[channelId] += 1;
const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0);
const auto& splits = context->input(TENSOR_INDEX_1).flat<int32>();
int fieldNum = 0;
for (int i = 0; i < splits.size(); ++i) {
fieldNum += splits(i);
}
size_t dataSize = inputTensor.NumElements();
time_t timestamp = -1;
if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat(
"timestamp[%d] error, skip excess batch after %d/%d", timestamp, batchId, maxStep)));
return;
}
SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus);
int batchQueueId = (batchId % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId);
TimeCost enqueueTC;
EnqueueBatchData(std::vector<int>{batchId, batchQueueId}, timestamp, inputTensor, splits);
LOG_DEBUG(KEY_PROCESS "ReadEmbKeyV2Dynamic read batch cost(ms):{}, elapsed from last(ms):{},"
" enqueueTC(ms):{}, batch[{}]:{}",
tc.ElapsedMS(), g_staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId);
g_staticSw = TimeCost();
}
void CheckEmbTables()
{
auto keyProcess = Singleton<KeyProcess>::GetInstance();
for (size_t i = 0; i < embNames.size(); ++i) {
if (!keyProcess->HasEmbName(embNames.at(i))) {
LOG_DEBUG("ReadEmbKeyV2Dynamic not found emb_name:{} {}", i, embNames.at(i));
tableUsed.push_back(false);
} else {
tableUsed.push_back(true);
}
}
}
void EnqueueBatchData(std::vector<int> ids, time_t timestamp,
const Tensor& inputTensor, const TTypes<int32>::ConstFlat& splits)
{
if (tableUsed.empty()) {
CheckEmbTables();
}
auto queue = SingletonQueue<EmbBatchT>::GetInstances(ids[1]);
size_t offset = 0;
if (isTimestamp) {
offset += 1;
}
for (int i = 0; i < splits.size(); ++i) {
if (!tableUsed.at(i)) {
offset += splits(i);
continue;
}
auto batchData = queue->GetOne();
batchData->name = embNames.at(i);
size_t len = splits(i);
batchData->channel = channelId;
batchData->isEos = false;
batchData->batchId = ids[0];
batchData->sample.resize(len);
if (isTimestamp) {
batchData->timestamp = timestamp;
}
if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) {
auto src = reinterpret_cast<const int32_t*>(inputTensor.tensor_data().data());
copy(src + offset, src + offset + len, batchData->sample.data());
} else {
auto src = reinterpret_cast<const int64_t*>(inputTensor.tensor_data().data());
copy(src + offset, src + offset + len, batchData->sample.data());
}
offset += len;
queue->Pushv(move(batchData));
}
}
bool ParseTimestampAndCheck(const Tensor& inputTensor, int batchId, int fieldNumTmp, time_t& timestamp,
size_t& dataSize) const
{
if (dataSize - fieldNumTmp != 1) {
auto error = Error(ModuleName::M_DATASET_OPS, ErrorType::INVALID_ARGUMENT,
StringFormat("Timestamp field not found, dataSize:%ld, fieldNum:%d.",
dataSize, fieldNumTmp));
LOG_ERROR(error.ToString());
return false;
}
auto src = reinterpret_cast<const time_t*>(inputTensor.tensor_data().data());
std::copy(src, src + 1, ×tamp);
LOG_DEBUG("current batchId[{}] timestamp[{}]", batchId, timestamp);
dataSize -= 1;
if (timestamp <= 0) {
auto error = Error(ModuleName::M_DATASET_OPS, ErrorType::INVALID_ARGUMENT,
StringFormat("Timestamp should greater than 0, get:%ld.", timestamp));
LOG_ERROR(error.ToString());
return false;
}
return true;
}
void SetCurrEmbNamesStatus(const vector<string>& embeddingNames,
absl::flat_hash_map<std::string, SingleEmbTableStatus>& embStatus) const
{
for (size_t i = 0; i < embeddingNames.size(); ++i) {
auto it = embStatus.find(embeddingNames[i]);
if (it == embStatus.end()) {
embStatus.insert(std::pair<std::string,
SingleEmbTableStatus>(embeddingNames[i], SingleEmbTableStatus::SETS_NONE));
}
}
}
int channelId {};
vector<string> embNames {};
vector<bool> tableUsed{};
int maxStep = 0;
bool isTimestamp { false };
int threadNum = 0;
HybridMgmtBlock* hybridMgmtBlock;
};
class ReadEmbKeyV2 : public OpKernel {
public:
explicit ReadEmbKeyV2(OpKernelConstructionPtr context) : OpKernel(context)
{
LOG_DEBUG("ReadEmbKeyV2 init");
OP_REQUIRES_OK(context, context->GetAttr("channel_id", &channelId));
OP_REQUIRES_OK(context, context->GetAttr("emb_name", &embNames));
OP_REQUIRES_OK(context, context->GetAttr("splits", &splits));
OP_REQUIRES_OK(context, context->GetAttr("timestamp", &isTimestamp));
fieldNum = accumulate(splits.begin(), splits.end(), 0);
hybridMgmtBlock = Singleton<HybridMgmtBlock>::GetInstance();
if (!FeatureAdmitAndEvict::m_cfgThresholds.empty() &&
!FeatureAdmitAndEvict::IsThresholdCfgOK(FeatureAdmitAndEvict::m_cfgThresholds, embNames, isTimestamp)) {
context->SetStatus(
errors::Aborted(__FILE__, ":", __LINE__, " ", "threshold config, or timestamp error ...")
);
return;
}
if (splits.size() != embNames.size()) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat(
"splits & embNames size error.%d %d", splits.size(), embNames.size())));
return;
}
if (channelId < 0 || channelId >= MAX_CHANNEL_NUM) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat(
"ReadEmbKeyV2 channelId invalid. It should be in range [0, MAX_CHANNEL_NUM:%d)",
MAX_CHANNEL_NUM)));
return;
}
LOG_DEBUG(HYBRID_BLOCKING + " reset channel {}", channelId);
hybridMgmtBlock->ResetAll(channelId);
Singleton<HDTransfer>::GetInstance()->ClearTransChannel(channelId);
threadNum = GetThreadNumEnv();
if (threadNum <= 0) {
context->SetStatus(
errors::Aborted(__FILE__, ":", __LINE__, " ", "ThreadNum invalid. It should be bigger than 0 ..."));
return;
}
auto keyProcess = Singleton<KeyProcess>::GetInstance();
if (!keyProcess->isRunning) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", "KeyProcess not running."));
return;
}
maxStep = keyProcess->GetMaxStep(channelId);
}
~ReadEmbKeyV2() override = default;
void Compute(OpKernelContextPtr context) override
{
LOG_DEBUG("enter ReadEmbKeyV2");
TimeCost tc = TimeCost();
int batchId = hybridMgmtBlock->readEmbedBatchId[channelId];
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, TensorShape {}, &output));
auto out = output->flat<int32>();
out(0) = batchId;
if (channelId == 1) {
if (maxStep != -1 && batchId >= maxStep) {
LOG_DEBUG(StringFormat("skip excess batch after %d/%d", batchId, maxStep));
return;
}
}
hybridMgmtBlock->readEmbedBatchId[channelId] += 1;
const Tensor& inputTensor = context->input(TensorIndex::TENSOR_INDEX_0);
size_t dataSize = inputTensor.NumElements();
time_t timestamp = -1;
if (isTimestamp && !ParseTimestampAndCheck(inputTensor, batchId, fieldNum, timestamp, dataSize)) {
context->SetStatus(errors::Aborted(__FILE__, ":", __LINE__, " ", StringFormat(
"timestamp[%d] error, skip excess batch after %d/%d", timestamp, batchId, maxStep)));
return;
}
SetCurrEmbNamesStatus(embNames, FeatureAdmitAndEvict::m_embStatus);
int batchQueueId = (batchId % threadNum) + (MAX_KEY_PROCESS_THREAD * channelId);
TimeCost enqueueTC;
EnqueueBatchData(batchId, batchQueueId, timestamp, inputTensor);
LOG_DEBUG(KEY_PROCESS "ReadEmbKeyV2Static read batch cost(ms):{}, elapsed from last(ms):{},"
" enqueueTC(ms):{}, batch[{}]:{}",
tc.ElapsedMS(), g_staticSw.ElapsedMS(), enqueueTC.ElapsedMS(), channelId, batchId);
g_staticSw = TimeCost();
}
void CheckEmbTables()
{
auto keyProcess = Singleton<KeyProcess>::GetInstance();
for (size_t i = 0; i < splits.size(); ++i) {
if (!keyProcess->HasEmbName(embNames.at(i))) {
LOG_DEBUG("ReadEmbKeyV2 not found emb_name:{} {}", i, embNames.at(i));
tableUsed.push_back(false);
} else {
tableUsed.push_back(true);
}
}
}
void EnqueueBatchData(int batchId, int batchQueueId, time_t timestamp, const Tensor& inputTensor)
{
if (tableUsed.empty()) {
CheckEmbTables();
}
auto queue = SingletonQueue<EmbBatchT>::GetInstances(batchQueueId);
size_t offset = 0;
if (isTimestamp) {
offset += 1;
}
for (size_t i = 0; i < splits.size(); ++i) {
if (!tableUsed.at(i)) {
offset += splits.at(i);
continue;
}
auto batchData = queue->GetOne();
batchData->name = embNames.at(i);
size_t len = splits.at(i);
batchData->channel = channelId;
batchData->isEos = false;
batchData->batchId = batchId;
batchData->sample.resize(len);
if (isTimestamp) {
batchData->timestamp = timestamp;
}
if (inputTensor.dtype() == tensorflow::DT_INT32 || inputTensor.dtype() == tensorflow::DT_INT32_REF) {
auto src = reinterpret_cast<const int32_t*>(inputTensor.tensor_data().data());
copy(src + offset, src + offset + len, batchData->sample.data());
} else {
auto src = reinterpret_cast<const int64_t*>(inputTensor.tensor_data().data());
copy(src + offset, src + offset + len, batchData->sample.data());
}
offset += len;
queue->Pushv(move(batchData));
}
}
bool ParseTimestampAndCheck(const Tensor& inputTensor, int batchId, int fieldNumTmp, time_t& timestamp,
size_t& dataSize) const
{
if (dataSize - fieldNumTmp != 1) {
auto error = Error(ModuleName::M_DATASET_OPS, ErrorType::INVALID_ARGUMENT,
StringFormat("Timestamp field not found, dataSize:%ld, fieldNum:%d.",
dataSize, fieldNumTmp));
LOG_ERROR(error.ToString());
return false;
}
auto src = reinterpret_cast<const time_t*>(inputTensor.tensor_data().data());
std::copy(src, src + 1, ×tamp);
LOG_DEBUG("current batchId[{}] timestamp[{}]", batchId, timestamp);
dataSize -= 1;
if (timestamp <= 0) {
auto error = Error(ModuleName::M_DATASET_OPS, ErrorType::INVALID_ARGUMENT,
StringFormat("Timestamp should greater than 0, get:%ld.", timestamp));
LOG_ERROR(error.ToString());
return false;
}
return true;
}
void SetCurrEmbNamesStatus(const vector<string>& embeddingNames,
absl::flat_hash_map<std::string, SingleEmbTableStatus>& embStatus) const
{
for (size_t i = 0; i < embeddingNames.size(); ++i) {
auto it = embStatus.find(embeddingNames[i]);
if (it == embStatus.end()) {
embStatus.insert(std::pair<std::string,
SingleEmbTableStatus>(embeddingNames[i], SingleEmbTableStatus::SETS_NONE));
}
}
}
int channelId {};
vector<int> splits {};
vector<bool> tableUsed{};
int fieldNum {};
vector<string> embNames {};
int maxStep = 0;
bool isTimestamp { false };
int threadNum = KEY_PROCESS_THREAD;
HybridMgmtBlock* hybridMgmtBlock;
};
}
namespace tensorflow {
REGISTER_OP("ClearChannel").Attr("channel_id : int");
REGISTER_KERNEL_BUILDER(Name("ClearChannel").Device(DEVICE_CPU), MxRec::ClearChannel);
REGISTER_OP("SetThreshold")
.Input("input: int32")
.Attr("emb_name: string = ''")
.Attr("ids_name: string = ''")
.Output("output: int32")
.SetShapeFn([](InferenceContextPtr c) {
c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar());
return Status::OK();
});
REGISTER_KERNEL_BUILDER(Name("SetThreshold").Device(DEVICE_CPU), MxRec::SetThreshold);
REGISTER_OP("ReturnTimestamp")
.Input("input: int64")
.Output("output: int64")
.SetShapeFn([](InferenceContextPtr c) {
c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar());
return Status::OK();
});
REGISTER_KERNEL_BUILDER(Name("ReturnTimestamp").Device(DEVICE_CPU), MxRec::ReturnTimestamp);
REGISTER_OP("ReadEmbKeyV2Dynamic")
.Input("sample: T")
.Input("splits: int32")
.Output("output: int32")
.Attr("T: {int64, int32}")
.Attr("channel_id: int")
.Attr("emb_name: list(string)")
.Attr("timestamp: bool")
.SetShapeFn([](InferenceContextPtr c) {
c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar());
return Status::OK();
});
REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2Dynamic").Device(DEVICE_CPU), MxRec::ReadEmbKeyV2Dynamic);
REGISTER_OP("ReadEmbKeyV2")
.Input("sample: T")
.Output("output: int32")
.Attr("T: {int64, int32}")
.Attr("channel_id: int")
.Attr("splits: list(int)")
.Attr("emb_name: list(string)")
.Attr("timestamp: bool")
.SetShapeFn([](InferenceContextPtr c) {
c->set_output(TensorIndex::TENSOR_INDEX_0, c->Scalar());
return Status::OK();
});
REGISTER_KERNEL_BUILDER(Name("ReadEmbKeyV2").Device(DEVICE_CPU), MxRec::ReadEmbKeyV2);
}