* Copyright (c) Huawei Technologies Co., Ltd. 2022-2022. All rights reserved.
* Description: batch expression codegen
*/
#include "batch_expression_codegen.h"
namespace omniruntime::codegen {
namespace {
const int BATCH_EXPRFUNC_ROWCNT_INDEX = 3;
const int BATCH_EXPRFUNC_OUT_LENGTH_ARG_INDEX = 5;
const int BATCH_EXPRFUNC_OUT_NULL_INDEX = 8;
const int BATCH_EXPRFUNC_OUT_DATA_INDEX = 9;
}
BatchExpressionCodeGen::BatchExpressionCodeGen(std::string name, const Expr &cpExpr, op::OverflowConfig *overflowConfig)
: CodegenBase(name, cpExpr, overflowConfig)
{}
bool BatchExpressionCodeGen::InitializeBatchCodegenContext(iterator_range<llvm::Function::arg_iterator> args)
{
this->batchCodegenContext = std::make_unique<BatchCodegenContext>();
for (auto &arg : args) {
auto argName = arg.getName().str();
if (argName == "data") {
batchCodegenContext->data = &arg;
} else if (argName == "nullBitmap") {
batchCodegenContext->nullBitmap = &arg;
} else if (argName == "offsets") {
batchCodegenContext->offsets = &arg;
} else if (argName == "rowCnt") {
batchCodegenContext->rowCnt = &arg;
} else if (argName == "rowIdxArray") {
batchCodegenContext->rowIdxArray = &arg;
} else if (argName == "outputLength" || argName == "outputNull" || argName == "outputData") {
continue;
} else if (argName == "executionContext") {
batchCodegenContext->executionContext = &arg;
} else if (argName == "dictionaryVectors") {
batchCodegenContext->dictionaryVectors = &arg;
} else {
LogWarn("Invalid argument %s", argName.c_str());
return false;
}
}
return true;
}
llvm::Function *BatchExpressionCodeGen::CreateBatchFunction()
{
std::vector<Type *> args {
llvmTypes->I64PtrType(),
llvmTypes->I64PtrType(),
llvmTypes->I64PtrType(),
llvmTypes->I32Type(),
llvmTypes->I32PtrType(),
llvmTypes->I32PtrType(),
llvmTypes->I64Type(),
llvmTypes->I64PtrType(),
llvmTypes->I1PtrType(),
llvmTypes->ToBatchDataPointerType(expr->GetReturnTypeId())
};
FunctionType *prototype = FunctionType::get(llvmTypes->I32Type(), args, false);
func = llvm::Function::Create(prototype, llvm::Function::ExternalLinkage, funcName, modulePtr);
std::string argNames[] = {
"data", "nullBitmap", "offsets", "rowCnt", "rowIdxArray",
"outputLength", "executionContext", "dictionaryVectors", "outputNull", "outputData"
};
int32_t idx = 0;
for (auto &arg : func->args()) {
arg.setName(argNames[idx]);
idx++;
}
BasicBlock *body = BasicBlock::Create(*context, "CREATED_BATCH_FUNC_BODY", func);
builder->SetInsertPoint(body);
if (!InitializeBatchCodegenContext(func->args())) {
return nullptr;
}
auto result = VisitExpr(*expr);
if (result->data == nullptr) {
return nullptr;
}
if (result->length != nullptr) {
CallExternFunction("batch_copy", { OMNI_INT }, OMNI_INT,
{ func->getArg(BATCH_EXPRFUNC_OUT_LENGTH_ARG_INDEX), result->length,
func->getArg(BATCH_EXPRFUNC_ROWCNT_INDEX) },
nullptr, "copy_length");
}
CallExternFunction("batch_copy", { expr->GetReturnTypeId() }, expr->GetReturnTypeId(),
{ func->getArg(BATCH_EXPRFUNC_OUT_DATA_INDEX), result->data, func->getArg(BATCH_EXPRFUNC_ROWCNT_INDEX) },
nullptr, "copy_data");
CallExternFunction("batch_copy", { OMNI_BOOLEAN }, OMNI_BOOLEAN,
{ func->getArg(BATCH_EXPRFUNC_OUT_NULL_INDEX), result->isNull, func->getArg(BATCH_EXPRFUNC_ROWCNT_INDEX) },
nullptr, "copy_null");
builder->CreateRet(func->getArg(BATCH_EXPRFUNC_ROWCNT_INDEX));
verifyFunction(*func);
return func;
}
CodeGenValuePtr BatchExpressionCodeGen::VisitExpr(const Expr &e)
{
e.Accept(*this);
return this->value;
}
void BatchExpressionCodeGen::Visit(const LiteralExpr &lExpr)
{
this->value.reset(BatchLiteralExprConstantHelper(lExpr));
}
CodeGenValue *BatchExpressionCodeGen::BatchLiteralExprConstantHelper(const LiteralExpr &lExpr)
{
bool isNullLiteral = lExpr.isNull;
Value *isNull = llvmTypes->CreateConstantBool(isNullLiteral);
AllocaInst *nullArrayPtr =
builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "IS_NULL_PTR");
AllocaInst *literalArrayPtr = GetResultArray(lExpr.GetReturnTypeId(), this->batchCodegenContext->rowCnt);
Value *literalValue = nullptr;
Value *length = nullptr;
Value *lengthArrayPtr = nullptr;
Value *precisionVal = nullptr;
Value *scaleVal = nullptr;
switch (lExpr.GetReturnTypeId()) {
case OMNI_INT:
case OMNI_DATE32: {
literalValue = llvmTypes->CreateConstantInt(lExpr.intVal);
break;
}
case OMNI_TIMESTAMP:
case OMNI_LONG: {
literalValue = llvmTypes->CreateConstantLong(lExpr.longVal);
break;
}
case OMNI_DOUBLE: {
literalValue = llvmTypes->CreateConstantDouble(lExpr.doubleVal);
break;
}
case OMNI_CHAR:
case OMNI_VARCHAR: {
literalValue = this->CreateConstantString(*(lExpr.stringVal));
lengthArrayPtr =
builder->CreateAlloca(llvmTypes->I32Type(), this->batchCodegenContext->rowCnt, "LENGTH_PTR");
length = llvmTypes->CreateConstantInt(lExpr.stringVal->length());
break;
}
case OMNI_BOOLEAN: {
literalValue = llvmTypes->CreateConstantBool(lExpr.boolVal);
break;
}
case OMNI_DECIMAL64: {
precisionVal = llvmTypes->CreateConstantInt(
static_cast<Decimal64DataType *>(lExpr.GetReturnType().get())->GetPrecision());
scaleVal =
llvmTypes->CreateConstantInt(static_cast<Decimal64DataType *>(lExpr.GetReturnType().get())->GetScale());
literalValue = llvmTypes->CreateConstantLong(lExpr.longVal);
break;
}
case OMNI_DECIMAL128: {
std::string dec128String = isNullLiteral ? "0" : *lExpr.stringVal;
__uint128_t dec128 = Decimal128Utils::StrToUint128_t(dec128String.c_str());
dec128String = Decimal128Utils::Uint128_tToStr(dec128);
precisionVal = llvmTypes->CreateConstantInt(
dynamic_cast<Decimal128DataType *>(lExpr.GetReturnType().get())->GetPrecision());
scaleVal = llvmTypes->CreateConstantInt(
dynamic_cast<Decimal128DataType *>(lExpr.GetReturnType().get())->GetScale());
literalValue = llvm::ConstantInt::get(llvm::Type::getInt128Ty(*context), dec128String, 10);
break;
}
default: {
LogWarn("Unsupported data type in LITERAL Expr %d", lExpr.GetReturnTypeId());
return new CodeGenValue(nullptr, nullptr);
}
}
std::vector<Value *> funcArgs;
if (TypeUtil::IsStringType(lExpr.GetReturnTypeId())) {
funcArgs = { this->batchCodegenContext->executionContext,
literalArrayPtr,
nullArrayPtr,
lengthArrayPtr,
literalValue,
isNull,
length,
this->batchCodegenContext->rowCnt };
CallExternFunction("batch_fill_literal", { OMNI_VARCHAR }, OMNI_VARCHAR, funcArgs, nullptr,
"fill_literal_array");
return new CodeGenValue(literalArrayPtr, nullArrayPtr, lengthArrayPtr);
} else if (TypeUtil::IsDecimalType(lExpr.GetReturnTypeId())) {
funcArgs = { literalArrayPtr, nullArrayPtr, literalValue, isNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_fill_literal", { lExpr.GetReturnTypeId() }, lExpr.GetReturnTypeId(), funcArgs,
nullptr, "fill_literal_array");
return new DecimalValue(literalArrayPtr, nullArrayPtr, precisionVal, scaleVal);
} else {
funcArgs = { literalArrayPtr, nullArrayPtr, literalValue, isNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_fill_literal", { lExpr.GetReturnTypeId() }, lExpr.GetReturnTypeId(), funcArgs,
nullptr, "fill_literal_array");
return new CodeGenValue(literalArrayPtr, nullArrayPtr);
}
}
void BatchExpressionCodeGen::Visit(const FieldExpr &fExpr)
{
Value *rowCnt = this->batchCodegenContext->rowCnt;
Value *vecBatch = this->batchCodegenContext->data;
Value *bitmap = this->batchCodegenContext->nullBitmap;
Value *offsets = this->batchCodegenContext->offsets;
Value *dictionaryVectors = this->batchCodegenContext->dictionaryVectors;
Value *rowIdxArray = this->batchCodegenContext->rowIdxArray;
Value *colIdx = llvmTypes->CreateConstantInt(fExpr.colVal);
Value *gep = builder->CreateGEP(llvmTypes->I64Type(), vecBatch, colIdx);
auto dictionaryVectorGEP = builder->CreateGEP(llvmTypes->I64Type(), dictionaryVectors, colIdx);
Value *dictionaryVectorPtr = builder->CreateLoad(llvmTypes->I64Type(), dictionaryVectorGEP);
auto condition = builder->CreateIsNotNull(dictionaryVectorPtr);
BasicBlock *trueBlock = BasicBlock::Create(*context, "DICTIONARY_NOT_NULL", func);
BasicBlock *falseBlock = BasicBlock::Create(*context, "DICTIONARY_IS_NULL");
BasicBlock *mergeBlock = BasicBlock::Create(*context, "field_data");
builder->CreateCondBr(condition, trueBlock, falseBlock);
builder->SetInsertPoint(trueBlock);
AllocaInst *dicLengthArray = builder->CreateAlloca(llvmTypes->I32Type(), rowCnt, "dic_varchar_length");
auto dicArrayPtr = this->GetDictionaryVectorValue(*(fExpr.GetReturnType()), rowIdxArray, rowCnt,
dictionaryVectorPtr, dicLengthArray);
builder->CreateBr(mergeBlock);
trueBlock = builder->GetInsertBlock();
func->getBasicBlockList().push_back(falseBlock);
builder->SetInsertPoint(falseBlock);
Value *elementAddr = builder->CreateLoad(llvmTypes->I64Type(), gep);
AllocaInst *lengthArray = nullptr;
Value *dataArrayPtr = nullptr;
if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) {
lengthArray = builder->CreateAlloca(llvmTypes->I32Type(), rowCnt, "varchar_length");
auto offsetsGEP = builder->CreateGEP(llvmTypes->I64Type(), offsets, colIdx);
Value *offsetPtr = builder->CreateLoad(llvmTypes->I64Type(), offsetsGEP);
offsetPtr = builder->CreateIntToPtr(offsetPtr, llvmTypes->I32PtrType());
dataArrayPtr =
this->GetVectorValue(*(fExpr.GetReturnType()), rowIdxArray, rowCnt, elementAddr, offsetPtr, lengthArray);
} else {
dataArrayPtr =
this->GetVectorValue(*(fExpr.GetReturnType()), rowIdxArray, rowCnt, elementAddr, nullptr, nullptr);
}
builder->CreateBr(mergeBlock);
falseBlock = builder->GetInsertBlock();
func->getBasicBlockList().push_back(mergeBlock);
builder->SetInsertPoint(mergeBlock);
int32_t numReservedValues = 2;
Type *phiType = llvmTypes->ToBatchDataPointerType(fExpr.GetReturnTypeId());
PHINode *phiValue = builder->CreatePHI(phiType, numReservedValues, "data");
phiValue->addIncoming(dicArrayPtr, trueBlock);
phiValue->addIncoming(dataArrayPtr, falseBlock);
PHINode *phiLength = nullptr;
if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) {
phiLength = builder->CreatePHI(llvmTypes->I32PtrType(), numReservedValues, "length");
phiLength->addIncoming(dicLengthArray, trueBlock);
phiLength->addIncoming(lengthArray, falseBlock);
}
auto bitmapGEP = builder->CreateGEP(llvmTypes->I64Type(), bitmap, colIdx);
Value *nullBitsPtr = builder->CreateLoad(llvmTypes->I64Type(), bitmapGEP);
nullBitsPtr = builder->CreateIntToPtr(nullBitsPtr, llvmTypes->I32PtrType());
auto dstNullArray = GetResultArray(OMNI_BOOLEAN, rowCnt);
CallExternFunction("batch_BitsToNullArray", { OMNI_BOOLEAN }, OMNI_BOOLEAN,
{ dstNullArray, nullBitsPtr, rowIdxArray, rowCnt }, nullptr, "copy_null");
if (TypeUtil::IsDecimalType(fExpr.GetReturnTypeId())) {
Value *precision =
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(fExpr.GetReturnType().get())->GetPrecision());
Value *scale =
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(fExpr.GetReturnType().get())->GetScale());
this->value = std::make_shared<DecimalValue>(phiValue, dstNullArray, precision, scale);
} else if (TypeUtil::IsStringType(fExpr.GetReturnTypeId())) {
this->value = std::make_shared<CodeGenValue>(phiValue, dstNullArray, phiLength);
} else {
this->value = std::make_shared<CodeGenValue>(phiValue, dstNullArray);
}
}
Value *BatchExpressionCodeGen::GetDictionaryVectorValue(const DataType &dataType, Value *rowIdxArray,
llvm::Value *rowCnt, llvm::Value *dictionaryVectorPtr, AllocaInst *lengthArrayPtr)
{
std::vector<DataTypeId> paramTypes = { OMNI_LONG };
DataTypeId retTypeId = dataType.GetId();
AllocaInst *dataArrayPtr = GetResultArray(retTypeId, rowCnt);
std::vector<Value *> funcArgs;
if (TypeUtil::IsStringType(retTypeId)) {
funcArgs = { batchCodegenContext->executionContext,
dictionaryVectorPtr,
rowIdxArray,
rowCnt,
dataArrayPtr,
lengthArrayPtr };
} else {
funcArgs = { dictionaryVectorPtr, rowIdxArray, rowCnt, dataArrayPtr };
}
CallExternFunction("batch_GetDic", { OMNI_LONG }, retTypeId, funcArgs, nullptr, "get_dictionary_value");
return dataArrayPtr;
}
Value *BatchExpressionCodeGen::GetVectorValue(const DataType &dataType, Value *rowIdxArray, llvm::Value *rowCnt,
llvm::Value *dataVectorPtr, Value *offsetArrayPtr, llvm::Value *lengthArrayPtr)
{
std::vector<DataTypeId> paramTypes = { OMNI_LONG };
DataTypeId retTypeId = dataType.GetId();
AllocaInst *dataArrayPtr = GetResultArray(retTypeId, rowCnt);
std::vector<Value *> funcArgs;
if (TypeUtil::IsStringType(retTypeId)) {
funcArgs = { batchCodegenContext->executionContext,
offsetArrayPtr,
dataVectorPtr,
rowIdxArray,
rowCnt,
dataArrayPtr,
lengthArrayPtr };
} else {
funcArgs = { dataVectorPtr, rowIdxArray, rowCnt, dataArrayPtr };
}
CallExternFunction("batch_GetData", { OMNI_LONG }, retTypeId, funcArgs, nullptr, "get_vector_value");
return dataArrayPtr;
}
void BatchExpressionCodeGen::Visit(const UnaryExpr &uExpr)
{
auto val = VisitExpr(*(uExpr.exp));
if (!val->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
switch (uExpr.op) {
case omniruntime::expressions::Operator::NOT: {
std::vector<Value *> funcArgs { val->data, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_not", { uExpr.exp->GetReturnTypeId() }, uExpr.GetReturnTypeId(), funcArgs,
nullptr, "logical_not");
this->value = std::make_shared<CodeGenValue>(val->data, val->isNull);
break;
}
default: {
this->value = CreateInvalidCodeGenValue();
break;
}
}
}
void BatchExpressionCodeGen::Visit(const BinaryExpr &binaryExpr)
{
auto *bExpr = const_cast<BinaryExpr *>(&binaryExpr);
CodeGenValuePtr left = VisitExpr(*(bExpr->left));
if (!left->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
Value *leftValue = left->data;
Value *leftLen = left->length;
Value *leftNull = left->isNull;
CodeGenValuePtr right = VisitExpr(*(bExpr->right));
if (!right->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
Value *rightValue = right->data;
Value *rightLen = right->length;
Value *rightNull = right->isNull;
if (bExpr->op == omniruntime::expressions::Operator::AND) {
std::vector<omniruntime::type::DataTypeId> boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN };
std::vector<Value *> andFuncParams { leftValue, leftNull, rightValue, rightNull,
this->batchCodegenContext->rowCnt };
CallExternFunction("batch_and_expr", boolParams, OMNI_BOOLEAN, andFuncParams, nullptr, "and_expr");
this->value = std::make_shared<CodeGenValue>(leftValue, leftNull);
return;
}
if (bExpr->op == omniruntime::expressions::Operator::OR) {
std::vector<omniruntime::type::DataTypeId> boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN };
std::vector<Value *> orFuncParams { leftValue, leftNull, rightValue, rightNull,
this->batchCodegenContext->rowCnt };
CallExternFunction("batch_or_expr", boolParams, OMNI_BOOLEAN, orFuncParams, nullptr, "or_expr");
this->value = std::make_shared<CodeGenValue>(leftValue, leftNull);
return;
}
if (bExpr->left->GetReturnTypeId() == OMNI_INT || bExpr->left->GetReturnTypeId() == OMNI_DATE32 ||
bExpr->left->GetReturnTypeId() == OMNI_LONG || bExpr->left->GetReturnTypeId() == OMNI_TIMESTAMP) {
this->BatchBinaryExprIntLongHelper(bExpr, leftValue, rightValue, leftNull, rightNull);
return;
} else if (bExpr->left->GetReturnTypeId() == OMNI_DOUBLE) {
this->BatchBinaryExprDoubleHelper(bExpr, leftValue, rightValue, leftNull, rightNull);
return;
} else if (TypeUtil::IsStringType(bExpr->left->GetReturnTypeId())) {
this->BatchBinaryExprStringHelper(bExpr, leftValue, leftLen, rightValue, rightLen, leftNull, rightNull);
return;
} else if (TypeUtil::IsDecimalType(bExpr->left->GetReturnTypeId())) {
this->BatchBinaryExprDecimalHelper(bExpr, static_cast<DecimalValue &>(*left.get()),
static_cast<DecimalValue &>(*right.get()), leftNull, rightNull);
return;
}
LogWarn("Unsupported data type for BINARY expr %d", bExpr->left->GetReturnTypeId());
this->value = CreateInvalidCodeGenValue();
}
void BatchExpressionCodeGen::Visit(const BetweenExpr &btExpr)
{
auto bExpr = const_cast<BetweenExpr *>(&btExpr);
auto val = VisitExpr(*(bExpr->value));
if (!val->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
auto lowerVal = VisitExpr(*(bExpr->lowerBound));
if (!lowerVal->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
auto upperVal = VisitExpr(*(bExpr->upperBound));
if (!upperVal->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
AllocaInst *cmpLeft = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "cmpLeft");
AllocaInst *cmpRight = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "cmpRight");
std::pair<AllocaInst **, AllocaInst **> cmpPair = std::make_pair(&cmpLeft, &cmpRight);
BatchVisitBetweenExprHelper(*bExpr, val, lowerVal, upperVal, cmpPair);
}
void BatchExpressionCodeGen::Visit(const IsNullExpr &isNullExpr)
{
Expr *valueExpr = isNullExpr.value;
auto isNullValue = VisitExpr(*valueExpr);
if (!isNullValue->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
std::vector<Value *> funcArgs { isNullValue->isNull, llvmTypes->CreateConstantBool(true),
this->batchCodegenContext->rowCnt };
CallExternFunction("batch_equal", { OMNI_BOOLEAN, OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr, "is_null");
AllocaInst *nullArrayPtr =
builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "IS_NULL_PTR");
funcArgs = { nullArrayPtr, llvmTypes->CreateConstantBool(false), this->batchCodegenContext->rowCnt };
CallExternFunction("batch_fill_null", { OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr, "batch_fill_null");
this->value = std::make_shared<CodeGenValue>(isNullValue->isNull, nullArrayPtr);
}
llvm::AllocaInst *BatchExpressionCodeGen::GetResultArray(omniruntime::type::DataTypeId dataTypeId, Value *rowCnt)
{
AllocaInst *resultArray = nullptr;
switch (dataTypeId) {
case OMNI_INT:
case OMNI_DATE32: {
resultArray = builder->CreateAlloca(llvmTypes->I32Type(), rowCnt, "DATA_PTR");
break;
}
case OMNI_TIMESTAMP:
case OMNI_DECIMAL64:
case OMNI_LONG: {
resultArray = builder->CreateAlloca(llvmTypes->I64Type(), rowCnt, "DATA_PTR");
break;
}
case OMNI_DOUBLE: {
resultArray = builder->CreateAlloca(llvmTypes->DoubleType(), rowCnt, "DATA_PTR");
break;
}
case OMNI_CHAR:
case OMNI_VARCHAR: {
resultArray = builder->CreateAlloca(llvmTypes->I8PtrType(), rowCnt, "DATA_PTR");
break;
}
case OMNI_BOOLEAN: {
resultArray = builder->CreateAlloca(llvmTypes->I1Type(), rowCnt, "DATA_PTR");
break;
}
case OMNI_DECIMAL128: {
resultArray = builder->CreateAlloca(llvmTypes->I128Type(), rowCnt, "DATA_PTR");
break;
}
default: {
LogWarn("Unsupported type when creating array %d", dataTypeId);
break;
}
}
if (resultArray == nullptr) {
LogWarn("Failed to create result array");
}
return resultArray;
}
static std::string ChangeFuncNameToNull(const FuncExpr &fExpr)
{
auto typeSize = static_cast<int32_t>(fExpr.arguments.size() + 1);
auto originalFuncName = fExpr.function->GetId();
auto originalFuncChars = originalFuncName.c_str();
int32_t separatorIdx = 0;
auto pos = static_cast<int32_t>(originalFuncName.length() - 1);
for (; pos >= 0; pos--) {
if (originalFuncChars[pos] == '_') {
separatorIdx++;
if (separatorIdx == typeSize) {
break;
}
}
}
return originalFuncName.insert(pos, "_null");
}
void BatchExpressionCodeGen::FuncExprOverflowNullHelper(const FuncExpr &fExpr)
{
Value *falseValue = llvmTypes->CreateConstantBool(false);
AllocaInst *isAnyNull =
builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "IS_NULL_PTR");
std::vector<Value *> funcArgs { isAnyNull, falseValue, this->batchCodegenContext->rowCnt };
auto ret =
CallExternFunction("batch_fill_null", { OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr, "fill_null_array");
AllocaInst *overflowNull =
builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "OVERFLOW_NULL_PTR");
funcArgs = { overflowNull, falseValue, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_fill_null", { OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr,
"fill_overflow_null_array");
DataTypeId funcRetType = fExpr.GetReturnTypeId();
bool isInvalidExpr = false;
auto argVals = GetDataAndOverflowNullArgs(fExpr, isAnyNull, isInvalidExpr, overflowNull);
if (isInvalidExpr) {
this->value = CreateInvalidCodeGenValue();
return;
}
Value *isNullArray = PushAndGetNullFlagArray(fExpr, argVals, isAnyNull, false);
AllocaInst *resultArray = GetResultArray(funcRetType, this->batchCodegenContext->rowCnt);
argVals.push_back(resultArray);
AllocaInst *outputLenPtr = nullptr;
if (TypeUtil::IsStringType(funcRetType)) {
outputLenPtr = builder->CreateAlloca(llvmTypes->I32Type(), this->batchCodegenContext->rowCnt, "output_len");
auto defaultLength =
llvmTypes->CreateConstantInt(static_cast<CharDataType *>(fExpr.GetReturnType().get())->GetWidth());
funcArgs = { outputLenPtr, defaultLength, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_fill_length_literal", { OMNI_INT }, OMNI_INT, funcArgs, nullptr,
"fill_literal_array");
argVals.push_back(outputLenPtr);
if (FuncExpr::IsCastStrStr(fExpr)) {
argVals.push_back(
llvmTypes->CreateConstantInt(static_cast<VarcharDataType *>(fExpr.GetReturnType().get())->GetWidth()));
}
} else if (TypeUtil::IsDecimalType(funcRetType)) {
argVals.push_back(
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(fExpr.GetReturnType().get())->GetPrecision()));
argVals.push_back(
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(fExpr.GetReturnType().get())->GetScale()));
}
argVals.push_back(this->batchCodegenContext->rowCnt);
auto f = modulePtr->getFunction("batch_" + ChangeFuncNameToNull(fExpr));
if (f) {
ret = CreateCall(f, argVals, fExpr.function->GetId());
InlineFunctionInfo inlineFunctionInfo;
llvm::InlineFunction(*((CallInst *)ret), inlineFunctionInfo);
} else {
LogWarn("Unable to generate function : %s", fExpr.funcName.c_str());
this->value = CreateInvalidCodeGenValue();
return;
}
CallExternFunction("batch_or", { OMNI_BOOLEAN, OMNI_BOOLEAN }, OMNI_BOOLEAN,
{ isNullArray, overflowNull, this->batchCodegenContext->rowCnt }, nullptr);
if (TypeUtil::IsDecimalType(funcRetType)) {
this->value = std::make_shared<DecimalValue>(resultArray, isNullArray,
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(fExpr.GetReturnType().get())->GetPrecision()),
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(fExpr.GetReturnType().get())->GetScale()));
} else {
this->value = std::make_shared<CodeGenValue>(resultArray, isNullArray, outputLenPtr);
}
}
std::vector<llvm::Value *> BatchExpressionCodeGen::GetDataAndOverflowNullArgs(
const omniruntime::expressions::FuncExpr &fExpr, AllocaInst *isAnyNull, bool &isInvalidExpr,
AllocaInst *overflowNull)
{
std::vector<Value *> argVals;
argVals.push_back(overflowNull);
CodeGenValuePtr resultPtr;
int numArgs = fExpr.arguments.size();
std::vector<Value *> nullFuncParams;
auto signature = fExpr.function->GetSignatures()[0];
if (FunctionRegistry::IsNullExecutionContextSet(&signature)) {
argVals.push_back(this->batchCodegenContext->executionContext);
}
for (int i = 0; i < numArgs; i++) {
Expr *argN = fExpr.arguments[i];
resultPtr = VisitExpr(*argN);
if (!resultPtr->IsValidValue()) {
isInvalidExpr = true;
return argVals;
}
argVals.push_back(resultPtr->data);
nullFuncParams = { isAnyNull, resultPtr->isNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_or", { OMNI_BOOLEAN, OMNI_BOOLEAN }, OMNI_BOOLEAN, nullFuncParams, nullptr,
"either_null");
if ((TypeUtil::IsStringType(argN->GetReturnTypeId()))) {
if (argN->GetReturnTypeId() == OMNI_CHAR) {
argVals.push_back(
llvmTypes->CreateConstantInt(static_cast<CharDataType *>(argN->GetReturnType().get())->GetWidth()));
}
if (FuncExpr::IsCastStrStr(fExpr)) {
argVals.push_back(llvmTypes->CreateConstantInt(
static_cast<VarcharDataType *>(argN->GetReturnType().get())->GetWidth()));
}
argVals.push_back(this->value->length);
}
if (TypeUtil::IsDecimalType(argN->GetReturnTypeId())) {
argVals.push_back(llvmTypes->CreateConstantInt(
static_cast<DecimalDataType *>(argN->GetReturnType().get())->GetPrecision()));
argVals.push_back(
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(argN->GetReturnType().get())->GetScale()));
}
if (fExpr.function->GetNullableResultType() == INPUT_DATA_AND_NULL_AND_RETURN_NULL) {
argVals.push_back(this->value->isNull);
}
}
return argVals;
}
template <bool isNeedVerifyResult, bool isNeedVerifyVal>
std::vector<llvm::Value *> BatchExpressionCodeGen::GetDefaultFunctionArgValues(
const FuncExpr &fExpr,
AllocaInst *isAnyNull,
bool &isInvalidExpr)
{
std::vector<Value *> argVals;
CodeGenValuePtr resultPtr;
int numArgs = fExpr.arguments.size();
std::vector<Value *> nullFuncParams;
if (fExpr.function->IsExecutionContextSet()) {
argVals.push_back(this->batchCodegenContext->executionContext);
}
for (int i = 0; i < numArgs; i++) {
Expr *argN = fExpr.arguments[i];
resultPtr = VisitExpr(*argN);
if (!resultPtr->IsValidValue()) {
isInvalidExpr = true;
return argVals;
}
argVals.push_back(resultPtr->data);
if constexpr (isNeedVerifyResult) {
nullFuncParams = { isAnyNull, resultPtr->isNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_or", { OMNI_BOOLEAN, OMNI_BOOLEAN }, OMNI_BOOLEAN, nullFuncParams, nullptr,
"either_null");
}
if ((TypeUtil::IsStringType(argN->GetReturnTypeId()))) {
if (argN->GetReturnTypeId() == OMNI_CHAR) {
argVals.push_back(
llvmTypes->CreateConstantInt(
static_cast<CharDataType *>(argN->GetReturnType().get())->GetWidth()));
}
if (FuncExpr::IsCastStrStr(fExpr)) {
argVals.push_back(llvmTypes->CreateConstantInt(
static_cast<VarcharDataType *>(argN->GetReturnType().get())->GetWidth()));
}
argVals.push_back(this->value->length);
}
if (TypeUtil::IsDecimalType(argN->GetReturnTypeId())) {
argVals.push_back(llvmTypes->CreateConstantInt(
static_cast<DecimalDataType *>(argN->GetReturnType().get())->GetPrecision()));
argVals.push_back(
llvmTypes->CreateConstantInt(
static_cast<DecimalDataType *>(argN->GetReturnType().get())->GetScale()));
}
if constexpr (isNeedVerifyVal) {
argVals.push_back(this->value->isNull);
}
}
return argVals;
}
inline std::vector<llvm::Value *> BatchExpressionCodeGen::GetDataArgs(const FuncExpr &fExpr, AllocaInst *isAnyNull,
bool &isInvalidExpr)
{
return GetDefaultFunctionArgValues<true, false>(fExpr, isAnyNull, isInvalidExpr);
}
inline std::vector<llvm::Value *> BatchExpressionCodeGen::GetDataAndNullArgs(const FuncExpr &fExpr,
AllocaInst *isAnyNull, bool &isInvalidExpr)
{
return GetDefaultFunctionArgValues<false, true>(fExpr, isAnyNull, isInvalidExpr);
}
inline std::vector<llvm::Value *> BatchExpressionCodeGen::GetDataAndNullArgsAndReturnNull(const FuncExpr &fExpr,
AllocaInst *isAnyNull, bool &isInvalidExpr)
{
return GetDefaultFunctionArgValues<true, true>(fExpr, isAnyNull, isInvalidExpr);
}
std::vector<llvm::Value *> BatchExpressionCodeGen::GetFunctionArgValues(const omniruntime::expressions::FuncExpr &fExpr,
AllocaInst *isAnyNull, bool &isInvalidExpr)
{
switch (fExpr.function->GetNullableResultType()) {
case INPUT_DATA:
return GetDataArgs(fExpr, isAnyNull, isInvalidExpr);
case INPUT_DATA_AND_NULL:
return GetDataAndNullArgs(fExpr, isAnyNull, isInvalidExpr);
case INPUT_DATA_AND_NULL_AND_RETURN_NULL:
return GetDataAndNullArgsAndReturnNull(fExpr, isAnyNull, isInvalidExpr);
default:
return GetDataArgs(fExpr, isAnyNull, isInvalidExpr);
}
}
Value *BatchExpressionCodeGen::ArenaAlloc(Value *sizeInBytes)
{
return CallExternFunction("ArenaAllocatorMalloc", { OMNI_LONG, OMNI_INT }, OMNI_CHAR,
{ batchCodegenContext->executionContext, sizeInBytes }, nullptr);
}
Value *BatchExpressionCodeGen::GetTypeSize(DataTypeId dataTypeId)
{
int32_t typeSize = 0;
switch (dataTypeId) {
case OMNI_INT:
case OMNI_DATE32:
typeSize = sizeof(int32_t);
break;
case OMNI_TIMESTAMP:
case OMNI_LONG:
case OMNI_DECIMAL64:
typeSize = sizeof(int64_t);
break;
case OMNI_DOUBLE:
typeSize = sizeof(double);
break;
case OMNI_BOOLEAN:
typeSize = sizeof(bool);
break;
case OMNI_SHORT:
typeSize = sizeof(int16_t);
break;
case OMNI_DECIMAL128:
typeSize = 2 * sizeof(int64_t);
break;
case OMNI_CHAR:
case OMNI_VARCHAR:
typeSize = sizeof(int64_t);
break;
default:
LogWarn("Unsupported data type in UDF funcExpr %d", dataTypeId);
return nullptr;
}
return llvmTypes->CreateConstantInt(typeSize);
}
std::vector<llvm::Value *> BatchExpressionCodeGen::GetHiveUdfArgValues(const FuncExpr &fExpr, bool &isInvalidExpr)
{
auto argSize = static_cast<int32_t>(fExpr.arguments.size());
auto size = llvmTypes->CreateConstantInt(argSize);
auto valueAddrArray = builder->CreateAlloca(llvmTypes->I64Type(), size);
auto nullAddrArray = builder->CreateAlloca(llvmTypes->I64Type(), size);
auto lengthAddrArray = builder->CreateAlloca(llvmTypes->I64Type(), size);
std::vector<Value *> argVals;
for (int32_t i = 0; i < argSize; i++) {
auto argExpr = fExpr.arguments[i];
auto argExprResult = VisitExpr(*argExpr);
if (!argExprResult->IsValidValue()) {
isInvalidExpr = true;
return argVals;
}
auto valuePtr = builder->CreateGEP(llvmTypes->I64Type(), valueAddrArray, llvmTypes->CreateConstantInt(i));
auto nullPtr = builder->CreateGEP(llvmTypes->I64Type(), nullAddrArray, llvmTypes->CreateConstantInt(i));
auto lengthPtr = builder->CreateGEP(llvmTypes->I64Type(), lengthAddrArray, llvmTypes->CreateConstantInt(i));
builder->CreateStore(argExprResult->data, valuePtr);
builder->CreateStore(argExprResult->isNull, nullPtr);
auto length = TypeUtil::IsStringType(argExpr->GetReturnTypeId()) ? argExprResult->length :
llvmTypes->CreateConstantLong(0);
builder->CreateStore(length, lengthPtr);
}
argVals.push_back(valueAddrArray);
argVals.push_back(nullAddrArray);
argVals.push_back(lengthAddrArray);
return argVals;
}
Value *BatchExpressionCodeGen::CreateHiveUdfArgTypes(const FuncExpr &fExpr)
{
auto elementSize = static_cast<int32_t>(fExpr.arguments.size());
auto alloca = builder->CreateAlloca(llvmTypes->I32Type(), llvmTypes->CreateConstantInt(elementSize));
for (int32_t i = 0; i < elementSize; i++) {
auto ptr = builder->CreateGEP(llvmTypes->I32Type(), alloca, llvmTypes->CreateConstantInt(i));
builder->CreateStore(llvmTypes->CreateConstantInt(fExpr.arguments[i]->GetReturnTypeId()), ptr);
}
return alloca;
}
void BatchExpressionCodeGen::CallHiveUdfFunction(const FuncExpr &fExpr)
{
auto returnTypeId = fExpr.GetReturnTypeId();
std::vector<Value *> argVals;
argVals.emplace_back(batchCodegenContext->executionContext);
argVals.emplace_back(CreateConstantString(fExpr.funcName));
argVals.emplace_back(CreateHiveUdfArgTypes(fExpr));
argVals.emplace_back(llvmTypes->CreateConstantInt(returnTypeId));
argVals.emplace_back(llvmTypes->CreateConstantInt(fExpr.arguments.size()));
argVals.emplace_back(batchCodegenContext->rowCnt);
bool isInvalidExpr = false;
auto inputArgs = GetHiveUdfArgValues(fExpr, isInvalidExpr);
if (isInvalidExpr) {
this->value = CreateInvalidCodeGenValue();
return;
}
argVals.insert(argVals.end(), inputArgs.begin(),
inputArgs.end());
auto returnTypeSize = GetTypeSize(returnTypeId);
if (returnTypeSize == nullptr) {
this->value = CreateInvalidCodeGenValue();
return;
}
auto arraySize = batchCodegenContext->rowCnt;
auto outputValuePtr = ArenaAlloc(builder->CreateMul(returnTypeSize, arraySize));
auto outputNullPtr = ArenaAlloc(arraySize);
auto outputLengthPtr = TypeUtil::IsStringType(returnTypeId) ?
ArenaAlloc(builder->CreateMul(GetTypeSize(OMNI_INT), arraySize)) :
llvmTypes->CreateConstantLong(0);
argVals.emplace_back(outputValuePtr);
argVals.emplace_back(outputNullPtr);
argVals.emplace_back(outputLengthPtr);
auto signature = FunctionSignature("EvaluateHiveUdfBatch", std::vector<DataTypeId> {}, OMNI_INT);
auto function = FunctionRegistry::LookupFunction(&signature);
auto f = modulePtr->getFunction(function->GetId());
if (f) {
auto ret = CreateCall(f, argVals, "call_evaluate_hive_udf");
InlineFunctionInfo inlineFunctionInfo;
llvm::InlineFunction(*((CallInst *)ret), inlineFunctionInfo);
this->value = std::make_shared<CodeGenValue>(outputValuePtr, outputNullPtr,
TypeUtil::IsStringType(returnTypeId) ? outputLengthPtr : nullptr);
} else {
LogWarn("Unable to generate udf function : %s", fExpr.funcName.c_str());
this->value = CreateInvalidCodeGenValue();
}
}
void BatchExpressionCodeGen::Visit(const FuncExpr &fExpr)
{
if (fExpr.functionType == HIVE_UDF) {
CallHiveUdfFunction(fExpr);
return;
}
if (this->overflowConfig != nullptr &&
this->overflowConfig->GetOverflowConfigId() == omniruntime::op::OVERFLOW_CONFIG_NULL) {
auto signature = fExpr.function->GetSignatures()[0];
if (FunctionRegistry::LookupNullFunction(&signature)) {
FuncExprOverflowNullHelper(fExpr);
return;
}
}
Value *falseValue = llvmTypes->CreateConstantBool(false);
AllocaInst *isAnyNull =
builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "IS_NULL_PTR");
std::vector<Value *> funcArgs { isAnyNull, falseValue, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_fill_null", { OMNI_BOOLEAN }, OMNI_BOOLEAN, funcArgs, nullptr, "fill_null_array");
DataTypeId funcRetType = fExpr.GetReturnTypeId();
bool isInvalidExpr = false;
auto argVals = GetFunctionArgValues(fExpr, isAnyNull, isInvalidExpr);
if (isInvalidExpr) {
this->value = CreateInvalidCodeGenValue();
return;
}
AllocaInst *resultArray = GetResultArray(funcRetType, this->batchCodegenContext->rowCnt);
Value *isNullArray = PushAndGetNullFlagArray(fExpr, argVals, isAnyNull, true);
argVals.push_back(resultArray);
AllocaInst *outputLenPtr = nullptr;
if (TypeUtil::IsStringType(funcRetType)) {
outputLenPtr = builder->CreateAlloca(llvmTypes->I32Type(), this->batchCodegenContext->rowCnt, "output_len");
auto defaultLength =
llvmTypes->CreateConstantInt(static_cast<CharDataType *>(fExpr.GetReturnType().get())->GetWidth());
funcArgs = { outputLenPtr, defaultLength, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_fill_length_literal", { OMNI_INT }, OMNI_INT, funcArgs, nullptr,
"fill_literal_array");
argVals.push_back(outputLenPtr);
if (FuncExpr::IsCastStrStr(fExpr)) {
argVals.push_back(
llvmTypes->CreateConstantInt(static_cast<VarcharDataType *>(fExpr.GetReturnType().get())->GetWidth()));
}
} else if (TypeUtil::IsDecimalType(funcRetType)) {
argVals.push_back(
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(fExpr.GetReturnType().get())->GetPrecision()));
argVals.push_back(
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(fExpr.GetReturnType().get())->GetScale()));
}
argVals.push_back(this->batchCodegenContext->rowCnt);
auto f = modulePtr->getFunction("batch_" + fExpr.function->GetId());
if (f) {
auto ret = CreateCall(f, argVals, fExpr.function->GetId());
InlineFunctionInfo inlineFunctionInfo;
llvm::InlineFunction(*((CallInst *)ret), inlineFunctionInfo);
} else {
LogWarn("Unable to generate function : %s", fExpr.funcName.c_str());
this->value = CreateInvalidCodeGenValue();
return;
}
if (TypeUtil::IsDecimalType(funcRetType)) {
this->value = std::make_shared<DecimalValue>(resultArray, isNullArray,
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(fExpr.GetReturnType().get())->GetPrecision()),
llvmTypes->CreateConstantInt(static_cast<DecimalDataType *>(fExpr.GetReturnType().get())->GetScale()));
} else {
this->value = std::make_shared<CodeGenValue>(resultArray, isNullArray, outputLenPtr);
}
}
void BatchExpressionCodeGen::Visit(const IfExpr &ifExpr)
{
Expr *cond = ifExpr.condition;
Expr *ifTrue = ifExpr.trueExpr;
Expr *ifFalse = ifExpr.falseExpr;
auto baseType = ifExpr.GetReturnTypeId();
CodeGenValuePtr evCond = VisitExpr(*cond);
if (!evCond->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
auto evTrue = VisitExpr(*ifTrue);
if (!evTrue->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
Value *evTrueValue = evTrue->data;
Value *evTrueLength = evTrue->length;
Value *evTrueNull = evTrue->isNull;
auto evFalse = VisitExpr(*ifFalse);
if (!evFalse->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
Value *evFalseValue = evFalse->data;
Value *evFalseLength = evFalse->length;
Value *evFalseNull = evFalse->isNull;
switch (baseType) {
case OMNI_INT:
case OMNI_DATE32:
case OMNI_LONG:
case OMNI_TIMESTAMP:
case OMNI_DOUBLE:
case OMNI_BOOLEAN: {
CallExternFunction("batch_if", { baseType }, baseType,
{ evCond->data, evCond->isNull, evTrueValue, evTrueNull, evFalseValue, evFalseNull,
this->batchCodegenContext->rowCnt },
nullptr);
this->value = std::make_shared<CodeGenValue>(evTrueValue, evTrueNull);
return;
}
case OMNI_DECIMAL64:
case OMNI_DECIMAL128: {
auto &left = static_cast<DecimalValue &>(*evTrue);
auto &right = static_cast<DecimalValue &>(*evFalse);
std::vector<Value *> argValsCmp { evCond->data,
evCond->isNull,
left.data,
left.isNull,
right.data,
right.isNull,
this->batchCodegenContext->rowCnt };
CallExternFunction("batch_if", { baseType }, baseType, argValsCmp, nullptr);
this->value = std::make_shared<DecimalValue>(left.data, left.isNull,
const_cast<Value *>(left.GetPrecision()), const_cast<Value *>(left.GetScale()));
return;
}
case OMNI_CHAR:
case OMNI_VARCHAR: {
CallExternFunction("batch_if", { baseType }, baseType,
{ evCond->data, evCond->isNull, evTrueValue, evTrueNull, evTrueLength, evFalseValue, evFalseNull,
evFalseLength, this->batchCodegenContext->rowCnt },
nullptr);
this->value = std::make_shared<CodeGenValue>(evTrueValue, evTrueNull, evTrueLength);
return;
}
default: {
LogWarn("Unsupported data type in IF expr %d", baseType);
this->value = CreateInvalidCodeGenValue();
return;
}
}
}
void BatchExpressionCodeGen::Visit(const CoalesceExpr &cExpr)
{
Expr *value1Expr = cExpr.value1;
Expr *value2Expr = cExpr.value2;
CodeGenValuePtr value1 = VisitExpr(*value1Expr);
if (!value1->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
auto value2 = VisitExpr(*value2Expr);
if (!value2->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
if (cExpr.GetReturnTypeId() == OMNI_BOOLEAN || cExpr.GetReturnTypeId() == OMNI_INT ||
cExpr.GetReturnTypeId() == OMNI_LONG || cExpr.GetReturnTypeId() == OMNI_DOUBLE ||
cExpr.GetReturnTypeId() == OMNI_DATE32 || cExpr.GetReturnTypeId() == OMNI_TIMESTAMP) {
CallExternFunction("batch_coalesce", { cExpr.GetReturnTypeId(), cExpr.GetReturnTypeId() },
cExpr.GetReturnTypeId(),
{ value1->data, value1->isNull, value2->data, value2->isNull, this->batchCodegenContext->rowCnt }, nullptr);
this->value = std::make_shared<CodeGenValue>(value1->data, value1->isNull);
} else if (TypeUtil::IsStringType(cExpr.GetReturnTypeId())) {
CallExternFunction("batch_coalesce", { OMNI_VARCHAR, OMNI_VARCHAR }, OMNI_VARCHAR,
{ value1->data, value1->isNull, value1->length, value2->data, value2->isNull, value2->length,
this->batchCodegenContext->rowCnt },
nullptr);
this->value = std::make_shared<CodeGenValue>(value1->data, value1->isNull, value1->length);
} else if (TypeUtil::IsDecimalType(cExpr.GetReturnTypeId())) {
auto value1Precision = (Value *)static_cast<DecimalValue &>(*value1.get()).GetPrecision();
auto value1Scale = (Value *)static_cast<DecimalValue &>(*value1.get()).GetScale();
CallExternFunction("batch_coalesce", { cExpr.GetReturnTypeId(), cExpr.GetReturnTypeId() },
cExpr.GetReturnTypeId(),
{ value1->data, value1->isNull, value2->data, value2->isNull, this->batchCodegenContext->rowCnt }, nullptr);
this->value = std::make_shared<DecimalValue>(value1->data, value1->isNull, value1Precision, value1Scale);
} else {
LogWarn("Unsupported data type in COALESCE expr %d", cExpr.GetReturnTypeId());
this->value = CreateInvalidCodeGenValue();
return;
}
}
void BatchExpressionCodeGen::Visit(const InExpr &inExpr)
{
auto iExpr = const_cast<InExpr *>(&inExpr);
Expr *toCompare = iExpr->arguments[0];
auto baseType = iExpr->arguments[0]->GetReturnTypeId();
int32_t size = iExpr->arguments.size();
auto valueToCompare = VisitExpr(*toCompare);
if (!valueToCompare->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
Value *inArray = GetResultArray(OMNI_BOOLEAN, this->batchCodegenContext->rowCnt);
Value *isNull = GetResultArray(OMNI_BOOLEAN, this->batchCodegenContext->rowCnt);
std::vector<CodeGenValuePtr> cmps(size - 1);
for (int i = 1; i < size; ++i) {
Expr *cmp = iExpr->arguments[i];
cmps[i - 1] = VisitExpr(*cmp);
if (!cmps[i - 1]->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
}
auto cmpValues = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size - 1));
auto cmpBools = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size - 1));
for (int i = 0; i < size - 1; ++i) {
Value *gep =
builder->CreateGEP(llvmTypes->I64Type(), cmpValues, llvmTypes->CreateConstantInt(i), "cmp_value_address");
builder->CreateStore(cmps[i]->data, gep);
gep = builder->CreateGEP(llvmTypes->I64Type(), cmpBools, llvmTypes->CreateConstantInt(i), "cmp_null_address");
builder->CreateStore(cmps[i]->isNull, gep);
}
std::vector<Value *> args;
switch (baseType) {
case OMNI_BOOLEAN:
case OMNI_INT:
case OMNI_DATE32:
case OMNI_LONG:
case OMNI_TIMESTAMP:
case OMNI_DOUBLE:
case OMNI_DECIMAL64:
case OMNI_DECIMAL128: {
args = { llvmTypes->CreateConstantInt(size - 1),
cmpValues,
cmpBools,
valueToCompare->data,
valueToCompare->isNull,
inArray,
isNull,
this->batchCodegenContext->rowCnt };
CallExternFunction("batch_in", { baseType }, OMNI_BOOLEAN, args, nullptr);
break;
}
case OMNI_CHAR:
case OMNI_VARCHAR: {
auto cmpLengths = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size - 1));
for (int i = 0; i < size - 1; ++i) {
Value *gep = builder->CreateGEP(llvmTypes->I64Type(), cmpLengths, llvmTypes->CreateConstantInt(i),
"cmp_length_address");
builder->CreateStore(cmps[i]->length, gep);
}
args = { llvmTypes->CreateConstantInt(size - 1),
cmpValues,
cmpBools,
cmpLengths,
valueToCompare->data,
valueToCompare->isNull,
valueToCompare->length,
inArray,
isNull,
this->batchCodegenContext->rowCnt };
CallExternFunction("batch_in", { baseType }, OMNI_BOOLEAN, args, nullptr);
break;
}
default: {
LogWarn("Unsupported data type in IN expr %d", baseType);
this->value = CreateInvalidCodeGenValue();
return;
}
}
this->value = std::make_shared<CodeGenValue>(inArray, isNull);
}
void BatchExpressionCodeGen::Visit(const SwitchExpr &switchExpr)
{
auto switchDataType = switchExpr.GetReturnTypeId();
Expr *elseExpr = switchExpr.falseExpr;
std::vector<std::pair<Expr *, Expr *>> whenClause = switchExpr.whenClause;
const int size = whenClause.size();
AllocaInst *finalValue = GetResultArray(switchDataType, this->batchCodegenContext->rowCnt);
AllocaInst *finalNull = GetResultArray(OMNI_BOOLEAN, this->batchCodegenContext->rowCnt);
std::vector<CodeGenValuePtr> conditions(size);
std::vector<CodeGenValuePtr> results(size);
for (int i = 0; i < size; ++i) {
Expr *cond = whenClause[i].first;
Expr *resExpr = whenClause[i].second;
conditions[i] = VisitExpr(*cond);
if (!conditions[i]->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
results[i] = VisitExpr(*resExpr);
if (!results[i]->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
}
auto whenClauses = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size));
auto whenBools = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size));
auto resultValues = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size));
auto resultNulls = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size));
for (int i = 0; i < size; ++i) {
Value *gep = builder->CreateGEP(llvmTypes->I64Type(), whenClauses, llvmTypes->CreateConstantInt(i),
"when_value_address");
builder->CreateStore(conditions[i]->data, gep);
gep = builder->CreateGEP(llvmTypes->I64Type(), whenBools, llvmTypes->CreateConstantInt(i), "when_null_address");
builder->CreateStore(conditions[i]->isNull, gep);
gep = builder->CreateGEP(llvmTypes->I64Type(), resultValues, llvmTypes->CreateConstantInt(i),
"result_value_address");
builder->CreateStore(results[i]->data, gep);
gep = builder->CreateGEP(llvmTypes->I64Type(), resultNulls, llvmTypes->CreateConstantInt(i),
"result_null_address");
builder->CreateStore(results[i]->isNull, gep);
}
auto evFalse = VisitExpr(*elseExpr);
if (!evFalse->IsValidValue()) {
this->value = CreateInvalidCodeGenValue();
return;
}
std::vector<Value *> args;
switch (switchDataType) {
case OMNI_INT:
case OMNI_BOOLEAN:
case OMNI_DATE32:
case OMNI_LONG:
case OMNI_TIMESTAMP:
case OMNI_DOUBLE: {
args = { llvmTypes->CreateConstantInt(size),
whenClauses,
whenBools,
resultValues,
resultNulls,
evFalse->data,
evFalse->isNull,
finalValue,
finalNull,
this->batchCodegenContext->rowCnt };
CallExternFunction("batch_switch", { switchDataType }, switchDataType, args, nullptr);
this->value = std::make_shared<CodeGenValue>(finalValue, finalNull);
return;
}
case OMNI_DECIMAL64:
case OMNI_DECIMAL128: {
auto returnDecimalValue = BuildDecimalValue(nullptr, *switchExpr.GetReturnType(), nullptr);
args = { llvmTypes->CreateConstantInt(size),
whenClauses,
whenBools,
resultValues,
resultNulls,
evFalse->data,
evFalse->isNull,
finalValue,
finalNull,
this->batchCodegenContext->rowCnt };
CallExternFunction("batch_switch", { switchDataType }, switchDataType, args, nullptr);
this->value = std::make_shared<DecimalValue>(finalValue, finalNull,
const_cast<Value *>(returnDecimalValue->GetPrecision()),
const_cast<Value *>(returnDecimalValue->GetScale()));
return;
}
case OMNI_CHAR:
case OMNI_VARCHAR: {
auto resultLengths = GetResultArray(OMNI_LONG, llvmTypes->CreateConstantInt(size));
for (int i = 0; i < size; ++i) {
Value *gep = builder->CreateGEP(llvmTypes->I64Type(), resultLengths, llvmTypes->CreateConstantInt(i),
"result_length_address");
builder->CreateStore(results[i]->length, gep);
}
AllocaInst *finalLength = GetResultArray(OMNI_INT, this->batchCodegenContext->rowCnt);
args = { llvmTypes->CreateConstantInt(size),
whenClauses,
whenBools,
resultValues,
resultNulls,
resultLengths,
evFalse->data,
evFalse->isNull,
evFalse->length,
finalValue,
finalNull,
finalLength,
this->batchCodegenContext->rowCnt };
CallExternFunction("batch_switch", { switchDataType }, switchDataType, args, nullptr);
this->value = std::make_shared<CodeGenValue>(finalValue, finalNull, finalLength);
return;
}
default: {
LogWarn("Unsupported data type in SWITCH expr %d", switchDataType);
this->value = CreateInvalidCodeGenValue();
return;
}
}
}
void BatchExpressionCodeGen::BatchBinaryExprIntLongHelper(const omniruntime::expressions::BinaryExpr *binaryExpr,
llvm::Value *left, llvm::Value *right, llvm::Value *leftIsNull, llvm::Value *rightIsNull)
{
DataTypeId returnTypeId = binaryExpr->GetReturnTypeId();
std::vector<omniruntime::type::DataTypeId> typeParams;
if (binaryExpr->left->GetReturnTypeId() == OMNI_INT || binaryExpr->left->GetReturnTypeId() == OMNI_DATE32) {
typeParams = { OMNI_INT, OMNI_INT };
} else {
typeParams = { OMNI_LONG, OMNI_LONG };
}
std::vector<omniruntime::type::DataTypeId> boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN };
AllocaInst *logicalArrayPtr = nullptr;
std::vector<Value *> logicalFuncParams;
if (returnTypeId == OMNI_BOOLEAN) {
logicalArrayPtr = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "LOGICAL_PTR");
logicalFuncParams = { left, right, logicalArrayPtr, this->batchCodegenContext->rowCnt };
}
std::vector<Value *> arithFuncParams { left, right, this->batchCodegenContext->rowCnt };
std::vector<Value *> nullFuncParams { leftIsNull, rightIsNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null");
switch (binaryExpr->op) {
case omniruntime::expressions::Operator::LT:
CallExternFunction("batch_lessThan", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_lt");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::GT:
CallExternFunction("batch_greaterThan", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_gt");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::LTE:
CallExternFunction("batch_lessThanEqual", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_le");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::GTE:
CallExternFunction("batch_greaterThanEqual", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_ge");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::EQ:
CallExternFunction("batch_equal", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_eq");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::NEQ:
CallExternFunction("batch_notEqual", typeParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_neq");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::ADD:
CallExternFunction("batch_add", typeParams, returnTypeId, arithFuncParams, nullptr, "arithmetic_add");
this->value = std::make_shared<CodeGenValue>(left, leftIsNull);
return;
case omniruntime::expressions::Operator::SUB:
CallExternFunction("batch_subtract", typeParams, returnTypeId, arithFuncParams, nullptr, "arithmetic_sub");
this->value = std::make_shared<CodeGenValue>(left, leftIsNull);
return;
case omniruntime::expressions::Operator::MUL:
CallExternFunction("batch_multiply", typeParams, returnTypeId, arithFuncParams, nullptr, "arithmetic_mul");
this->value = std::make_shared<CodeGenValue>(left, leftIsNull);
return;
case omniruntime::expressions::Operator::DIV:
arithFuncParams.push_back(leftIsNull);
CallExternFunction("batch_divide", typeParams, returnTypeId, arithFuncParams,
batchCodegenContext->executionContext, "arithmetic_div");
this->value = std::make_shared<CodeGenValue>(left, leftIsNull);
return;
case omniruntime::expressions::Operator::MOD:
arithFuncParams.push_back(leftIsNull);
CallExternFunction("batch_modulus", typeParams, returnTypeId, arithFuncParams,
batchCodegenContext->executionContext, "arithmetic_mod");
this->value = std::make_shared<CodeGenValue>(left, leftIsNull);
return;
default: {
LogError("Unsupported int or long binary operator %u", static_cast<uint32_t>(binaryExpr->op));
this->value = CreateInvalidCodeGenValue();
return;
}
}
}
void BatchExpressionCodeGen::BatchBinaryExprDoubleHelper(const omniruntime::expressions::BinaryExpr *binaryExpr,
llvm::Value *left, llvm::Value *right, llvm::Value *leftIsNull, llvm::Value *rightIsNull)
{
std::vector<omniruntime::type::DataTypeId> doubleParams { OMNI_DOUBLE, OMNI_DOUBLE };
std::vector<omniruntime::type::DataTypeId> boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN };
DataTypeId returnTypeId = binaryExpr->GetReturnTypeId();
AllocaInst *logicalArrayPtr = nullptr;
std::vector<Value *> logicalFuncParams;
if (returnTypeId == OMNI_BOOLEAN) {
logicalArrayPtr = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "LOGICAL_PTR");
logicalFuncParams = { left, right, logicalArrayPtr, this->batchCodegenContext->rowCnt };
}
std::vector<Value *> arithFuncParams { left, right, this->batchCodegenContext->rowCnt };
std::vector<Value *> nullFuncParams { leftIsNull, rightIsNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null");
switch (binaryExpr->op) {
case omniruntime::expressions::Operator::LT:
CallExternFunction("batch_lessThan", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_lt");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::GT:
CallExternFunction("batch_greaterThan", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_gt");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::LTE:
CallExternFunction("batch_lessThanEqual", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_le");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::GTE:
CallExternFunction("batch_greaterThanEqual", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_ge");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::EQ:
CallExternFunction("batch_equal", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_eq");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::NEQ:
CallExternFunction("batch_notEqual", doubleParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_neq");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::ADD:
CallExternFunction("batch_add", doubleParams, OMNI_DOUBLE, arithFuncParams, nullptr, "arithmetic_add");
this->value = std::make_shared<CodeGenValue>(left, leftIsNull);
return;
case omniruntime::expressions::Operator::SUB:
CallExternFunction("batch_subtract", doubleParams, OMNI_DOUBLE, arithFuncParams, nullptr, "arithmetic_sub");
this->value = std::make_shared<CodeGenValue>(left, leftIsNull);
return;
case omniruntime::expressions::Operator::MUL:
CallExternFunction("batch_multiply", doubleParams, OMNI_DOUBLE, arithFuncParams, nullptr, "arithmetic_mul");
this->value = std::make_shared<CodeGenValue>(left, leftIsNull);
return;
case omniruntime::expressions::Operator::DIV:
CallExternFunction("batch_divide", doubleParams, OMNI_DOUBLE, arithFuncParams, nullptr, "arithmetic_div");
this->value = std::make_shared<CodeGenValue>(left, leftIsNull);
return;
case omniruntime::expressions::Operator::MOD:
CallExternFunction("batch_modulus", doubleParams, OMNI_DOUBLE, arithFuncParams, nullptr, "arithmetic_mod");
this->value = std::make_shared<CodeGenValue>(left, leftIsNull);
return;
default: {
LogError("Unsupported double binary operator %u", static_cast<uint32_t>(binaryExpr->op));
this->value = CreateInvalidCodeGenValue();
return;
}
}
}
void BatchExpressionCodeGen::BatchBinaryExprDecimalHelper(const omniruntime::expressions::BinaryExpr *binaryExpr,
DecimalValue &left, DecimalValue &right, llvm::Value *leftIsNull, llvm::Value *rightIsNull)
{
std::vector<DataTypeId> decimalParams { binaryExpr->left->GetReturnTypeId(), binaryExpr->right->GetReturnTypeId() };
std::vector<DataTypeId> boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN };
DataTypeId returnTypeId = binaryExpr->GetReturnTypeId();
std::shared_ptr<DecimalValue> returnDecimalValue = nullptr;
AllocaInst *logicalArrayPtr = nullptr;
AllocaInst *arithArrayPtr = nullptr;
std::vector<Value *> logicalFuncParams;
std::vector<Value *> arithFuncParams;
if (returnTypeId == OMNI_BOOLEAN) {
logicalArrayPtr = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "LOGICAL_PTR");
logicalFuncParams = {
left.data, const_cast<Value *>(left.GetPrecision()), const_cast<Value *>(left.GetScale()),
right.data, const_cast<Value *>(right.GetPrecision()), const_cast<Value *>(right.GetScale()),
logicalArrayPtr, this->batchCodegenContext->rowCnt
};
} else if (TypeUtil::IsDecimalType(returnTypeId)) {
returnDecimalValue = BuildDecimalValue(nullptr, *binaryExpr->GetReturnType(), nullptr);
arithFuncParams = { left.data,
const_cast<Value *>(left.GetPrecision()),
const_cast<Value *>(left.GetScale()),
right.data,
const_cast<Value *>(right.GetPrecision()),
const_cast<Value *>(right.GetScale()),
const_cast<Value *>(returnDecimalValue->GetPrecision()),
const_cast<Value *>(returnDecimalValue->GetScale()),
this->batchCodegenContext->rowCnt };
if (decimalParams == std::vector<DataTypeId> { OMNI_DECIMAL128, OMNI_DECIMAL128 } &&
returnTypeId == OMNI_DECIMAL64) {
arithArrayPtr = builder->CreateAlloca(llvmTypes->I64Type(), this->batchCodegenContext->rowCnt, "ARITH_PTR");
arithFuncParams.insert(std::begin(arithFuncParams) + 6, arithArrayPtr);
} else if (decimalParams == std::vector<DataTypeId> { OMNI_DECIMAL64, OMNI_DECIMAL64 } &&
returnTypeId == OMNI_DECIMAL128) {
arithArrayPtr =
builder->CreateAlloca(llvmTypes->I128Type(), this->batchCodegenContext->rowCnt, "ARITH_PTR");
arithFuncParams.insert(std::begin(arithFuncParams) + 6, arithArrayPtr);
}
}
std::vector<Value *> nullFuncParams { leftIsNull, rightIsNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null");
if (!overflowConfig || overflowConfig->GetOverflowConfigId() != omniruntime::op::OVERFLOW_CONFIG_NULL) {
arithFuncParams.insert(arithFuncParams.begin(), leftIsNull);
}
Value *falseValue = llvmTypes->CreateConstantBool(false);
AllocaInst *overflowNull =
builder->CreateAlloca(Type::getInt1Ty(*context), this->batchCodegenContext->rowCnt, "OVERFLOW_NULL_PTR");
std::vector<Value *> funcArgs { overflowNull, falseValue, this->batchCodegenContext->rowCnt };
std::vector<DataTypeId> paramsVec = { OMNI_BOOLEAN };
CallExternFunction("batch_fill_null", paramsVec, OMNI_BOOLEAN, funcArgs, nullptr, "fill_null_array");
switch (binaryExpr->op) {
case omniruntime::expressions::Operator::LT:
CallExternFunction("batch_lessThan", decimalParams, returnTypeId, logicalFuncParams, nullptr,
"relational_lt");
break;
case omniruntime::expressions::Operator::GT:
CallExternFunction("batch_greaterThan", decimalParams, returnTypeId, logicalFuncParams, nullptr,
"relational_gt");
break;
case omniruntime::expressions::Operator::LTE:
CallExternFunction("batch_lessThanEqual", decimalParams, returnTypeId, logicalFuncParams, nullptr,
"relational_le");
break;
case omniruntime::expressions::Operator::GTE:
CallExternFunction("batch_greaterThanEqual", decimalParams, returnTypeId, logicalFuncParams, nullptr,
"relational_ge");
break;
case omniruntime::expressions::Operator::EQ:
CallExternFunction("batch_equal", decimalParams, returnTypeId, logicalFuncParams, nullptr, "relational_eq");
break;
case omniruntime::expressions::Operator::NEQ:
CallExternFunction("batch_notEqual", decimalParams, returnTypeId, logicalFuncParams, nullptr,
"relational_neq");
break;
case omniruntime::expressions::Operator::ADD:
CallExternFunction("batch_add", decimalParams, returnTypeId, arithFuncParams,
batchCodegenContext->executionContext, "arithmetic_add", this->overflowConfig, overflowNull);
break;
case omniruntime::expressions::Operator::SUB:
CallExternFunction("batch_subtract", decimalParams, returnTypeId, arithFuncParams,
batchCodegenContext->executionContext, "arithmetic_sub", this->overflowConfig, overflowNull);
break;
case omniruntime::expressions::Operator::MUL:
CallExternFunction("batch_multiply", decimalParams, returnTypeId, arithFuncParams,
batchCodegenContext->executionContext, "arithmetic_mul", this->overflowConfig, overflowNull);
break;
case omniruntime::expressions::Operator::DIV:
CallExternFunction("batch_divide", decimalParams, returnTypeId, arithFuncParams,
batchCodegenContext->executionContext, "arithmetic_div", this->overflowConfig, overflowNull);
break;
case omniruntime::expressions::Operator::MOD:
CallExternFunction("batch_modulus", decimalParams, returnTypeId, arithFuncParams,
batchCodegenContext->executionContext, "arithmetic_mod", this->overflowConfig, overflowNull);
break;
default: {
LogError("Unsupported decimal binary operator %u", static_cast<uint32_t>(binaryExpr->op));
this->value = CreateInvalidCodeGenValue();
return;
}
}
if (TypeUtil::IsDecimalType(returnTypeId)) {
std::vector<Value *> isAnyNullParams { leftIsNull, overflowNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, isAnyNullParams, nullptr, "either_null");
llvm::Value *dataValue = nullptr;
if (returnTypeId == omniruntime::type::OMNI_DECIMAL128) {
if (decimalParams[0] == OMNI_DECIMAL128) {
dataValue = left.data;
} else if (decimalParams[1] == OMNI_DECIMAL128) {
dataValue = right.data;
} else {
dataValue = arithArrayPtr;
}
this->value = std::make_shared<DecimalValue>(dataValue, leftIsNull,
const_cast<Value *>(returnDecimalValue->GetPrecision()),
const_cast<Value *>(returnDecimalValue->GetScale()));
} else if (returnTypeId == omniruntime::type::OMNI_DECIMAL64) {
if (decimalParams[0] == OMNI_DECIMAL64) {
dataValue = left.data;
} else if (decimalParams[1] == OMNI_DECIMAL64) {
dataValue = right.data;
} else {
dataValue = arithArrayPtr;
}
this->value = std::make_shared<DecimalValue>(dataValue, leftIsNull,
const_cast<Value *>(returnDecimalValue->GetPrecision()),
const_cast<Value *>(returnDecimalValue->GetScale()));
}
} else {
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
}
}
void BatchExpressionCodeGen::BatchBinaryExprStringHelper(const omniruntime::expressions::BinaryExpr *binaryExpr,
llvm::Value *left, llvm::Value *leftLen, llvm::Value *right, llvm::Value *rightLen, llvm::Value *leftIsNull,
llvm::Value *rightIsNull)
{
std::vector<omniruntime::type::DataTypeId> strParams { OMNI_VARCHAR, OMNI_VARCHAR };
std::vector<omniruntime::type::DataTypeId> boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN };
auto logicalArrayPtr = builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "LOGICAL_PTR");
std::vector<Value *> logicalFuncParams { left, leftLen, right,
rightLen, logicalArrayPtr, this->batchCodegenContext->rowCnt };
std::vector<Value *> nullFuncParams { leftIsNull, rightIsNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null");
switch (binaryExpr->op) {
case omniruntime::expressions::Operator::LT:
CallExternFunction("batch_lessThan", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_lt");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::GT:
CallExternFunction("batch_greaterThan", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_gt");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::LTE:
CallExternFunction("batch_lessThanEqual", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_le");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::GTE:
CallExternFunction("batch_greaterThanEqual", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr,
"relational_ge");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::EQ:
CallExternFunction("batch_equal", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_eq");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
case omniruntime::expressions::Operator::NEQ:
CallExternFunction("batch_notEqual", strParams, OMNI_BOOLEAN, logicalFuncParams, nullptr, "relational_neq");
this->value = std::make_shared<CodeGenValue>(logicalArrayPtr, leftIsNull);
return;
default: {
LogError("Unsupported string binary operator %u", static_cast<uint32_t>(binaryExpr->op));
this->value = CreateInvalidCodeGenValue();
return;
}
}
}
void BatchExpressionCodeGen::BatchVisitBetweenExprHelper(BetweenExpr &bExpr, const std::shared_ptr<CodeGenValue> &val,
const std::shared_ptr<CodeGenValue> &lowerVal, const std::shared_ptr<CodeGenValue> &upperVal,
std::pair<AllocaInst **, AllocaInst **> cmpPair)
{
auto cmpLeft = cmpPair.first;
auto cmpRight = cmpPair.second;
std::vector<omniruntime::type::DataTypeId> params;
if (TypeUtil::IsStringType(bExpr.value->GetReturnTypeId())) {
params = { OMNI_VARCHAR, OMNI_VARCHAR };
} else if (bExpr.value->GetReturnTypeId() == type::OMNI_DATE32) {
params = { OMNI_INT, OMNI_INT };
} else if (bExpr.value->GetReturnTypeId() == type::OMNI_TIMESTAMP) {
params = { OMNI_LONG, OMNI_LONG };
} else {
params = { bExpr.value->GetReturnTypeId(), bExpr.value->GetReturnTypeId() };
}
std::vector<Value *> logicalFuncParams1;
std::vector<Value *> logicalFuncParams2;
bool isSupportedType = false;
if (bExpr.value->GetReturnTypeId() == OMNI_INT || bExpr.value->GetReturnTypeId() == OMNI_LONG ||
bExpr.value->GetReturnTypeId() == OMNI_DATE32 || bExpr.value->GetReturnTypeId() == OMNI_DOUBLE ||
bExpr.value->GetReturnTypeId() == type::OMNI_TIMESTAMP) {
logicalFuncParams1 = { lowerVal->data, val->data, *cmpLeft, this->batchCodegenContext->rowCnt };
logicalFuncParams2 = { val->data, upperVal->data, *cmpRight, this->batchCodegenContext->rowCnt };
isSupportedType = true;
} else if (TypeUtil::IsStringType(bExpr.value->GetReturnTypeId())) {
logicalFuncParams1 = { lowerVal->data, lowerVal->length, val->data,
val->length, *cmpLeft, this->batchCodegenContext->rowCnt };
logicalFuncParams2 = { val->data, val->length, upperVal->data,
upperVal->length, *cmpRight, this->batchCodegenContext->rowCnt };
isSupportedType = true;
} else if (TypeUtil::IsDecimalType(bExpr.value->GetReturnTypeId())) {
logicalFuncParams1 = { lowerVal->data,
const_cast<Value *>(static_cast<DecimalValue &>(*lowerVal).GetPrecision()),
const_cast<Value *>(static_cast<DecimalValue &>(*lowerVal).GetScale()),
val->data,
const_cast<Value *>(static_cast<DecimalValue &>(*val).GetPrecision()),
const_cast<Value *>(static_cast<DecimalValue &>(*val).GetScale()),
*cmpLeft,
this->batchCodegenContext->rowCnt };
logicalFuncParams2 = { val->data,
const_cast<Value *>(static_cast<DecimalValue &>(*val).GetPrecision()),
const_cast<Value *>(static_cast<DecimalValue &>(*val).GetScale()),
upperVal->data,
const_cast<Value *>(static_cast<DecimalValue &>(*upperVal).GetPrecision()),
const_cast<Value *>(static_cast<DecimalValue &>(*upperVal).GetScale()),
*cmpRight,
this->batchCodegenContext->rowCnt };
isSupportedType = true;
}
if (isSupportedType) {
CallExternFunction("batch_lessThanEqual", params, OMNI_BOOLEAN, logicalFuncParams1, nullptr, "relational_le");
CallExternFunction("batch_lessThanEqual", params, OMNI_BOOLEAN, logicalFuncParams2, nullptr, "relational_le");
std::vector<omniruntime::type::DataTypeId> boolParams { OMNI_BOOLEAN, OMNI_BOOLEAN };
std::vector<Value *> nullFuncParams { lowerVal->isNull, val->isNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null");
nullFuncParams = { lowerVal->isNull, upperVal->isNull, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_or", boolParams, OMNI_BOOLEAN, nullFuncParams, nullptr, "either_null");
std::vector<Value *> betweenFuncParams = { *cmpLeft, *cmpRight, this->batchCodegenContext->rowCnt };
CallExternFunction("batch_and", boolParams, OMNI_BOOLEAN, betweenFuncParams, nullptr, "between_pass");
this->value = std::make_shared<CodeGenValue>(*cmpLeft, lowerVal->isNull);
return;
}
LogError("Unsupported data type for BETWEEN expr %d", bExpr.value->GetReturnTypeId());
this->value = CreateInvalidCodeGenValue();
}
Value *BatchExpressionCodeGen::PushAndGetNullFlagArray(const FuncExpr &fExpr, std::vector<llvm::Value *> &argVals,
Value *nullFlagArray, bool needAdd)
{
if (fExpr.function->GetNullableResultType() == INPUT_DATA_AND_NULL_AND_RETURN_NULL) {
AllocaInst *isNullArrPtr =
builder->CreateAlloca(llvmTypes->I1Type(), this->batchCodegenContext->rowCnt, "IS_RET_NULL_PTR");
CallExternFunction("batch_fill_null", { OMNI_BOOLEAN }, OMNI_BOOLEAN,
{ isNullArrPtr, llvmTypes->CreateConstantBool(false), this->batchCodegenContext->rowCnt }, nullptr,
"fill_ret_null_array");
argVals.push_back(isNullArrPtr);
return isNullArrPtr;
}
if (needAdd) {
argVals.push_back(nullFlagArray);
}
return nullFlagArray;
}
}