* Copyright (c) Huawei Technologies Co., Ltd. 2021-2021. All rights reserved.
* Description: project codegen
*/
#include "projection_codegen.h"
namespace omniruntime {
namespace codegen {
using namespace llvm;
using namespace orc;
using namespace omniruntime::expressions;
using namespace omniruntime::type;
namespace {
const int INPUT_TABLE_INDEX = 0;
const int NUM_ROWS_INDEX = 1;
const int OUTPUT_ADDRESS_INDEX = 2;
const int SELECTED = 3;
const int NUM_SELECTED = 4;
const int BITMAP = 5;
const int OFFSETS_INDEX = 6;
const int NEW_NULL_VALUES_INDEX = 7;
const int OUTPUT_OFFSETS_INDEX = 8;
const int EXECUTION_CONTEXT_IDX = 9;
const int DICTIONARY_VECTORS_IDX = 10;
}
intptr_t ProjectionCodeGen::GetFunction(const DataTypes &inputDataTypes)
{
llvm::Function *func = CreateFunction(inputDataTypes);
if (func == nullptr) {
return 0;
}
return CreateWrapper();
}
intptr_t ProjectionCodeGen::CreateWrapper()
{
std::vector<Type *> args {
llvmTypes->I64PtrType(),
llvmTypes->I32Type(),
llvmTypes->I64Type(),
llvmTypes->I32PtrType(),
llvmTypes->I32Type(),
llvmTypes->I64PtrType(),
llvmTypes->I64PtrType(),
llvmTypes->I32PtrType(),
llvmTypes->I32PtrType(),
llvmTypes->I64Type(),
llvmTypes->I64PtrType()
};
FunctionType *funcSignature = FunctionType::get(llvmTypes->I32Type(), args, false);
llvm::Function *funcDecl =
llvm::Function::Create(funcSignature, llvm::Function::ExternalLinkage, "WRAPPER_FUNC", modulePtr);
BasicBlock *preLoop = BasicBlock::Create(*context, "PRE_LOOP", funcDecl);
BasicBlock *loopBody = BasicBlock::Create(*context, "LOOP_BODY", funcDecl);
BasicBlock *addToOutput = BasicBlock::Create(*context, "ADD_OUTPUT", funcDecl);
BasicBlock *incrementCounter = BasicBlock::Create(*context, "INCREMENT_COUNTER", funcDecl);
BasicBlock *endBlock = BasicBlock::Create(*context, "END_BLOCK", funcDecl);
Argument *input = funcDecl->getArg(INPUT_TABLE_INDEX);
input->setName("INPUT_TABLE");
Argument *numRows = funcDecl->getArg(NUM_ROWS_INDEX);
numRows->setName("NUM_ROWS");
Argument *outputAddress = funcDecl->getArg(OUTPUT_ADDRESS_INDEX);
outputAddress->setName("OUTPUT_ADDRESS");
RecordMainFunction(funcDecl);
Argument *selected = nullptr;
Argument *numSelected = nullptr;
if (filter) {
selected = funcDecl->getArg(SELECTED);
selected->setName("SELECTED_ARRAY");
numSelected = funcDecl->getArg(NUM_SELECTED);
numSelected->setName("NUM_SELECTED");
}
Argument *bitmap = funcDecl->getArg(BITMAP);
bitmap->setName("BITMAP");
Argument *offsets = funcDecl->getArg(OFFSETS_INDEX);
offsets->setName("OFFSETS");
Argument *nullValuesAddress = funcDecl->getArg(NEW_NULL_VALUES_INDEX);
nullValuesAddress->setName("NULL_VALUES_ADDRESS");
Argument *outputOffsetsAddress = funcDecl->getArg(OUTPUT_OFFSETS_INDEX);
outputOffsetsAddress->setName("OUTPUT_OFFSETS_ADDRESS");
Argument *executionContext = funcDecl->getArg(EXECUTION_CONTEXT_IDX);
executionContext->setName("EXECUTION_CONTEXT_ADDRESS");
Argument *dictionaryVectors = funcDecl->getArg(DICTIONARY_VECTORS_IDX);
dictionaryVectors->setName("DICTIONARY_VECTORS");
Value *zero = llvmTypes->CreateConstantInt(0);
Value *one = llvmTypes->CreateConstantInt(1);
builder->SetInsertPoint(preLoop);
AllocaInst *indexStore = builder->CreateAlloca(llvmTypes->I32Type(), nullptr, "INDEX_COUNTER");
builder->CreateStore(zero, indexStore);
AllocaInst *offsetStore = builder->CreateAlloca(llvmTypes->I32Type(), nullptr, "CURRENT_OFFSET");
builder->CreateStore(zero, offsetStore);
Value *curIndexVal;
Value *rowIndexVal;
Value *nextIndexVal;
Value *selectedAddress;
FunctionSignature setBitNullFuncSignature = FunctionSignature("WrapSetBitNull", { OMNI_INT }, OMNI_BOOLEAN);
llvm::Function *setBitNullFunc =
modulePtr->getFunction(FunctionRegistry::LookupFunction(&setBitNullFuncSignature)->GetId());
llvm::Function *varcharVectorFunc = nullptr;
if (expr->GetReturnTypeId() == OMNI_CHAR || expr->GetReturnTypeId() == OMNI_VARCHAR) {
std::vector<DataTypeId> paramTypes = { OMNI_LONG, OMNI_INT, OMNI_VARCHAR };
FunctionSignature varcharVectorFuncSignature = FunctionSignature("WrapVarcharVector", paramTypes, OMNI_INT);
varcharVectorFunc =
modulePtr->getFunction(FunctionRegistry::LookupFunction(&varcharVectorFuncSignature)->GetId());
}
Type *outPtrType = llvmTypes->ToPointerType(expr->GetReturnTypeId());
if (outPtrType == nullptr) {
return 0;
}
Value *outColPtr = builder->CreateIntToPtr(outputAddress, outPtrType);
AllocaInst *outputLenPtr = builder->CreateAlloca(llvmTypes->I32Type(), nullptr, "OUTPUT_LENGTH");
auto isNullPtr = builder->CreateAlloca(llvmTypes->I1Type(), nullptr, "IS_NULL");
auto columnArgs = exprFunc->ToColumnArgs(input);
auto dicArgs = exprFunc->ToDicArgs(dictionaryVectors);
auto nullArgs = exprFunc->ToNullArgs(bitmap);
auto offsetArgs = exprFunc->ToOffsetArgs(offsets);
builder->CreateBr(loopBody);
builder->SetInsertPoint(loopBody);
curIndexVal = builder->CreateLoad(llvmTypes->I32Type(), indexStore, "CUR_INDEX");
if (filter) {
selectedAddress = builder->CreateGEP(llvmTypes->I32Type(), selected, curIndexVal, "SELECTED_ADDRESS");
rowIndexVal = builder->CreateLoad(llvmTypes->I32Type(), selectedAddress);
} else {
rowIndexVal = curIndexVal;
}
builder->CreateStore(llvmTypes->CreateConstantBool(false), isNullPtr);
std::vector<Value *> projFuncArgs;
int32_t argsSize = exprFunc->GetArgumentCount() + exprFunc->GetInputColumnCount() * 4;
projFuncArgs.reserve(argsSize);
projFuncArgs.push_back(rowIndexVal);
projFuncArgs.push_back(outputLenPtr);
projFuncArgs.push_back(executionContext);
projFuncArgs.push_back(isNullPtr);
projFuncArgs.insert(projFuncArgs.end(), columnArgs.begin(), columnArgs.end());
projFuncArgs.insert(projFuncArgs.end(), dicArgs.begin(), dicArgs.end());
projFuncArgs.insert(projFuncArgs.end(), nullArgs.begin(), nullArgs.end());
projFuncArgs.insert(projFuncArgs.end(), offsetArgs.begin(), offsetArgs.end());
CallInst *ret = builder->CreateCall(func, projFuncArgs, "ROW_PROCESS");
builder->CreateBr(addToOutput);
builder->SetInsertPoint(addToOutput);
Value *gep;
Type *ty = llvmTypes->VectorToLLVMType(*(expr->GetReturnType()));
if (TypeUtil::IsStringType(expr->GetReturnTypeId())) {
auto outputLen = builder->CreateLoad(llvmTypes->I32Type(), outputLenPtr, "OUTPUT_LENGTH");
auto stringPtr = builder->CreateIntToPtr(ret, Type::getInt8PtrTy(*context));
std::vector<Value *> argVals { outColPtr, curIndexVal, stringPtr, outputLen };
auto call = builder->CreateCall(varcharVectorFunc, argVals, "wrap_varchar_vector");
InlineFunctionInfo inlineFunctionInfo;
InlineFunction(*call, inlineFunctionInfo);
} else {
gep = builder->CreateGEP(ty, outColPtr, curIndexVal, "OUTPUT_ADDRESS");
builder->CreateStore(ret, gep);
}
auto setNullRet = builder->CreateCall(setBitNullFunc,
{ nullValuesAddress, curIndexVal, builder->CreateLoad(llvmTypes->I1Type(), isNullPtr) }, "wrap_set_bit_null");
InlineFunctionInfo inlineSetNullFuncInfo;
InlineFunction(*setNullRet, inlineSetNullFuncInfo);
builder->CreateBr(incrementCounter);
builder->SetInsertPoint(incrementCounter);
nextIndexVal = builder->CreateAdd(curIndexVal, one, "NEXT_INDEX");
builder->CreateStore(nextIndexVal, indexStore);
Value *sentinel;
if (filter) {
sentinel = numSelected;
} else {
sentinel = numRows;
}
Value *cond = builder->CreateICmpSLT(nextIndexVal, sentinel, "END_LOOP_COND");
builder->CreateCondBr(cond, loopBody, endBlock);
builder->SetInsertPoint(endBlock);
builder->CreateRet(nextIndexVal);
OptimizeFunctionsAndModule();
return Compile();
}
}
}