* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ArrayJoinHelper.h"
#include <Core/Settings.h>
#include <DataTypes/DataTypeArray.h>
#include <Interpreters/ActionsDAG.h>
#include <Interpreters/ArrayJoin.h>
#include <Interpreters/Context.h>
#include <Processors/QueryPlan/ArrayJoinStep.h>
#include <Processors/QueryPlan/ExpressionStep.h>
#include <Processors/QueryPlan/IQueryPlanStep.h>
#include <Processors/QueryPlan/QueryPlan.h>
#include <Poco/Logger.h>
#include <Common/logger_useful.h>
namespace DB
{
namespace
{
extern const int LOGICAL_ERROR;
}
namespace Setting
{
extern const SettingsUInt64 max_block_size;
}
}
namespace local_engine
{
namespace ArrayJoinHelper
{
const DB::ActionsDAG::Node * findArrayJoinNode(const DB::ActionsDAG & actions_dag)
{
const DB::ActionsDAG::Node * array_join_node = nullptr;
const auto & nodes = actions_dag.getNodes();
for (const auto & node : nodes)
{
if (node.type == DB::ActionsDAG::ActionType::ARRAY_JOIN)
{
if (array_join_node)
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Expect single ARRAY JOIN node in generate rel");
array_join_node = &node;
}
}
return array_join_node;
}
struct SplittedActionsDAGs
{
DB::ActionsDAG before_array_join;
DB::ActionsDAG array_join;
DB::ActionsDAG after_array_join;
};
static SplittedActionsDAGs splitActionsDAGInGenerate(const DB::ActionsDAG & actions_dag)
{
SplittedActionsDAGs res;
auto array_join_node = findArrayJoinNode(actions_dag);
std::unordered_set<const DB::ActionsDAG::Node *> first_split_nodes(array_join_node->children.begin(), array_join_node->children.end());
auto first_split_result = actions_dag.split(first_split_nodes);
res.before_array_join = std::move(first_split_result.first);
array_join_node = findArrayJoinNode(first_split_result.second);
std::unordered_set<const DB::ActionsDAG::Node *> second_split_nodes = {array_join_node};
auto second_split_result = first_split_result.second.split(second_split_nodes);
res.array_join = std::move(second_split_result.first);
second_split_result.second.removeUnusedActions();
res.after_array_join = std::move(second_split_result.second);
return res;
}
DB::ActionsDAG applyArrayJoinOnOneColumn(const DB::Block & header, size_t column_index)
{
auto arrayColumn = header.getByPosition(column_index);
if (!typeid_cast<const DB::DataTypeArray *>(arrayColumn.type.get()))
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Expect array column in array join");
DB::ActionsDAG actions_dag(header.getColumnsWithTypeAndName());
const auto * array_column_node = actions_dag.getInputs()[column_index];
auto array_join_name = array_column_node->result_name;
const auto * array_join_node = &actions_dag.addArrayJoin(*array_column_node, array_join_name);
actions_dag.addOrReplaceInOutputs(*array_join_node);
return std::move(actions_dag);
}
std::vector<DB::IQueryPlanStep *>
addArrayJoinStep(DB::ContextPtr context, DB::QueryPlan & plan, const DB::ActionsDAG & actions_dag, bool is_left)
{
auto logger = getLogger("ArrayJoinHelper");
std::vector<DB::IQueryPlanStep *> steps;
if (findArrayJoinNode(actions_dag))
{
LOG_DEBUG(logger, "original actions_dag:{}", actions_dag.dumpDAG());
auto splitted_actions_dags = splitActionsDAGInGenerate(actions_dag);
LOG_DEBUG(logger, "actions_dag before arrayJoin:{}", splitted_actions_dags.before_array_join.dumpDAG());
LOG_DEBUG(logger, "actions_dag during arrayJoin:{}", splitted_actions_dags.array_join.dumpDAG());
LOG_DEBUG(logger, "actions_dag after arrayJoin:{}", splitted_actions_dags.after_array_join.dumpDAG());
auto ignore_actions_dag = [](const DB::ActionsDAG & actions_dag_) -> bool
{
We should ignore actions_dag like:
0 : INPUT () (no column) String a
1 : INPUT () (no column) String b
Output nodes: 0, 1
*/
return actions_dag_.getOutputs().size() == actions_dag_.getNodes().size()
&& actions_dag_.getInputs().size() == actions_dag_.getNodes().size();
};
if (!ignore_actions_dag(splitted_actions_dags.before_array_join))
{
auto step_before_array_join
= std::make_unique<DB::ExpressionStep>(plan.getCurrentHeader(), std::move(splitted_actions_dags.before_array_join));
step_before_array_join->setStepDescription("Pre-projection In Generate");
steps.emplace_back(step_before_array_join.get());
plan.addStep(std::move(step_before_array_join));
}
DB::Names array_joined_columns{findArrayJoinNode(splitted_actions_dags.array_join)->result_name};
DB::ArrayJoin array_join;
array_join.columns = std::move(array_joined_columns);
array_join.is_left = is_left;
auto array_join_step = std::make_unique<DB::ArrayJoinStep>(
plan.getCurrentHeader(), std::move(array_join), false, context->getSettingsRef()[DB::Setting::max_block_size]);
array_join_step->setStepDescription("ARRAY JOIN In Generate");
steps.emplace_back(array_join_step.get());
plan.addStep(std::move(array_join_step));
if (!ignore_actions_dag(splitted_actions_dags.after_array_join))
{
auto step_after_array_join
= std::make_unique<DB::ExpressionStep>(plan.getCurrentHeader(), std::move(splitted_actions_dags.after_array_join));
step_after_array_join->setStepDescription("Post-projection In Generate");
steps.emplace_back(step_after_array_join.get());
plan.addStep(std::move(step_after_array_join));
}
}
else
{
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Expect array join node in actions_dag");
}
return steps;
}
}
}