/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2021-2023. All rights reserved.
 * Description: Aggregation Base Class
 */
#ifndef AGGREGATION_H
#define AGGREGATION_H

#include <vector>
#include <thread>
#include "operator/operator_factory.h"
#include "operator/aggregation/aggregator/only_aggregator_factory.h"
#include "memory/memory_pool.h"
#include "operator/status.h"

namespace omniruntime {
namespace op {
class AggregationCommonOperator : public Operator {
public:
    explicit AggregationCommonOperator(std::vector<std::unique_ptr<Aggregator>> &&aggs, std::vector<bool> &inputRaws,
        std::vector<bool> &outputPartials, bool isOverflowAsNull = false)
        : inputRaws(inputRaws), outputPartials(outputPartials), isOverflowAsNull(isOverflowAsNull)
    {
        for (auto &agg : aggs) {
            agg->SetExecutionContext(executionContext.get());
        }
        this->aggregators = std::move(aggs);
    }

    ~AggregationCommonOperator() override {};

protected:
    std::vector<std::unique_ptr<Aggregator>> aggregators;
    std::vector<bool> inputRaws;
    std::vector<bool> outputPartials;
    bool isOverflowAsNull;
};

class AggregationCommonOperatorFactory : public OperatorFactory {
public:
    AggregationCommonOperatorFactory(std::vector<bool> &inputRaws, std::vector<bool> &outputPartials,
        std::vector<uint32_t> &maskColsContext, bool isOverflowAsNull = false, bool isStatisticalAggregate = false)
        : inputRaws(inputRaws),
          outputPartials(outputPartials),
          isOverflowAsNull(isOverflowAsNull),
          isStatisticalAggregate(isStatisticalAggregate)
    {
        for (size_t i = 0; i < maskColsContext.size(); ++i) {
            maskCols.push_back(maskColsContext[i]);
        }
    }

    ~AggregationCommonOperatorFactory() override {};

    std::vector<int32_t> &GetMaskColumns()
    {
        return maskCols;
    }

    virtual OmniStatus Init() = 0;
    virtual OmniStatus Close() = 0;

    template <class T>
    void CreateAggregatorFactory(std::vector<std::unique_ptr<AggregatorFactory>> &aggregatorFactories, int32_t maskCol);

    OmniStatus CreateAggregatorFactories(std::vector<std::unique_ptr<AggregatorFactory>> &aggregatorFactories,
        const std::vector<uint32_t> &funcTypesContext, const std::vector<int32_t> &maskCols);

protected:
    std::vector<bool> inputRaws;
    std::vector<bool> outputPartials;
    std::vector<int32_t> maskCols;
    bool isOverflowAsNull;
    bool isStatisticalAggregate;
};
}
}
#endif