/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "substrait/SubstraitExtensionCollector.h"
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/core/PlanNode.h"
#include "velox/functions/prestosql/registration/RegistrationFunctions.h"

using namespace facebook::velox;

namespace gluten {

class SubstraitExtensionCollectorTest : public ::testing::Test {
 protected:
  void SetUp() override {
    Test::SetUp();
    functions::prestosql::registerAllScalarFunctions();
  }

  int getReferenceNumber(const std::string& functionName, std::vector<TypePtr>&& arguments) {
    int referenceNumber1 = extensionCollector_->getReferenceNumber(functionName, arguments);
    // Repeat the call to make sure properly de-duplicated.
    int referenceNumber2 = extensionCollector_->getReferenceNumber(functionName, arguments);
    EXPECT_EQ(referenceNumber1, referenceNumber2);
    return referenceNumber1;
  }

  int getReferenceNumber(
      const std::string& functionName,
      std::vector<TypePtr>&& arguments,
      core::AggregationNode::Step step) {
    int referenceNumber1 = extensionCollector_->getReferenceNumber(functionName, arguments);
    // Repeat the call to make sure properly de-duplicated.
    int referenceNumber2 = extensionCollector_->getReferenceNumber(functionName, arguments);
    EXPECT_EQ(referenceNumber1, referenceNumber2);
    return referenceNumber2;
  }

  /// Given a substrait plan, return the sorted extension functions by the
  /// function anchor.
  ::google::protobuf::RepeatedPtrField<::substrait::extensions::SimpleExtensionDeclaration> getSortedSubstraitExtension(
      const ::substrait::Plan* substraitPlan) {
    auto substraitExtensions = substraitPlan->extensions();
    std::sort(substraitExtensions.begin(), substraitExtensions.end(), [](const auto& a, const auto& b) {
      return a.extension_function().function_anchor() < b.extension_function().function_anchor();
    });

    return substraitExtensions;
  }

  SubstraitExtensionCollectorPtr extensionCollector_ = std::make_shared<SubstraitExtensionCollector>();
};

TEST_F(SubstraitExtensionCollectorTest, getReferenceNumberForScalarFunction) {
  ASSERT_EQ(getReferenceNumber("plus", {INTEGER(), INTEGER()}), 0);
  ASSERT_EQ(getReferenceNumber("divide", {INTEGER(), INTEGER()}), 1);
  ASSERT_EQ(getReferenceNumber("cardinality", {ARRAY(INTEGER())}), 2);
  ASSERT_EQ(getReferenceNumber("array_sum", {ARRAY(INTEGER())}), 3);

  auto functionType = std::make_shared<const FunctionType>(std::vector<TypePtr>{INTEGER(), VARCHAR()}, BIGINT());
  std::vector<TypePtr> types = {MAP(INTEGER(), VARCHAR()), functionType};
  ASSERT_ANY_THROW(getReferenceNumber("transform_keys", std::move(types)));
}

TEST_F(SubstraitExtensionCollectorTest, getReferenceNumberForAggregateFunction) {
  // Sum aggregate function have same argument type for each aggregation step.
  ASSERT_EQ(getReferenceNumber("sum", {INTEGER()}, core::AggregationNode::Step::kSingle), 0);

  // Partial avg aggregate function should use primitive integral type.
  ASSERT_EQ(getReferenceNumber("avg", {INTEGER()}, core::AggregationNode::Step::kPartial), 1);

  // Final avg aggregate function should use struct type, like
  // 'ROW<DOUBLE,BIGINT>'
  ASSERT_EQ(getReferenceNumber("avg", {ROW({DOUBLE(), BIGINT()})}, core::AggregationNode::Step::kFinal), 2);

  // Count aggregate function have same argument type for each aggregation step.
  ASSERT_EQ(getReferenceNumber("count", {INTEGER()}, core::AggregationNode::Step::kFinal), 3);
}

TEST_F(SubstraitExtensionCollectorTest, addExtensionsToPlan) {
  getReferenceNumber("plus", {INTEGER(), INTEGER()});
  getReferenceNumber("divide", {INTEGER(), INTEGER()});
  getReferenceNumber("cardinality", {ARRAY(INTEGER())});
  getReferenceNumber("array_sum", {ARRAY(INTEGER())});
  getReferenceNumber("sum", {INTEGER()});
  getReferenceNumber("avg", {INTEGER()});
  getReferenceNumber("avg", {ROW({DOUBLE(), BIGINT()})});
  getReferenceNumber("count", {INTEGER()});

  google::protobuf::Arena arena;
  auto* substraitPlan = google::protobuf::Arena::CreateMessage<::substrait::Plan>(&arena);

  extensionCollector_->addExtensionsToPlan(substraitPlan);

  const auto& substraitExtensions = getSortedSubstraitExtension(substraitPlan);
  auto getFunctionName = [&](auto id) { return substraitExtensions[id].extension_function().name(); };

  ASSERT_EQ(substraitPlan->extensions().size(), 8);
  ASSERT_EQ(getFunctionName(0), "plus:i32_i32");
  ASSERT_EQ(getFunctionName(1), "divide:i32_i32");
  ASSERT_EQ(getFunctionName(2), "cardinality:list");
  ASSERT_EQ(getFunctionName(3), "array_sum:list");
  ASSERT_EQ(getFunctionName(4), "sum:i32");
  ASSERT_EQ(getFunctionName(5), "avg:i32");
  ASSERT_EQ(getFunctionName(6), "avg:struct<fp64,i64>");
  ASSERT_EQ(getFunctionName(7), "count:i32");
}

} // namespace gluten