/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "AggregateSerializationUtils.h"
#include <AggregateFunctions/IAggregateFunction.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnFixedString.h>
#include <Columns/ColumnString.h>
#include <DataTypes/DataTypeAggregateFunction.h>
#include <DataTypes/DataTypeFixedString.h>
#include <DataTypes/DataTypeNullable.h>
#include <DataTypes/DataTypeString.h>
#include <Functions/FunctionHelpers.h>
#include <IO/WriteBufferFromVector.h>
#include <IO/WriteHelpers.h>
#include <Common/Arena.h>

using namespace DB;

namespace local_engine
{

bool isFixedSizeStateAggregateFunction(const String& name)
{
    static const std::set<String> function_set = {"min", "max", "sum", "count", "avg"};
    return function_set.contains(name);
}

bool isFixedSizeArguments(const DataTypes& data_types)
{
    return removeNullable(data_types.front())->isValueRepresentedByNumber();
}

bool isFixedSizeAggregateFunction(const DB::AggregateFunctionPtr& function)
{
    return isFixedSizeStateAggregateFunction(function->getName()) && isFixedSizeArguments(function->getArgumentTypes());
}

DB::ColumnWithTypeAndName convertAggregateStateToFixedString(const DB::ColumnWithTypeAndName& col)
{
    const auto *aggregate_col = checkAndGetColumn<ColumnAggregateFunction>(&*col.column);
    if (!aggregate_col)
    {
        return col;
    }
    // only support known fixed size aggregate function
    if (!isFixedSizeAggregateFunction(aggregate_col->getAggregateFunction()))
    {
        return col;
    }
    size_t state_size = aggregate_col->getAggregateFunction()->sizeOfData();
    auto res_type = std::make_shared<DataTypeFixedString>(state_size);
    auto res_col = res_type->createColumn();
    PaddedPODArray<UInt8> & column_chars_t = assert_cast<ColumnFixedString &>(*res_col).getChars();
    column_chars_t.reserve_exact(aggregate_col->size() * state_size);
    for (const auto & item : aggregate_col->getData())
    {
        column_chars_t.insert_assume_reserved(item, item + state_size);
    }
    return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name);
}

DB::ColumnWithTypeAndName convertAggregateStateToString(const DB::ColumnWithTypeAndName& col)
{
    const auto *aggregate_col = checkAndGetColumn<ColumnAggregateFunction>(&*col.column);
    if (!aggregate_col)
    {
        return col;
    }
    auto res_type = std::make_shared<DataTypeString>();
    auto res_col = res_type->createColumn();
    PaddedPODArray<UInt8> & column_chars = assert_cast<ColumnString &>(*res_col).getChars();
    IColumn::Offsets & column_offsets = assert_cast<ColumnString &>(*res_col).getOffsets();
    auto value_writer = WriteBufferFromVector<PaddedPODArray<UInt8>>(column_chars);
    column_offsets.reserve_exact(aggregate_col->size());
    for (const auto & item : aggregate_col->getData())
    {
        aggregate_col->getAggregateFunction()->serialize(item, value_writer);
        writeChar('\0', value_writer);
        column_offsets.emplace_back(value_writer.count());
    }
    return DB::ColumnWithTypeAndName(std::move(res_col), res_type, col.name);
}

DB::ColumnWithTypeAndName convertFixedStringToAggregateState(const DB::ColumnWithTypeAndName & col, const DB::DataTypePtr & type)
{
    chassert(WhichDataType(type).isAggregateFunction());
    auto res_col = type->createColumn();
    const auto * agg_type = checkAndGetDataType<DataTypeAggregateFunction>(type.get());
    ColumnAggregateFunction & real_column = typeid_cast<ColumnAggregateFunction &>(*res_col);
    auto & arena = real_column.createOrGetArena();
    ColumnAggregateFunction::Container & vec = real_column.getData();
    vec.reserve_exact(col.column->size());
    auto agg_function = agg_type->getFunction();
    size_t size_of_state = agg_function->sizeOfData();
    size_t align_of_state = agg_function->alignOfData();

    for (size_t i = 0; i < col.column->size(); ++i)
    {
        AggregateDataPtr place = arena.alignedAlloc(size_of_state, align_of_state);

        agg_function->create(place);

        auto value = col.column->getDataAt(i);
        memcpy(place, value.data, value.size);

        vec.push_back(place);
    }
    return DB::ColumnWithTypeAndName(std::move(res_col), type, col.name);
}

DB::Block convertAggregateStateInBlock(DB::Block& block)
{
    ColumnsWithTypeAndName columns;
    columns.reserve(block.columns());
    for (const auto & item : block.getColumnsWithTypeAndName())
    {
        if (WhichDataType(item.type).isAggregateFunction())
        {
            const auto *aggregate_col = checkAndGetColumn<ColumnAggregateFunction>(&*item.column);
            if (isFixedSizeAggregateFunction(aggregate_col->getAggregateFunction()))
                columns.emplace_back(convertAggregateStateToFixedString(item));
            else
                columns.emplace_back(convertAggregateStateToString(item));
        }
        else
        {
            columns.emplace_back(item);
        }
    }

    return columns;
}

}