* Copyright (c) Huawei Technologies Co., Ltd. 2020-2024. All rights reserved.
*/
#ifndef __OMNI_OPERATOR_H__
#define __OMNI_OPERATOR_H__
#include <cstdint>
#include <vector>
#include "execution_context.h"
#include "status.h"
#include "vector/vector_batch.h"
#include "vector/vector_helper.h"
#include "metrics/metrics.h"
#include "compute/reason.h"
#include "compute/operator_stats.h"
namespace omniruntime {
namespace op {
using namespace omniruntime::vec;
using namespace omniruntime::compute;
class Operator {
public:
Operator()
: sourceTypes(nullptr),
executionContext(std::make_unique<ExecutionContext>()),
inputVecBatch(nullptr),
status(OMNI_STATUS_NORMAL)
{
}
virtual ~Operator()
{
}
virtual int32_t AddInput(omniruntime::vec::VectorBatch* vecBatch) = 0;
virtual int32_t GetOutput(omniruntime::vec::VectorBatch** result) = 0;
static void DeleteOperator(Operator* op)
{
op->Close();
delete op;
}
OmniStatus GetStatus()
{
return status;
}
void SetStatus(OmniStatus omniStatus)
{
this->status = omniStatus;
};
virtual OmniStatus Init()
{
return OMNI_STATUS_NORMAL;
}
virtual OmniStatus Close()
{
return OMNI_STATUS_NORMAL;
}
virtual uint64_t GetSpilledBytes()
{
return 0;
}
virtual uint64_t GetUsedMemBytes()
{
return 0;
}
virtual uint64_t GetTotalMemBytes()
{
return 0;
}
virtual std::vector<uint64_t> GetSpecialMetricsInfo()
{
return {};
}
virtual uint64_t GetHashMapUniqueKeys()
{
return 0;
}
virtual VectorBatch* AlignSchema(VectorBatch* inputVecBatch)
{
return inputVecBatch;
}
omniruntime::vec::VectorBatch* GetInputVecBatch()
{
return inputVecBatch;
}
void SetInputVecBatch(vec::VectorBatch* inVecBatch)
{
inputVecBatch = inVecBatch;
}
void ResetInputVecBatch()
{
inputVecBatch = nullptr;
}
virtual BlockingReason IsBlocked(ContinueFuture* future)
{
return BlockingReason::kNotBlocked;
}
virtual void noMoreInput()
{
noMoreInput_ = true;
}
virtual void setNoMoreInput(bool noMoreInput)
{
noMoreInput_ = noMoreInput;
}
virtual bool needsInput()
{
return status != OMNI_STATUS_FINISHED && !noMoreInput_;
}
bool isFinished()
{
return status == OMNI_STATUS_FINISHED;
}
bool hasInputedData()
{
return hasInputedData_;
}
void setInputedData(bool hasInputedData)
{
this->hasInputedData_ = hasInputedData;
}
OperatorStats stats(bool clear)
{
OperatorStats stats = stats_;
if (clear) {
stats = stats_;
stats_.Clear();
}
return stats;
}
void SetOperatorType(string opType)
{
operatorType_ = opType;
}
const std::string& operatorType() const
{
return operatorType_;
}
const PlanNodeId& planNodeId() const
{
return planNodeId_;
}
void SetPlanNodeId(PlanNodeId nodeId)
{
planNodeId_ = nodeId;
}
int32_t GetOperatorId() const
{
return operatorId_;
}
void SetOperatorId(int32_t opId)
{
operatorId_ = opId;
}
OperatorStats& stats()
{
return stats_;
}
void setInputOperatorCnt(int32_t cnt)
{
inputOperatorCnt_ = cnt;
}
OperatorStats stats_;
protected:
int32_t* sourceTypes;
std::unique_ptr<ExecutionContext> executionContext;
vec::VectorBatch* inputVecBatch = nullptr;
bool noMoreInput_{true};
int32_t inputOperatorCnt_{0};
void UpdateAddInputInfo(int32_t rowCount)
{
if (LIKELY(!IsDebugEnable())) {
return;
}
metrics.UpdateAddInputInfo(rowCount, executionContext);
}
void UpdateGetOutputInfo(int32_t rowCount)
{
if (LIKELY(!IsDebugEnable())) {
return;
}
metrics.UpdateGetOutputInfo(rowCount, executionContext);
}
void UpdateSpillFileInfo(uint32_t fileCount)
{
if (LIKELY(!IsDebugEnable())) {
return;
}
metrics.UpdateSpillFileInfo(fileCount, executionContext);
}
void UpdateSpillTimesInfo()
{
if (LIKELY(!IsDebugEnable())) {
return;
}
metrics.UpdateSpillTimesInfo(executionContext);
}
void UpdateCloseInfo()
{
if (LIKELY(!IsDebugEnable())) {
return;
}
metrics.UpdateCloseInfo(executionContext);
}
void SetOperatorName(const std::string& operatorName)
{
operatorType_ = operatorName;
if (LIKELY(!IsDebugEnable())) {
return;
}
metrics.SetOperatorName(operatorName);
}
private:
OmniStatus status;
Metrics metrics;
PlanNodeId planNodeId_;
std::string operatorType_;
int32_t operatorId_;
bool hasInputedData_{false};
};
}
}
#endif