* Copyright (c) Huawei Technologies Co., Ltd. 2025-2025. All rights reserved.
*/
#include "grouping.h"
#include "operator/aggregation/group_aggregation_expr.h"
namespace omniruntime::op {
Operator *GroupingOperatorFactory::CreateOperator()
{
return new GroupingOperator(groupingNode_, queryConfig_);
}
GroupingOperatorFactory *GroupingOperatorFactory::CreateGroupingOperatorFactory(
const std::shared_ptr<const GroupingNode> &groupingNode, const config::QueryConfig &queryConfig)
{
return new GroupingOperatorFactory(groupingNode, queryConfig);
}
static std::vector<ExprPtr> BuildAggregateExpressions(
const std::shared_ptr<const AggregationNode>& aggPlanNode,
const std::vector<ExprPtr>& projection)
{
auto size = aggPlanNode->OutputType()->GetSize();
auto groupByKeySize = aggPlanNode->GetGroupByKeys().size();
std::vector<ExprPtr> expressions(size);
for (int32_t i = 0; i < size; i++) {
if (i >= groupByKeySize) {
expressions[i] = new FieldExpr(i, aggPlanNode->OutputType()->GetType(i));
continue;
}
auto fieldExpr = dynamic_cast<FieldExpr *>(aggPlanNode->GetGroupByKeys()[i]);
if (!fieldExpr) {
expressions[i] = new FieldExpr(i, aggPlanNode->OutputType()->GetType(i));
continue;
}
auto literalExpr = dynamic_cast<LiteralExpr *>(projection[fieldExpr->colVal]);
if (!literalExpr) {
expressions[i] = new FieldExpr(i, aggPlanNode->OutputType()->GetType(i));
continue;
}
expressions[i] = literalExpr->Copy();
}
return expressions;
}
GroupingOperator::GroupingOperator(const std::shared_ptr<const GroupingNode> &groupingNode,
config::QueryConfig &queryConfig): queryConfig_(queryConfig)
{
auto expandNode = groupingNode->GetExpandPlanNode();
aggPlanNode_ = groupingNode->GetAggregationNode();
auto projections = expandNode->GetProjections();
auto sourceTypes = *(expandNode->InputType());
auto residualOutputType = aggPlanNode_->OutputType();
auto overflowConfig = queryConfig_.IsOverFlowASNull() == true
? std::make_shared<OverflowConfig>(OVERFLOW_CONFIG_NULL)
: std::make_shared<OverflowConfig>(OVERFLOW_CONFIG_EXCEPTION);
aggFactor_ = CreateOperatorFactory(aggPlanNode_, queryConfig_);
int32_t index = 0;
aggOperators_.resize(projections.size());
residualAggFactor_ = CreateResidualOperatorFactory(aggPlanNode_, residualOutputType, aggPlanNode_->GetGroupByNum(), queryConfig_);
for (const auto &projection : projections) {
if (index == 0) {
auto exprEvaluator = std::make_shared<ExpressionEvaluator>(projection, sourceTypes, overflowConfig.get());
exprEvaluator->ProjectFuncGeneration();
expressionEvaluators_.push_back(exprEvaluator);
aggOperators_[index] = std::shared_ptr<Operator>(aggFactor_->CreateOperator());
} else {
std::vector <ExprPtr> expressions = BuildAggregateExpressions(aggPlanNode_, projection);
auto exprEvaluator = std::make_shared<ExpressionEvaluator>(expressions, *residualOutputType.get(),
overflowConfig.get());
exprEvaluator->ProjectFuncGeneration();
expressionEvaluators_.push_back(exprEvaluator);
aggOperators_[index] = std::shared_ptr<Operator>(residualAggFactor_->CreateOperator());
for (const auto p : projection) {
delete p;
}
}
index++;
}
}
int32_t GroupingOperator::AddInput(VectorBatch *vecBatch)
{
if (vecBatch->GetRowCount() <= 0) {
VectorHelper::FreeVecBatch(vecBatch);
ResetInputVecBatch();
return 0;
}
hasInput = true;
auto projectedVecBatch = this->expressionEvaluators_[0]->Evaluate(vecBatch, executionContext.get());
aggOperators_[0]->AddInput(projectedVecBatch);
VectorHelper::FreeVecBatch(vecBatch);
return 0;
}
int32_t GroupingOperator::GetOutput(VectorBatch **outputVecBatch)
{
if (!noMoreInput_) {
SetStatus(OMNI_STATUS_NORMAL);
return 0;
}
if (!hasInput) {
SetStatus(OMNI_STATUS_FINISHED);
return 0;
}
const auto size = expressionEvaluators_.size();
while (index_ < size) {
aggOperators_[index_ - 1]->GetOutput(outputVecBatch);
if (*outputVecBatch != nullptr) {
auto projectedVecBatch = this->expressionEvaluators_[index_]->Evaluate(*outputVecBatch,
executionContext.get());
aggOperators_[index_]->AddInput(projectedVecBatch);
return 0;
}
index_++;
}
aggOperators_[index_ - 1]->GetOutput(outputVecBatch);
if (*outputVecBatch != nullptr) {
return 0;
}
SetStatus(OMNI_STATUS_FINISHED);
return 0;
}
std::shared_ptr<OperatorFactory> GroupingOperator::CreateOperatorFactory(
const std::shared_ptr<const AggregationNode> &aggregationNode, const config::QueryConfig &queryConfig)
{
if (aggregationNode->GetGroupByKeys().empty()) {
return std::shared_ptr<AggregationWithExprOperatorFactory>(
AggregationWithExprOperatorFactory::CreateAggregationWithExprOperatorFactory(aggregationNode, queryConfig));
}
return std::shared_ptr<HashAggregationWithExprOperatorFactory>(
HashAggregationWithExprOperatorFactory::CreateAggregationWithExprOperatorFactory(aggregationNode, queryConfig));
}
std::shared_ptr<OperatorFactory> GroupingOperator::CreateResidualOperatorFactory(
const std::shared_ptr<const AggregationNode> &aggregationNode, const DataTypesPtr &sourceTypes, int rollUpSize,
const config::QueryConfig &queryConfig)
{
auto aggsOutputTypes = aggregationNode->GetAggsOutputTypes();
int aggsOutputTypeSize = 0;
for (const auto &aggsOutputType : aggsOutputTypes) {
aggsOutputTypeSize += aggsOutputType.GetSize();
}
auto groupBySize = aggregationNode->GetGroupByKeys().size();
groupByKeys_.resize(groupBySize);
std::vector<bool> inputRaws(aggregationNode->GetInputRaws().size(), false);
for (int32_t i = 0; i < groupBySize; i++) {
groupByKeys_[i] = new FieldExpr(i, sourceTypes->GetType(i));
}
int aggIndex = aggregationNode->OutputType()->GetSize() - aggsOutputTypeSize;
for (const auto &aggsOutputType : aggsOutputTypes) {
std::vector<ExprPtr> tmpExprs;
for (int32_t i = 0; i < aggsOutputType.GetSize(); i++) {
tmpExprs.push_back(new FieldExpr(aggIndex, sourceTypes->GetType(aggIndex)));
aggIndex++;
}
aggsKeys_.push_back(tmpExprs);
}
auto residualPlanNoded = std::make_shared<const AggregationNode>(aggregationNode->Id(), groupByKeys_,
aggregationNode->GetGroupByNum(), aggsKeys_, aggregationNode->OutputType(),
aggregationNode->GetAggsOutputTypes(), aggregationNode->GetAggFuncTypes(), aggregationNode->GetAggFilters(),
aggregationNode->GetMaskColumns(), inputRaws, aggregationNode->GetOutputPartials(),
aggregationNode->GetIsStatisticalAggregate(), aggregationNode->OutputType(), aggregationNode->Sources()[0]);
if (aggregationNode->GetGroupByKeys().empty()) {
return std::shared_ptr<AggregationWithExprOperatorFactory>(
AggregationWithExprOperatorFactory::CreateAggregationWithExprOperatorFactory(residualPlanNoded,
queryConfig));
}
return std::shared_ptr<HashAggregationWithExprOperatorFactory>(
HashAggregationWithExprOperatorFactory::CreateAggregationWithExprOperatorFactory(residualPlanNoded,
queryConfig));
}
OmniStatus GroupingOperator::Close()
{
for (auto aggOperator : aggOperators_) {
aggOperator->Close();
}
return OMNI_STATUS_FINISHED;
}
}