* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <string>
#include <vector>
#include "expression/expressions.h"
#include "operator/config/operator_config.h"
#include "type/data_types.h"
#include "util/config/QueryConfig.h"
namespace omniruntime {
namespace op {
enum BuildSide {
OMNI_BUILD_UNKNOWN = 0,
OMNI_BUILD_LEFT,
OMNI_BUILD_RIGHT
};
}
using namespace type;
using namespace op;
using namespace expressions;
using ExprPtr = Expr *;
using PlanNodeId = std::string;
class SortOrderInfo {
public:
SortOrderInfo(bool ascending, bool nullsFirst) noexcept : ascending(ascending), nullsFirst(nullsFirst) {}
bool IsAscending() const { return ascending; }
bool IsNullsFirst() const { return nullsFirst; }
bool operator==(const SortOrderInfo &other) const
{
return std::tie(ascending, nullsFirst) == std::tie(other.ascending, other.nullsFirst);
}
bool operator!=(const SortOrderInfo &other) const { return !(*this == other); }
std::string ToString() const
{
return Format("{} NULLS {}", (ascending ? "ASC" : "DESC"), (nullsFirst ? "FIRST" : "LAST"));
}
private:
bool ascending;
bool nullsFirst;
};
extern const SortOrderInfo K_ASC_NULLS_FIRST;
extern const SortOrderInfo K_ASC_NULLS_LAST;
extern const SortOrderInfo K_DESC_NULLS_FIRST;
extern const SortOrderInfo K_DESC_NULLS_LAST;
class PlanNode {
public:
explicit PlanNode(const PlanNodeId &id) : id{id} {}
virtual ~PlanNode() = default;
virtual const DataTypesPtr &OutputType() const = 0;
virtual std::string_view Name() const = 0;
const PlanNodeId &Id() const { return id; }
virtual bool CanSpill(const config::QueryConfig &queryConfig) const { return false; }
virtual const std::vector<std::shared_ptr<const PlanNode>> &Sources() const = 0;
private:
const std::string id;
};
using PlanNodePtr = std::shared_ptr<const PlanNode>;
class OrderByNode : public PlanNode {
public:
OrderByNode(const PlanNodeId& id, const std::vector<int32_t>& sortCols, const std::vector<int32_t>& sortAscending,
const std::vector<int32_t>& sortNullFirsts, const PlanNodePtr& source,
std::vector<omniruntime::expressions::Expr*>& sortExpressions)
: PlanNode(id),
sourceTypes(source->OutputType()),
sortCols(sortCols),
sortAscending(sortAscending),
sortNullFirsts(sortNullFirsts),
sources({source}),
sortExpressions(sortExpressions)
{
outputCols.reserve(sourceTypes->GetSize());
for (int i = 0; i < sourceTypes->GetSize(); ++i) {
outputCols.push_back(i);
}
}
~OrderByNode() override = default;
const DataTypesPtr &OutputType() const override { return sources[0]->OutputType(); }
const std::vector<std::shared_ptr<const PlanNode>> &Sources() const override { return sources; }
const std::vector<int32_t> &GetSortCols() const { return sortCols; }
const std::vector<int32_t> &GetOutputCols() const { return outputCols; }
const std::vector<omniruntime::expressions::Expr*>& GetExpressions() const
{
return sortExpressions;
}
const DataTypesPtr &GetSourceTypes() const { return sourceTypes; }
const std::vector<int32_t> &GetSortAscending() const { return sortAscending; }
const std::vector<int32_t> &GetNullFirsts() const { return sortNullFirsts; }
bool CanSpill(const config::QueryConfig &queryConfig) const override { return queryConfig.orderBySpillEnabled(); }
std::string_view Name() const override { return "OrderBy"; }
private:
const DataTypesPtr sourceTypes;
std::vector<int32_t> outputCols;
const std::vector<int32_t> sortCols;
const std::vector<int32_t> sortAscending;
const std::vector<int32_t> sortNullFirsts;
const std::vector<PlanNodePtr> sources;
const OperatorConfig operatorConfig;
std::vector<omniruntime::expressions::Expr *> sortExpressions;
};
class FilterNode : public PlanNode {
public:
FilterNode(const PlanNodeId &id, ExprPtr filter, PlanNodePtr source, const std::vector<ExprPtr> &projectList)
: PlanNode(id), sources{std::move(source)}, filter(std::move(filter)), projectList(projectList)
{
std::vector<DataTypePtr> joinInputTypes;
if (!projectList.empty()) {
for (int i = 0; i < projectList.size(); i++) {
joinInputTypes.push_back(projectList[i]->dataType);
}
this->outputType = std::make_shared<DataTypes>(std::move(joinInputTypes));
} else {
this->outputType = sources[0]->OutputType();
}
}
~FilterNode() override = default;
const DataTypesPtr &OutputType() const override { return outputType; }
const std::vector<PlanNodePtr> &Sources() const override { return sources; }
ExprPtr GetFilterExpr() const { return filter; }
const std::vector<ExprPtr> &ProjectList() const {return projectList; }
std::string_view Name() const override { return "Filter"; }
private:
const std::vector<PlanNodePtr> sources;
ExprPtr filter;
const std::vector<ExprPtr> projectList;
DataTypesPtr outputType;
};
class ProjectNode : public PlanNode {
public:
ProjectNode(const PlanNodeId &id, std::vector<ExprPtr> &&inProjections, PlanNodePtr source)
: PlanNode(id), sources{source}, projections(std::move(inProjections)),
outputType(MakeOutputType(projections)) {}
static DataTypesPtr MakeOutputType(const std::vector<ExprPtr> &projections)
{
std::vector<DataTypePtr> argTypes;
for (auto project : projections) {
argTypes.push_back(project->GetReturnType());
}
return std::make_shared<DataTypes>(std::move(argTypes));
}
~ProjectNode() override = default;
const DataTypesPtr &OutputType() const override
{
return outputType;
}
const std::vector<PlanNodePtr> &Sources() const override { return sources; }
const std::vector<ExprPtr> &GetProjections() const { return projections; }
std::string_view Name() const override { return "Project"; }
protected:
const std::vector<PlanNodePtr> sources;
const std::vector<ExprPtr> projections;
const DataTypesPtr outputType;
};
class AggregationNode : public PlanNode {
public:
enum class Step {
K_PARTIAL,
K_FINAL,
K_INTERMEDIATE,
K_SINGLE
};
AggregationNode(const PlanNodeId &id, const std::vector<ExprPtr> &groupByKeys, const uint32_t groupByNum,
const std::vector<std::vector<ExprPtr>> &aggKeys, const DataTypesPtr sourceDataTypes,
const std::vector<DataTypes> aggsOutputTypes, const std::vector<uint32_t> &aggFuncTypes,
const std::vector<ExprPtr> &aggFilters, const std::vector<uint32_t> &maskColumns,
const std::vector<bool> &inputRaws, const std::vector<bool> &outputPartials, const bool isStatisticalAggregate,
const DataTypesPtr outputType, PlanNodePtr source)
: PlanNode(id), groupByKeys(groupByKeys), groupByNum(groupByNum), aggKeys(aggKeys),
sourceDataTypes(sourceDataTypes), aggsOutputTypes(aggsOutputTypes), aggFuncTypes(aggFuncTypes),
aggFilters(aggFilters), maskColumns(maskColumns), inputRaws(inputRaws), outputPartials(outputPartials),
isStatisticalAggregate(isStatisticalAggregate), outputType(outputType), sources({source}) {}
~AggregationNode() override = default;
std::string_view Name() const override { return "Aggregation"; }
const DataTypesPtr &OutputType() const override { return outputType; }
const std::vector<PlanNodePtr> &Sources() const override { return sources; }
const std::vector<ExprPtr> &GetGroupByKeys() const { return groupByKeys; }
const uint32_t GetGroupByNum() const { return groupByNum; }
const std::vector<std::vector<ExprPtr>> GetAggsKeys() const { return aggKeys; }
DataTypesPtr GetSourceDataTypes() const { return sourceDataTypes; }
const std::vector<uint32_t> &GetAggFuncTypes() const { return aggFuncTypes; }
const std::vector<ExprPtr> GetAggFilters() const { return aggFilters; }
const std::vector<uint32_t> GetMaskColumns() const { return maskColumns; }
const std::vector<bool> GetInputRaws() const { return inputRaws; }
const std::vector<bool> GetOutputPartials() const { return outputPartials; }
const bool GetIsStatisticalAggregate() const { return isStatisticalAggregate; }
const std::vector<DataTypes> &GetAggsOutputTypes() const { return aggsOutputTypes; }
private:
const std::vector<ExprPtr> groupByKeys;
const uint32_t groupByNum;
const std::vector<std::vector<ExprPtr>> aggKeys;
const DataTypesPtr sourceDataTypes;
const std::vector<DataTypes> aggsOutputTypes;
const std::vector<uint32_t> aggFuncTypes;
const std::vector<ExprPtr> aggFilters;
const std::vector<uint32_t> maskColumns;
const std::vector<bool> inputRaws;
const std::vector<bool> outputPartials;
bool isStatisticalAggregate;
const DataTypesPtr outputType;
const std::vector<PlanNodePtr> sources;
};
class WindowNode : public PlanNode {
public:
WindowNode(const PlanNodeId &id,
const std::vector<int32_t> &windowFunctionTypes,
const std::vector<int32_t> &partitionCols,
const std::vector<int32_t> &preGroupedCols,
const std::vector<int32_t> &sortCols,
const std::vector<int32_t> &sortAscending,
const std::vector<int32_t> &sortNullFirsts,
const int32_t preSortedChannelPrefix,
const int32_t expectedPositionsCount,
const DataTypesPtr &windowFunctionReturnTypes,
const DataTypesPtr &allTypes,
const std::vector<omniruntime::expressions::Expr *> argumentKeys,
const std::vector<int32_t> &windowFrameTypes,
const std::vector<int32_t> &windowFrameStartTypes,
const std::vector<int32_t> &windowFrameStartChannels,
const std::vector<int32_t> &windowFrameEndTypes,
const std::vector<int32_t> &windowFrameEndChannels,
const PlanNodePtr &source)
: PlanNode(id),
windowFunctionTypes(windowFunctionTypes),
partitionCols(partitionCols),
preGroupedCols(preGroupedCols),
sortCols(sortCols),
sortAscending(sortAscending),
sortNullFirsts(sortNullFirsts),
preSortedChannelPrefix(preSortedChannelPrefix),
expectedPositionsCount(expectedPositionsCount),
windowFunctionReturnTypes(windowFunctionReturnTypes),
allTypes(allTypes),
argumentKeys(argumentKeys),
windowFrameTypes(windowFrameTypes),
windowFrameStartTypes(windowFrameStartTypes),
windowFrameStartChannels(windowFrameStartChannels),
windowFrameEndTypes(windowFrameEndTypes),
windowFrameEndChannels(windowFrameEndChannels),
sources({source}),
sourceTypes(source->OutputType())
{
outputCols.reserve(sourceTypes->GetSize());
for (int i = 0; i < sourceTypes->GetSize(); ++i) {
outputCols.push_back(i);
}
}
~WindowNode() override = default;
const DataTypesPtr &GetSourceTypes() const { return sourceTypes; }
const std::vector<int32_t> &GetOutputCols() const { return outputCols; }
const std::vector<int32_t> &GetWindowFunctionTypes() const { return windowFunctionTypes; }
const std::vector<int32_t> &GetPartitionCols() const { return partitionCols; }
const std::vector<int32_t> &GetPreGroupedCols() const { return preGroupedCols; }
const std::vector<int32_t> &GetSortCols() const { return sortCols; }
const std::vector<int32_t> &GetSortAscending() const { return sortAscending; }
const std::vector<int32_t> &GetNullFirsts() const { return sortNullFirsts; }
const int32_t GetPreSortedChannelPrefix() const { return preSortedChannelPrefix; }
const int32_t GetExpectedPositionsCount() const { return expectedPositionsCount; }
const DataTypesPtr &GetWindowFunctionReturnTypes() const { return windowFunctionReturnTypes; }
const std::vector<omniruntime::expressions::Expr *> &GetArgumentKeys() const { return argumentKeys; }
const std::vector<int32_t> &GetWindowFrameTypes() const { return windowFrameTypes; }
const std::vector<int32_t> &GetWindowFrameStartTypes() const { return windowFrameStartTypes; }
const std::vector<int32_t> &GetWindowFrameStartChannels() const { return windowFrameStartChannels; }
const std::vector<int32_t> &GetWindowFrameEndTypes() const { return windowFrameEndTypes; }
const std::vector<int32_t> &GetWindowFrameEndChannels() const{ return windowFrameEndChannels; }
const std::vector<PlanNodePtr> &Sources() const override { return sources; }
const DataTypesPtr &OutputType() const override { return allTypes; }
std::string_view Name() const override { return "Window"; }
bool CanSpill(const config::QueryConfig &queryConfig) const override {
return queryConfig.orderBySpillEnabled();
}
private:
const std::vector<int32_t> windowFunctionTypes;
const std::vector<int32_t> partitionCols;
const std::vector<int32_t> preGroupedCols;
const std::vector<int32_t> sortCols;
const std::vector<int32_t> sortAscending;
const std::vector<int32_t> sortNullFirsts;
const int32_t preSortedChannelPrefix;
const int32_t expectedPositionsCount;
const DataTypesPtr windowFunctionReturnTypes;
const DataTypesPtr allTypes;
const std::vector<omniruntime::expressions::Expr *> argumentKeys;
const std::vector<int32_t> windowFrameTypes;
const std::vector<int32_t> windowFrameStartTypes;
const std::vector<int32_t> windowFrameStartChannels;
const std::vector<int32_t> windowFrameEndTypes;
const std::vector<int32_t> windowFrameEndChannels;
const std::vector<PlanNodePtr> sources;
const DataTypesPtr sourceTypes;
std::vector<int32_t> outputCols;
};
enum JoinType {
OMNI_JOIN_TYPE_INNER = 0,
OMNI_JOIN_TYPE_LEFT,
OMNI_JOIN_TYPE_RIGHT,
OMNI_JOIN_TYPE_FULL,
OMNI_JOIN_TYPE_LEFT_SEMI,
OMNI_JOIN_TYPE_LEFT_ANTI,
OMNI_JOIN_TYPE_EXISTENCE
};
class AbstractJoinNode : public PlanNode {
public:
AbstractJoinNode(const PlanNodeId &id, JoinType joinType_, BuildSide buildSide_, const std::vector<ExprPtr> &leftKeys_, const std::vector<ExprPtr> &rightKeys_,
ExprPtr filter_, PlanNodePtr left_, PlanNodePtr right_, DataTypesPtr leftOutputType_, DataTypesPtr rightOutputType_, const std::vector<ExprPtr> &partitionKeys_)
: PlanNode(id), joinType(joinType_), buildSide(buildSide_), leftKeys(leftKeys_), rightKeys(rightKeys_), filter(std::move(filter_)), sources({std::move(left_),
std::move(right_)}), leftOutputType(std::move(leftOutputType_)), rightOutputType(std::move(rightOutputType_)), partitionKeys(partitionKeys_)
{
std::vector<DataTypePtr> joinInputTypes;
if (!partitionKeys.empty()) {
for (int i = 0; i < partitionKeys.size(); i++) {
joinInputTypes.push_back(partitionKeys[i]->dataType);
}
this->outputType = std::make_shared<DataTypes>(std::move(joinInputTypes));
} else {
this->outputType = GetOutputType();
}
}
~AbstractJoinNode()
{
if (!partitionKeys.empty()) {
for (auto expr: partitionKeys) {
delete expr;
}
}
}
const std::vector<PlanNodePtr> &Sources() const override
{
return sources;
}
const DataTypesPtr &OutputType() const override
{
return outputType;
}
const DataTypesPtr &LeftOutputType() const
{
return leftOutputType;
}
const DataTypesPtr &RightOutputType() const
{
return rightOutputType;
}
JoinType GetJoinType() const
{
return joinType;
}
BuildSide GetBuildSide() const
{
return buildSide;
}
bool IsInnerJoin() const
{
return joinType == OMNI_JOIN_TYPE_INNER;
}
bool IsLeftJoin() const
{
return joinType == JoinType::OMNI_JOIN_TYPE_LEFT;
}
bool IsRightJoin() const
{
return joinType == JoinType::OMNI_JOIN_TYPE_RIGHT;
}
bool IsFullJoin() const
{
return joinType == JoinType::OMNI_JOIN_TYPE_FULL;
}
bool IsLeftSemi() const
{
return joinType == JoinType::OMNI_JOIN_TYPE_LEFT_SEMI;
}
bool IsLeftAnti() const
{
return joinType == JoinType::OMNI_JOIN_TYPE_LEFT_ANTI;
}
bool IsExistence() const
{
return joinType == JoinType::OMNI_JOIN_TYPE_EXISTENCE;
}
bool IsBuildLeft() const
{
return buildSide == BuildSide::OMNI_BUILD_LEFT;
}
bool IsBuildRight() const
{
return buildSide == BuildSide::OMNI_BUILD_RIGHT;
}
const std::vector<ExprPtr> &LeftKeys() const
{
return leftKeys;
}
const std::vector<ExprPtr> &RightKeys() const
{
return rightKeys;
}
const ExprPtr Filter() const
{
return filter;
}
std::shared_ptr<DataTypes> GetOutputType()
{
bool outputMayIncludeLeftColumns = !((IsLeftSemi() || IsExistence()) && IsBuildLeft());
bool outputMayIncludeRightColumns = !(((IsLeftSemi() || IsExistence()) && IsBuildRight()) || IsLeftAnti());
if (outputMayIncludeLeftColumns && outputMayIncludeRightColumns) {
auto outputSize = leftOutputType->GetSize() + rightOutputType->GetSize();
std::vector<DataTypePtr> joinInputTypes;
joinInputTypes.reserve(outputSize);
joinInputTypes.insert(joinInputTypes.end(), leftOutputType->Get().begin(), leftOutputType->Get().end());
joinInputTypes.insert(joinInputTypes.end(), rightOutputType->Get().begin(),
rightOutputType->Get().end());
if (buildSide == OMNI_BUILD_LEFT) {
std::rotate(joinInputTypes.begin(), joinInputTypes.begin() + leftOutputType->Get().size(),
joinInputTypes.end());
}
return std::make_shared<DataTypes>(std::move(joinInputTypes));
} else if (outputMayIncludeLeftColumns) {
int extraCnt = IsExistence() ? 1 : 0;
auto outputSize = leftOutputType->GetSize() + extraCnt;
std::vector<DataTypePtr> joinInputTypes;
joinInputTypes.reserve(outputSize);
joinInputTypes.insert(joinInputTypes.end(), leftOutputType->Get().begin(), leftOutputType->Get().end());
if (extraCnt > 0) {
joinInputTypes.emplace_back(BooleanDataType::Instance());
}
if (buildSide == OMNI_BUILD_LEFT) {
std::rotate(joinInputTypes.begin(), joinInputTypes.begin() + leftOutputType->Get().size(),
joinInputTypes.end());
}
return std::make_shared<DataTypes>(std::move(joinInputTypes));
} else if (outputMayIncludeRightColumns) {
int extraCnt = IsExistence() ? 1 : 0;
auto outputSize = rightOutputType->GetSize() + extraCnt;
std::vector<DataTypePtr> joinInputTypes;
joinInputTypes.reserve(outputSize);
joinInputTypes.insert(joinInputTypes.end(), rightOutputType->Get().begin(),
rightOutputType->Get().end());
if (extraCnt > 0) {
joinInputTypes.emplace_back(BooleanDataType::Instance());
}
if (buildSide == OMNI_BUILD_LEFT) {
std::rotate(joinInputTypes.begin(), joinInputTypes.begin() + rightOutputType->Get().size(),
joinInputTypes.end());
}
return std::make_shared<DataTypes>(std::move(joinInputTypes));
}
}
protected:
const JoinType joinType;
const BuildSide buildSide;
const std::vector<ExprPtr> leftKeys;
const std::vector<ExprPtr> rightKeys;
const ExprPtr filter;
const std::vector<PlanNodePtr> sources;
const DataTypesPtr leftOutputType;
const DataTypesPtr rightOutputType;
DataTypesPtr outputType;
const std::vector<ExprPtr> partitionKeys;
};
class HashJoinNode : public AbstractJoinNode {
public:
HashJoinNode(const PlanNodeId &id, JoinType joinType, BuildSide buildSide, bool nullAware, bool isShuffle, const std::vector<ExprPtr> &leftKeys,
const std::vector<ExprPtr> &rightKeys, ExprPtr filter, PlanNodePtr left, PlanNodePtr right, DataTypesPtr leftOutputType,
DataTypesPtr rightOutputType, const std::vector<omniruntime::expressions::Expr*>& partitionKeys)
: AbstractJoinNode(id, joinType, buildSide, leftKeys, rightKeys, std::move(filter), std::move(left), std::move(right), std::move(leftOutputType), std::move(rightOutputType), partitionKeys),
nullAware{nullAware}, isShuffle{isShuffle} {}
std::string_view Name() const override
{
return "HashJoin";
}
bool IsNullAware() const
{
return nullAware;
}
bool IsShuffle() const
{
return isShuffle;
}
std::vector<omniruntime::expressions::Expr*> PartitionKeys() const
{
return partitionKeys;
}
private:
const bool nullAware;
const bool isShuffle;
};
class MergeJoinNode : public AbstractJoinNode {
public:
MergeJoinNode(const PlanNodeId &id, JoinType joinType, BuildSide buildSide, const std::vector<ExprPtr> &leftKeys,
const std::vector<ExprPtr> &rightKeys, ExprPtr filter, PlanNodePtr left, PlanNodePtr right, DataTypesPtr leftOutputType, DataTypesPtr rightOutputType, const std::vector<ExprPtr>& partitionKeys)
: AbstractJoinNode(id, joinType, buildSide, leftKeys, rightKeys, std::move(filter), std::move(left), std::move(right),
std::move(leftOutputType), std::move(rightOutputType), partitionKeys) {}
std::string_view Name() const override
{
return "MergeJoin";
}
};
class NestedLoopJoinNode : public PlanNode {
public:
NestedLoopJoinNode(const PlanNodeId &id, JoinType joinType_, ExprPtr filter_, PlanNodePtr left_, PlanNodePtr right_,
DataTypesPtr leftOutputType_, DataTypesPtr rightOutputType_)
: PlanNode(id), joinType(joinType_), filter(std::move(filter_)), sources({std::move(left_), std::move(right_)}),
leftOutputType(std::move(leftOutputType_)), rightOutputType(std::move(rightOutputType_))
{
auto outputSize = leftOutputType->GetSize() + rightOutputType->GetSize();
std::vector<DataTypePtr> joinInputTypes;
joinInputTypes.reserve(outputSize);
joinInputTypes.insert(joinInputTypes.end(), leftOutputType->Get().begin(), leftOutputType->Get().end());
joinInputTypes.insert(joinInputTypes.end(), rightOutputType->Get().begin(), rightOutputType->Get().end());
this->outputType = std::make_shared<DataTypes>(std::move(joinInputTypes));
}
const std::vector<PlanNodePtr> &Sources() const override
{
return sources;
}
const DataTypesPtr &OutputType() const override
{
return outputType;
}
std::string_view Name() const override
{
return "NestedLoopJoin";
}
const ExprPtr Filter() const
{
return filter;
}
const JoinType GetJoinType() const
{
return joinType;
}
const DataTypesPtr &LeftOutputType() const
{
return leftOutputType;
}
const DataTypesPtr &RightOutputType() const
{
return rightOutputType;
}
static bool IsSupported(JoinType joinTypeParam)
{
switch (joinTypeParam) {
case OMNI_JOIN_TYPE_INNER:
case OMNI_JOIN_TYPE_LEFT:
case OMNI_JOIN_TYPE_RIGHT:
case OMNI_JOIN_TYPE_FULL:
return true;
default:
return false;
}
}
private:
const JoinType joinType;
const ExprPtr filter;
const std::vector<PlanNodePtr> sources;
const DataTypesPtr leftOutputType;
const DataTypesPtr rightOutputType;
DataTypesPtr outputType;
};
class TopNNode : public PlanNode {
public:
TopNNode(const PlanNodeId &id, const std::vector<omniruntime::expressions::Expr *> &sortCols,
const std::vector<int32_t> &sortAscending, const std::vector<int32_t> &sortNullFirsts,
int32_t count, const PlanNodePtr &source)
: PlanNode(id), sourceTypes(source->OutputType()), sortCols(sortCols), sortAscending(sortAscending),
sortNullFirsts(sortNullFirsts), count(count), sources({source})
{}
const std::vector<omniruntime::expressions::Expr *> &GetSortCols() const { return sortCols; }
const std::vector<int32_t> &GetSortAscending() const { return sortAscending; }
const std::vector<int32_t> &GetNullFirsts() const { return sortNullFirsts; }
const DataTypesPtr &OutputType() const override { return sources[0]->OutputType(); }
const DataTypesPtr &GetSourceTypes() const { return sourceTypes; }
const std::vector<PlanNodePtr> &Sources() const override { return sources; }
int32_t Count() const { return count; }
std::string_view Name() const override { return "TopN"; }
private:
const DataTypesPtr sourceTypes;
const std::vector<omniruntime::expressions::Expr *> sortCols;
const std::vector<int32_t> sortAscending;
const std::vector<int32_t> sortNullFirsts;
const int32_t count;
const std::vector<PlanNodePtr> sources;
};
class TopNSortNode : public PlanNode {
public:
TopNSortNode(const PlanNodeId &id, const std::vector<omniruntime::expressions::Expr*>& partitionKeys,
const std::vector<omniruntime::expressions::Expr*>& sortKeys, const std::vector<int32_t>& sortAscendings,
const std::vector<int32_t>& sortNullFirsts, int32_t n, bool isStrictTopN, const PlanNodePtr &source)
: PlanNode(id), sourceTypes(source->OutputType()), n(n), isStrictTopN(isStrictTopN), partitionKeys(partitionKeys),
sortKeys(sortKeys), sortAscendings(sortAscendings), sortNullFirsts(sortNullFirsts), sources({source})
{}
DataTypesPtr getSourceTypes() const { return sourceTypes; }
int32_t getN() const { return n; }
bool getIsStrictTopN() const { return isStrictTopN; }
const std::vector<omniruntime::expressions::Expr*>& getPartitionKeys() const { return partitionKeys; }
const std::vector<omniruntime::expressions::Expr*>& getSortKeys() const { return sortKeys; }
const std::vector<int32_t>& getSortAscendings() const { return sortAscendings; }
const std::vector<int32_t>& getSortNullFirsts() const { return sortNullFirsts; }
std::string_view Name() const override { return "TopNSort"; }
const DataTypesPtr &OutputType() const override { return sources[0]->OutputType(); }
const std::vector<PlanNodePtr> &Sources() const override { return sources; }
private:
const DataTypesPtr sourceTypes;
const int32_t n;
const bool isStrictTopN;
const std::vector<omniruntime::expressions::Expr *> partitionKeys;
const std::vector<omniruntime::expressions::Expr *> sortKeys;
const std::vector<int32_t> sortAscendings;
const std::vector<int32_t> sortNullFirsts;
const std::vector<PlanNodePtr> sources;
};
class LimitNode : public PlanNode {
public:
LimitNode(const PlanNodeId &id, int32_t offset, int32_t count, bool isPartial, const PlanNodePtr &source)
: PlanNode(id), offset(offset), count(count), isPartial(isPartial), sources{source}
{}
const DataTypesPtr &OutputType() const override { return sources[0]->OutputType(); }
const std::vector<PlanNodePtr> &Sources() const override { return sources; }
int32_t Offset() const { return offset; }
int32_t Count() const { return count; }
bool IsPartial() const { return isPartial; }
std::string_view Name() const override { return "Limit"; }
private:
const int32_t offset;
const int32_t count;
const bool isPartial;
const std::vector<PlanNodePtr> sources;
};
class UnionNode : public PlanNode {
public:
UnionNode(const PlanNodeId &id, std::vector<PlanNodePtr> sources, bool isDistinct)
: PlanNode(id), isDistinct(isDistinct), sources{std::move(sources)}
{}
const DataTypesPtr &OutputType() const override { return sources[0]->OutputType(); }
const std::vector<PlanNodePtr> &Sources() const override { return sources; }
std::string_view Name() const override { return "Union"; }
const DataTypesPtr &GetSourceTypes() const { return sources[0]->OutputType(); }
const bool &IsDistinct() const { return isDistinct; }
private:
const bool isDistinct;
const std::vector<PlanNodePtr> sources;
};
class ExpandNode : public PlanNode {
public:
ExpandNode(const PlanNodeId &id, std::vector<std::vector<ExprPtr>> &&projections, PlanNodePtr source)
: PlanNode(id), sources{source}, projections(std::move(projections))
{
std::vector<DataTypePtr> types;
if (this->projections.size() > 0) {
types.reserve(this->projections[0].size());
for (const auto &projection: this->projections[0]) {
types.push_back(projection->GetReturnType());
}
}
this->outputType = std::make_shared<DataTypes>(std::move(types));
}
const DataTypesPtr &OutputType() const override
{
return outputType;
}
const DataTypesPtr& InputType() const
{
return sources[0]->OutputType();
}
const std::vector<PlanNodePtr>& Sources() const override
{
return sources;
}
const std::vector<std::vector<ExprPtr>>& GetProjections() const
{
return projections;
}
std::string_view Name() const override
{
return "Expand";
}
private:
const std::vector<PlanNodePtr> sources;
const std::vector<std::vector<ExprPtr>> projections;
DataTypesPtr outputType;
};
class GroupingNode : public PlanNode {
public:
GroupingNode(const PlanNodeId &id, const std::shared_ptr<const ExpandNode> &expandPlanNode,
const std::shared_ptr<const AggregationNode> &aggregationNode)
: PlanNode(id), expandPlanNode_(expandPlanNode), aggregationNode_(aggregationNode),
sources_(expandPlanNode_->Sources()), outputType_(aggregationNode_->OutputType()) {}
std::shared_ptr<const ExpandNode> GetExpandPlanNode() const
{
return expandPlanNode_;
}
std::shared_ptr<const AggregationNode> GetAggregationNode() const
{
return aggregationNode_;
}
const DataTypesPtr &OutputType() const override
{
return outputType_;
}
const std::vector<PlanNodePtr> &Sources() const override
{
return sources_;
}
std::string_view Name() const override
{
return "Grouping";
}
private:
const std::shared_ptr<const ExpandNode> expandPlanNode_;
const std::shared_ptr<const AggregationNode> aggregationNode_;
const std::vector<PlanNodePtr> sources_;
DataTypesPtr outputType_;
};
}