* Copyright (C) 2025-2025. Huawei Technologies Co., Ltd. All rights reserved.
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "jni_common.h"
#include "deserializer.hh"
#include "common/common.h"
using namespace omniruntime::vec;
JNIEXPORT jlong JNICALL
Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseInit(
JNIEnv *env, jobject obj, jlong address, jint length)
{
JNI_FUNC_START
auto *vecBatch = new spark::VecBatch();
vecBatch->ParseFromArray(reinterpret_cast<char*>(address), length);
return (jlong)(vecBatch);
JNI_FUNC_END(runtimeExceptionClass)
}
JNIEXPORT void JNICALL
Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseClose(
JNIEnv *env, jobject obj, jlong address)
{
JNI_FUNC_START
spark::VecBatch* vecBatch = reinterpret_cast<spark::VecBatch*>(address);
delete vecBatch;
JNI_FUNC_END_VOID(runtimeExceptionClass)
}
JNIEXPORT jint JNICALL
Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseVecCount(
JNIEnv *env, jobject obj, jlong address)
{
JNI_FUNC_START
spark::VecBatch* vecBatch = reinterpret_cast<spark::VecBatch*>(address);
return vecBatch->veccnt();
JNI_FUNC_END(runtimeExceptionClass)
}
JNIEXPORT jint JNICALL
Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseRowCount(
JNIEnv *env, jobject obj, jlong address)
{
JNI_FUNC_START
spark::VecBatch* vecBatch = reinterpret_cast<spark::VecBatch*>(address);
return vecBatch->rowcnt();
JNI_FUNC_END(runtimeExceptionClass)
}
JNIEXPORT void JNICALL
Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_columnarShuffleParseBatch(
JNIEnv *env, jobject obj, jlong address, jintArray typeIdArray, jintArray precisionArray,
jintArray scaleArray, jlongArray vecNativeIdArray)
{
spark::VecBatch* vecBatch = reinterpret_cast<spark::VecBatch*>(address);
int32_t vecCount = vecBatch->veccnt();
int32_t rowCount = vecBatch->rowcnt();
omniruntime::vec::BaseVector* vecs[vecCount]{};
JNI_FUNC_START
jint *typeIdArrayElements = env->GetIntArrayElements(typeIdArray, NULL);
jint *precisionArrayElements = env->GetIntArrayElements(precisionArray, NULL);
jint *scaleArrayElements = env->GetIntArrayElements(scaleArray, NULL);
jlong *vecNativeIdArrayElements = env->GetLongArrayElements(vecNativeIdArray, NULL);
for (auto i = 0; i < vecCount; ++i) {
const spark::Vec& protoVec = vecBatch->vecs(i);
const spark::VecType& protoTypeId = protoVec.vectype();
scaleArrayElements[i] = protoTypeId.scale();
precisionArrayElements[i] = protoTypeId.precision();
typeIdArrayElements[i] = static_cast<jint>(protoTypeId.typeid_());
auto vectorDataTypeId = static_cast<omniruntime::type::DataTypeId>(protoTypeId.typeid_());
vecs[i] = VectorHelper::CreateVector(OMNI_FLAT, vectorDataTypeId, rowCount);
vecNativeIdArrayElements[i] = (jlong)(vecs[i]);
auto values = protoVec.values().data();
auto offsets = protoVec.offset().data();
auto nulls = protoVec.nulls().data();
if (vectorDataTypeId == OMNI_CHAR || vectorDataTypeId == OMNI_VARCHAR) {
auto charVec = reinterpret_cast<Vector<LargeStringContainer<std::string_view>> *>(vecs[i]);
char *valuesAddress =
omniruntime::vec::unsafe::UnsafeStringVector::ExpandStringBuffer(charVec, protoVec.values().size());
auto offsetsAddress = (uint8_t *)VectorHelper::UnsafeGetOffsetsAddr(vecs[i]);
memcpy_s(valuesAddress, protoVec.values().size(), values, protoVec.values().size());
memcpy_s(offsetsAddress, protoVec.offset().size(), offsets, protoVec.offset().size());
} else {
auto *valuesAddress = (uint8_t *)VectorHelper::UnsafeGetValues(vecs[i]);
memcpy_s(valuesAddress, protoVec.values().size(), values, protoVec.values().size());
}
for (auto j = 0; j < protoVec.nulls().size(); ++j) {
if (int(nulls[j])) {
vecs[i]->SetNull(j);
}
}
}
env->ReleaseIntArrayElements(typeIdArray, typeIdArrayElements, 0);
env->ReleaseIntArrayElements(precisionArray, precisionArrayElements, 0);
env->ReleaseIntArrayElements(scaleArray, scaleArrayElements, 0);
env->ReleaseLongArrayElements(vecNativeIdArray, vecNativeIdArrayElements, 0);
JNI_FUNC_END_WITH_VECTORS(runtimeExceptionClass, vecs)
}
JNIEXPORT jlong JNICALL
Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseInit(
JNIEnv *env, jobject obj, jlong address, jint length)
{
JNI_FUNC_START
auto *protoRowBatch = new spark::ProtoRowBatch();
protoRowBatch->ParseFromArray(reinterpret_cast<char*>(address), length);
return (jlong)(protoRowBatch);
JNI_FUNC_END(runtimeExceptionClass)
}
JNIEXPORT void JNICALL
Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseClose(
JNIEnv *env, jobject obj, jlong address)
{
JNI_FUNC_START
spark::ProtoRowBatch* protoRowBatch = reinterpret_cast<spark::ProtoRowBatch*>(address);
delete protoRowBatch;
JNI_FUNC_END_VOID(runtimeExceptionClass)
}
JNIEXPORT jint JNICALL
Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseVecCount(
JNIEnv *env, jobject obj, jlong address)
{
JNI_FUNC_START
spark::ProtoRowBatch* protoRowBatch = reinterpret_cast<spark::ProtoRowBatch*>(address);
return protoRowBatch->veccnt();
JNI_FUNC_END(runtimeExceptionClass)
}
JNIEXPORT jint JNICALL
Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseRowCount(
JNIEnv *env, jobject obj, jlong address)
{
JNI_FUNC_START
spark::ProtoRowBatch* protoRowBatch = reinterpret_cast<spark::ProtoRowBatch*>(address);
return protoRowBatch->rowcnt();
JNI_FUNC_END(runtimeExceptionClass)
}
JNIEXPORT void JNICALL
Java_com_huawei_boostkit_spark_serialize_ShuffleDataSerializerUtils_rowShuffleParseBatch(
JNIEnv *env, jobject obj, jlong address, jintArray typeIdArray, jintArray precisionArray,
jintArray scaleArray, jlongArray vecNativeIdArray)
{
spark::ProtoRowBatch* protoRowBatch = reinterpret_cast<spark::ProtoRowBatch*>(address);
int32_t vecCount = protoRowBatch->veccnt();
int32_t rowCount = protoRowBatch->rowcnt();
omniruntime::vec::BaseVector* vecs[vecCount];
std::vector<omniruntime::type::DataTypeId> omniDataTypeIds(vecCount);
int32_t pos = 0;
try {
jint *typeIdArrayElements = env->GetIntArrayElements(typeIdArray, NULL);
jint *precisionArrayElements = env->GetIntArrayElements(precisionArray, NULL);
jint *scaleArrayElements = env->GetIntArrayElements(scaleArray, NULL);
jlong *vecNativeIdArrayElements = env->GetLongArrayElements(vecNativeIdArray, NULL);
for (; pos < vecCount; ++pos) {
const spark::VecType& protoTypeId = protoRowBatch->vectypes(pos);
scaleArrayElements[pos] = protoTypeId.scale();
precisionArrayElements[pos] = protoTypeId.precision();
typeIdArrayElements[pos] = static_cast<jint>(protoTypeId.typeid_());
omniDataTypeIds[pos] = static_cast<omniruntime::type::DataTypeId>(protoTypeId.typeid_());
auto vectorDataTypeId = static_cast<omniruntime::type::DataTypeId>(protoTypeId.typeid_());
vecs[pos] = VectorHelper::CreateVector(OMNI_FLAT, vectorDataTypeId, rowCount);
vecNativeIdArrayElements[pos] = (jlong)(vecs[pos]);
}
auto *parser = new RowParser(omniDataTypeIds);
char *rows = const_cast<char*>(protoRowBatch->rows().data());
const int32_t *offsets = reinterpret_cast<const int32_t*>(protoRowBatch->offsets().data());
for (auto i = 0; i < rowCount; ++i) {
char *rowPtr = rows + offsets[i];
parser->ParseOneRow(reinterpret_cast<uint8_t*>(rowPtr), vecs, i);
}
env->ReleaseIntArrayElements(typeIdArray, typeIdArrayElements, 0);
env->ReleaseIntArrayElements(precisionArray, precisionArrayElements, 0);
env->ReleaseIntArrayElements(scaleArray, scaleArrayElements, 0);
env->ReleaseLongArrayElements(vecNativeIdArray, vecNativeIdArrayElements, 0);
delete parser;
} catch (const std::exception &e) {
for (int32_t i = 0; i < pos; i++) {
delete vecs[i];
}
env->ThrowNew(runtimeExceptionClass, e.what());
return;
}
}