* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "driver.h"
#include "codegen/time_util.h"
#include <memory>
namespace omniruntime::compute {
std::atomic_uint64_t BlockingState::numBlockdDrivers_{0};
BlockingState::BlockingState(
std::shared_ptr<OmniDriver> driver,
ContinueFuture &&future,
omniruntime::op::Operator *op,
BlockingReason reason)
: driver_(std::move(driver)),
future_(std::move(future)),
operator_(op),
reason_(reason)
{
numBlockdDrivers_++;
}
vec::VectorBatch *OmniDriver::Next(ContinueFuture *future, StopReason *stopReason)
{
auto self = shared_from_this();
std::shared_ptr<BlockingState> blockingState;
vec::VectorBatch *result = nullptr;
*stopReason = RunInternal(self, blockingState, &result);
if (blockingState != nullptr) {
*future = blockingState->Future();
return nullptr;
}
if (*stopReason == StopReason::kPause) {
return nullptr;
}
return result;
}
void OmniDriver::close()
{
if (closed_) {
return;
}
for (auto &op : operators_) {
op->Close();
op = nullptr;
}
closed_ = true;
}
#define CALL_OPERATOR(call, operatorPtr, operatorId, operatorMethod) \
opCallStatus_.Start(operatorId, operatorMethod); \
call; \
opCallStatus_.TimeSegmentStatistic(operatorPtr, operatorMethod); \
opCallStatus_.Stop(); \
void OpCallStatus::Start(int32_t operatorId, const char* operatorMethod)
{
opId = operatorId;
method = operatorMethod;
cpuTimeStartNs = ThreadCpuNanos();
}
void OpCallStatus::Stop()
{
cpuTimeStartNs = 0;
}
CpuWallTiming OmniDriver::processLazyIoStats(op::Operator& op, const CpuWallTiming& timing)
{
if (&op == operators_[0].get()) {
return timing;
}
auto lockStats = op.stats();
int64_t wallDelta = 0;
uint64_t inputBytesDelta = 0;
wallDelta = std::min<int64_t>(wallDelta, timing.wallNanos);
lockStats = operators_[0]->stats();
lockStats.getOutputTime.Add(CpuWallTiming{
1, wallDelta, 0
});
lockStats.inputBytes += inputBytesDelta;
lockStats.outputBytes += inputBytesDelta;
return CpuWallTiming{
1,
timing.wallNanos - wallDelta,
timing.cpuNanos - 0,
};
}
void OpCallStatus::TimeSegmentStatistic(op::Operator* op, const char* operatorMethod) const
{
const int64_t cpuTimeSegment = ThreadCpuNanos() - cpuTimeStartNs;
std::string_view opMethod(operatorMethod);
if (opMethod != kOpMethodAddInput && opMethod != kOpMethodGetOutput) {
LogDebug("not input or output for operator");
return;
}
auto &lockedStats = op->stats();
if (opMethod == kOpMethodAddInput) {
lockedStats.addInputTime.cpuNanos = cpuTimeSegment / static_cast<int64_t>(1e6);
lockedStats.addInputTime.count = 1;
} else if (opMethod == kOpMethodGetOutput) {
lockedStats.getOutputTime.cpuNanos = cpuTimeSegment / static_cast<int64_t>(1e6);
lockedStats.getOutputTime.count = 1;
}
}
StopReason OmniDriver::RunInternal(
std::shared_ptr<OmniDriver> &self,
std::shared_ptr<BlockingState> &blockingState,
vec::VectorBatch **result)
{
try {
const uint32_t numOperators = operators_.size();
ContinueFuture future = OmniFuture::makeEmpty();
for (;;) {
for (int32_t i = numOperators - 1; i >= 0; --i) {
if (shouldStop) {
return StopReason::kAtEnd;
}
auto *op = operators_[i].get();
curOperatorId_ = i;
blockingReason_ = op->IsBlocked(&future);
if (blockingReason_ != BlockingReason::kNotBlocked) {
return BlockDriver(self, i, std::move(future), blockingState);
}
if (i < numOperators - 1) {
auto *nextOp = operators_[i + 1].get();
blockingReason_ = nextOp->IsBlocked(&future);
if (blockingReason_ != BlockingReason::kNotBlocked) {
return BlockDriver(self, i + 1, std::move(future), blockingState);
}
bool needsInput;
CALL_OPERATOR(needsInput = nextOp->needsInput(), nextOp, curOperatorId_ + 1, kOpMethodNeedsInput);
if (needsInput) {
uint64_t resultBytes = 0;
vec::VectorBatch *intermediateResult = nullptr;
withDeltaCpuWallTimer(op, &OperatorStats::getOutputTime, [&]() {
CALL_OPERATOR(op->GetOutput(&intermediateResult), op, curOperatorId_, kOpMethodGetOutput);
if (intermediateResult) {
resultBytes = intermediateResult->CalculateTotalSize();
{
auto &lockedStats = op->stats();
lockedStats.AddOutputVector(resultBytes, intermediateResult->GetVectorCount(), intermediateResult->GetRowCount());
}
}
});
if (intermediateResult != nullptr) {
withDeltaCpuWallTimer(nextOp, &OperatorStats::addInputTime, [&]() {
{
auto &lockedStats = nextOp->stats();
lockedStats.AddInputVector(resultBytes, intermediateResult->GetVectorCount(), intermediateResult->GetRowCount());
}
CALL_OPERATOR(nextOp->AddInput(intermediateResult), nextOp, curOperatorId_ + 1,
kOpMethodAddInput);
});
i += 2;
continue;
} else {
blockingReason_ = op->IsBlocked(&future);
if (blockingReason_ != BlockingReason::kNotBlocked) {
return BlockDriver(self, i, std::move(future), blockingState);
}
if (op->isFinished()) {
nextOp->noMoreInput();
break;
}
}
}
} else {
withDeltaCpuWallTimer(op, &OperatorStats::getOutputTime, [&]() {
CALL_OPERATOR(op->GetOutput(result), op, curOperatorId_, kOpMethodGetOutput);
if (*result != nullptr) {
{
auto &lockedStats = op->stats();
lockedStats.AddOutputVector((*result)->CalculateTotalSize(), (*result)->GetVectorCount(), (*result)->GetRowCount());
}
}
});
if (*result != nullptr && !op->isFinished()) {
blockingReason_ = BlockingReason::kWaitForConsumer;
return StopReason::kBlock;
}
bool finished{false};
finished = op->isFinished();
if (finished) {
finished_ = true;
return StopReason::kAtEnd;
}
}
}
}
} catch (const std::exception &e) {
throw std::runtime_error(e.what());
}
}
#undef CALL_OPERATOR
StopReason OmniDriver::BlockDriver(
const std::shared_ptr<OmniDriver> &self,
size_t blockedOperatorId,
ContinueFuture &&future,
std::shared_ptr<BlockingState> &blockingState)
{
auto *op = operators_[blockedOperatorId].get();
blockedOperatorId_ = blockedOperatorId;
blockingState = std::make_shared<BlockingState>(
self, std::move(future), op, blockingReason_);
return StopReason::kBlock;
}
template <typename Func>
void OmniDriver::withDeltaCpuWallTimer(op::Operator* op, TimingMemberPtr opTimingMember, Func&& opFunction)
{
if (!trackOperatorCpuUsage_) {
opFunction();
return;
}
auto f = [op, opTimingMember, this](const CpuWallTiming& elapsedTime) {
auto elapsedSelfTime = processLazyIoStats(*op, elapsedTime);
(op->stats().*opTimingMember).Add(elapsedSelfTime);
};
DeltaCpuWallTimer<decltype(f)> timer(std::move(f));
opFunction();
}
}