/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

#include <cstddef>
#include <AggregateFunctions/IAggregateFunction_fwd.h>
#include <Columns/ColumnAggregateFunction.h>
#include <Columns/ColumnVector.h>
#include <DataTypes/DataTypesNumber.h>
#include <IO/VarInt.h>
#include <base/types.h>
#include "Common/Exception.h"
#include <Common/assert_cast.h>


#include <IO/ReadHelpers.h>
#include <Interpreters/BloomFilter.h>

namespace DB::ErrorCodes
{
extern const int BAD_ARGUMENTS;
}

namespace local_engine
{

struct AggregateFunctionGroupBloomFilterData
{
    bool initted = false;
    // small default value because BloomFilter has no default ctor
    DB::BloomFilter bloom_filter = DB::BloomFilter(100, 2, 0);
    static const char * name() { return "groupBloomFilter"; }

    void read(DB::ReadBuffer & in)
    {
        UInt64 filter_size, filter_hashes, seed = 0;
        readVarUInt(filter_size, in);
        readVarUInt(filter_hashes, in);
        readVarUInt(seed, in);
        if unlikely (filter_size == 0)
        {
            initted = false;
        }
        else
        {
            bloom_filter = DB::BloomFilter(DB::BloomFilterParameters(filter_size, filter_hashes, seed));
            auto & v = bloom_filter.getFilter();
            in.readStrict(reinterpret_cast<char *>(v.data()), v.size() * sizeof(v[0]));
            initted = true;
        }
    }

    void write(DB::WriteBuffer & out) const
    {
        if likely (initted)
        {
            writeVarUInt(bloom_filter.getSize(), out);
            writeVarUInt(bloom_filter.getHashes(), out);
            writeVarUInt(bloom_filter.getSeed(), out);
            const auto & v = bloom_filter.getFilter();

            out.write(reinterpret_cast<const char *>(v.data()), v.size() * sizeof(v[0]));
        }
        else
        {
            writeVarUInt(0, out);
            writeVarUInt(0, out);
            writeVarUInt(0, out);
        }
    }
};

// Aggreate Int64 values into a bloom filter.
// For groupFunctionBloomFilter, we don't actually care about the final Int result(currently always return BF byte size).
// We just need its intermediate state, ,i.e. groupFunctionFilterState.
template <typename T, typename Data>
class AggregateFunctionGroupBloomFilter final : public DB::IAggregateFunctionDataHelper<Data, AggregateFunctionGroupBloomFilter<T, Data>>
{
public:
    explicit AggregateFunctionGroupBloomFilter(
        const DB::DataTypes & argument_types_, const DB::Array & parameters_, size_t filter_size_, size_t filter_hashes_, size_t seed_)
        : DB::IAggregateFunctionDataHelper<Data, AggregateFunctionGroupBloomFilter<T, Data>>(argument_types_, parameters_, createResultType())
        , filter_size(filter_size_)
        , filter_hashes(filter_hashes_)
        , seed(seed_)
    {
    }

    String getName() const override { return Data::name(); }

    static DB::DataTypePtr createResultType() { return std::make_shared<DB::DataTypeNumber<T>>(); }

    bool allocatesMemoryInArena() const override { return false; }

    void checkFilterSize(size_t filter_size_) const
    {
        if (filter_size_ <= 0)
        {
            throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "filter_size being LTE 0 means bloom filter is not properly initialized");
        }
    }

    void add(DB::AggregateDataPtr __restrict place, const DB::IColumn ** columns, size_t row_num, DB::Arena *) const override
    {
        if unlikely (!this->data(place).initted)
        {
            checkFilterSize(filter_size);
            this->data(place).bloom_filter = DB::BloomFilter(DB::BloomFilterParameters(filter_size, filter_hashes, seed));
            this->data(place).initted = true;
        }

        T x = assert_cast<const DB::ColumnVector<T> &>(*columns[0]).getData()[row_num];
        this->data(place).bloom_filter.add(reinterpret_cast<const char *>(&x), sizeof(T));
    }

    void merge(DB::AggregateDataPtr __restrict place, DB::ConstAggregateDataPtr rhs, DB::Arena *) const override
    {
        // Skip un-initted values
        if (!this->data(rhs).initted)
        {
            return;
        }
        const auto & bloom_other = this->data(rhs).bloom_filter;
        const auto & filter_other = bloom_other.getFilter();
        if (!this->data(place).initted)
        {
            // We use filter_other's size/hashes/seed to avoid passing these parameters around to construct AggregateFunctionGroupBloomFilter.
            checkFilterSize(bloom_other.getSize());
            this->data(place).bloom_filter = DB::BloomFilter(DB::BloomFilterParameters(bloom_other.getSize(), bloom_other.getHashes(), bloom_other.getSeed()));
            this->data(place).initted = true;
        }
        auto & filter_self = this->data(place).bloom_filter.getFilter();
        for (size_t i = 0; i < filter_other.size(); ++i)
        {
            if (filter_other[i])
            {
                filter_self[i] |= filter_other[i];
            }
        }
    }

    void serialize(DB::ConstAggregateDataPtr __restrict place, DB::WriteBuffer & buf, std::optional<size_t> /* version */) const override
    {
        this->data(place).write(buf);
    }

    void deserialize(DB::AggregateDataPtr __restrict place, DB::ReadBuffer & buf, std::optional<size_t> /* version */, DB::Arena *) const override
    {
        this->data(place).read(buf);
    }

    void insertResultInto(DB::AggregateDataPtr __restrict /*place*/, DB::IColumn & to, DB::Arena *) const override
    {
        assert_cast<DB::ColumnVector<T> &>(to).getData().push_back(static_cast<T>(filter_size));
    }

private:
    size_t filter_size;
    size_t filter_hashes;
    size_t seed;
};

}