* Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved.
* Description: Aggregation Base Class
*/
#include "aggregation.h"
#include "aggregator/aggregator_factory.h"
namespace omniruntime {
namespace op {
template <class T>
void AggregationCommonOperatorFactory::CreateAggregatorFactory(
std::vector<std::unique_ptr<AggregatorFactory>> &aggregatorFactories, int32_t maskCol)
{
if (maskCol == Aggregator::INVALID_MASK_COL) {
aggregatorFactories.push_back(std::make_unique<T>());
} else {
aggregatorFactories.push_back(std::make_unique<MaskAggregatorFactory<T>>(maskCol));
}
}
OmniStatus AggregationCommonOperatorFactory::CreateAggregatorFactories(
std::vector<std::unique_ptr<AggregatorFactory>> &aggregatorFactories, const std::vector<uint32_t> &funcTypesContext,
const std::vector<int32_t> &maskCols)
{
OmniStatus ret = OMNI_STATUS_NORMAL;
for (uint32_t i = 0; i < funcTypesContext.size(); ++i) {
switch (funcTypesContext[i]) {
case OMNI_AGGREGATION_TYPE_SUM: {
if (ConfigUtil::GetSupportContainerVecRule() == SupportContainerVecRule::NOT_SUPPORT) {
CreateAggregatorFactory<SumSparkAggregatorFactory>(aggregatorFactories, maskCols[i]);
} else {
CreateAggregatorFactory<SumAggregatorFactory>(aggregatorFactories, maskCols[i]);
}
break;
}
case OMNI_AGGREGATION_TYPE_COUNT_COLUMN: {
CreateAggregatorFactory<CountColumnAggregatorFactory>(aggregatorFactories, maskCols[i]);
break;
}
case OMNI_AGGREGATION_TYPE_COUNT_ALL: {
CreateAggregatorFactory<CountAllAggregatorFactory>(aggregatorFactories, maskCols[i]);
break;
}
case OMNI_AGGREGATION_TYPE_MAX: {
CreateAggregatorFactory<MaxAggregatorFactory>(aggregatorFactories, maskCols[i]);
break;
}
case OMNI_AGGREGATION_TYPE_MIN: {
CreateAggregatorFactory<MinAggregatorFactory>(aggregatorFactories, maskCols[i]);
break;
}
case OMNI_AGGREGATION_TYPE_AVG: {
if (ConfigUtil::GetSupportContainerVecRule() == SupportContainerVecRule::NOT_SUPPORT) {
CreateAggregatorFactory<AverageSparkAggregatorFactory>(aggregatorFactories, maskCols[i]);
} else {
CreateAggregatorFactory<AverageAggregatorFactory>(aggregatorFactories, maskCols[i]);
}
break;
}
case OMNI_AGGREGATION_TYPE_SAMP: {
CreateAggregatorFactory<StddevSampSparkAggregatorFactory>(aggregatorFactories, maskCols[i]);
break;
}
case OMNI_AGGREGATION_TYPE_FIRST_IGNORENULL: {
CreateAggregatorFactory<FirstIgnoreNullAggregatorFactory>(aggregatorFactories, maskCols[i]);
break;
}
case OMNI_AGGREGATION_TYPE_FIRST_INCLUDENULL: {
CreateAggregatorFactory<FirstIncludeNullAggregatorFactory>(aggregatorFactories, maskCols[i]);
break;
}
case OMNI_AGGREGATION_TYPE_TRY_SUM: {
CreateAggregatorFactory<TrySumSparkAggregatorFactory>(aggregatorFactories, maskCols[i]);
break;
}
case OMNI_AGGREGATION_TYPE_TRY_AVG: {
CreateAggregatorFactory<TryAverageSparkAggregatorFactory>(aggregatorFactories, maskCols[i]);
break;
}
default: {
std::string omniExceptionInfo = "In function CreateAggregatorFactories, No such agg func type " +
std::to_string(funcTypesContext[i]);
throw omniruntime::exception::OmniException("UNSUPPORTED_ERROR", omniExceptionInfo);
}
}
}
return ret;
}
}
}