#include <gtest/gtest.h>
#include "table/data/vectorbatch/VectorBatch.h"
#include "table/data/binary/BinaryRowData.h"
#include "table/runtime/keyselector/KeySelector.h"
#include <OmniOperatorJIT/core/test/util/test_util.h>
TEST(KeySelectorTest, SelectFromVB) {
auto rawVB = new omniruntime::vec::VectorBatch(5);
std::vector<long> col0 {0, 1, 2, 3, 4};
std::vector<std::string> col1 {"hello_world_0", "hello_world_1", "hello_world_2", "hello_world_3", "hello_world_4"};
std::vector<int> col2 {20, 21, 22, 23, 24};
std::vector<long> col3 {1744053374, 1744053375, 1744053376, 1744053377, 1744053378};
BaseVector* vec = omniruntime::TestUtil::CreateVector(col0.size(), col0.data());
rawVB->Append(vec);
vec = omniruntime::TestUtil::CreateVarcharVector(col1.data(), col1.size());
rawVB->Append(vec);
vec = omniruntime::TestUtil::CreateVector(col2.size(), col2.data());
rawVB->Append(vec);
vec = omniruntime::TestUtil::CreateVector(col3.size(), col3.data());
rawVB->Append(vec);
omnistream::VectorBatch vb(rawVB, nullptr, nullptr);
std::vector<int> keyCols {0, 1, 2, 3};
std::vector<int> keyTypes {omniruntime::type::OMNI_LONG, omniruntime::type::OMNI_VARCHAR,
omniruntime::type::OMNI_INT, omniruntime::type::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE};
KeySelector<StringRef> selector(keyTypes, keyCols);
auto key = selector.getKey(&vb, 3);
char* expected = new char[24]{};
memset(expected, 1, 1);
memcpy(expected + 1, col0.data() + 3, 1);
memset(expected + 2, 1, 1);
memset(expected + 3, 13, 1);
memcpy(expected + 4, col1[3].data(), col1[3].size());
memset(expected + 17, 1, 1);
memcpy(expected + 18, col2.data() + 3, 1);
memset(expected + 19, 4, 1);
memcpy(expected + 20, col3.data() + 3, 4);
StringRef expectedKey(expected, 24);
EXPECT_TRUE(expectedKey == key);
}
TEST(KeySelectorTest, DeserializeToVB) {
auto rawVB = new omniruntime::vec::VectorBatch(1);
auto vec0 = new omniruntime::vec::Vector<int64_t>(1);
auto vec1 = new omniruntime::vec::Vector<omniruntime::vec::LargeStringContainer<std::string_view>>(1);
auto vec2 = new omniruntime::vec::Vector<int32_t>(1);
auto vec3 = new omniruntime::vec::Vector<int64_t>(1);
rawVB->Append(vec0);
rawVB->Append(vec1);
rawVB->Append(vec2);
rawVB->Append(vec3);
omnistream::VectorBatch vb(rawVB, nullptr, nullptr);
long key0 = 3;
std::string key1 = "hello_world_3";
int key2 = 23;
long key3 = 1744053374;
char* key = new char[24]{};
memset(key, 1, 1);
memcpy(key + 1, &key0, 1);
memset(key + 2, 1, 1);
memset(key + 3, 13, 1);
memcpy(key + 4, key1.data(), key1.size());
memset(key + 17, 1, 1);
memcpy(key + 18, &key2, 1);
memset(key + 19, 4, 1);
memcpy(key + 20, &key3, 4);
StringRef keyRef(key, 24);
std::vector<int> keyCols {0, 1, 2, 3};
std::vector<int> keyTypes {omniruntime::type::OMNI_LONG, omniruntime::type::OMNI_VARCHAR,
omniruntime::type::OMNI_INT, omniruntime::type::OMNI_TIMESTAMP_WITHOUT_TIME_ZONE};
KeySelector<StringRef> selector(keyTypes, keyCols);
selector.fillKeyToVectorBatch(keyRef, &vb, 0, {0, 1, 2, 3});
std::cout<<vb.GetValueAt<int64_t>(0, 0)<<" "<<std::string(reinterpret_cast<omniruntime::vec::Vector
<omniruntime::vec::LargeStringContainer<std::string_view>>*>(vb.Get(1))->GetValue(0))<<" "
<<vb.GetValueAt<int32_t>(2, 0)<<" "
<<vb.GetValueAt<int64_t>(3, 0)<<std::endl;
EXPECT_TRUE(vb.GetValueAt<int64_t>(0, 0) == key0);
EXPECT_TRUE(std::string(reinterpret_cast<omniruntime::vec::Vector
<omniruntime::vec::LargeStringContainer<std::string_view>>*>(vb.Get(1))->GetValue(0)) == key1);
EXPECT_TRUE(vb.GetValueAt<int32_t>(2, 0) == key2);
EXPECT_TRUE(vb.GetValueAt<int64_t>(3, 0) == key3);
}