* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "RocksDBStateUploader.h"
#include <iostream>
#include <stdexcept>
namespace fs = std::filesystem;
std::string jstringToString(JNIEnv* env, jstring jstr)
{
if (!env || !jstr) {
return "";
}
const char* cStr = env->GetStringUTFChars(jstr, nullptr);
if (!cStr) {
env->ExceptionClear();
return "";
}
std::string result(cStr);
env->ReleaseStringUTFChars(jstr, cStr);
return result;
}
std::string flinkPathToString(JNIEnv* env, jobject flinkPathObj)
{
if (!env || !flinkPathObj) {
return "";
}
const char* pathClassPath = "org/apache/flink/core/fs/Path";
jclass pathClass = env->FindClass(pathClassPath);
if (!pathClass) {
env->ExceptionDescribe();
env->ExceptionClear();
throw std::runtime_error("Failed to find org.apache.flink.core.fs.Path class");
}
jmethodID getPathMethod = env->GetMethodID(pathClass, "toString", "()Ljava/lang/String;");
if (!getPathMethod) {
env->ExceptionDescribe();
env->DeleteLocalRef(pathClass);
throw std::runtime_error("Failed to get Path.toString() method ID");
}
jstring pathStr = static_cast<jstring>(env->CallObjectMethod(flinkPathObj, getPathMethod));
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
env->DeleteLocalRef(pathClass);
throw std::runtime_error("Failed to call Path.getPath() method");
}
std::string result = jstringToString(env, pathStr);
env->DeleteLocalRef(pathStr);
env->DeleteLocalRef(pathClass);
return result;
}
jobject RocksDBStateUploader::addToJavaPathList(JNIEnv* env,
const std::vector<fs::path>& files,
jobject javaList,
jmethodID arrayListAdd,
jclass pathsClass,
jmethodID pathsGet)
{
for (const auto& cppPath : files) {
std::string pathStr = cppPath.string();
jstring javaStr = env->NewStringUTF(pathStr.c_str());
if (!javaStr) {
env->ExceptionClear();
std::cerr << "Failed to create Java String for path: " << pathStr << std::endl;
continue;
}
jobjectArray moreParts = env->NewObjectArray(0, env->FindClass("java/lang/String"), nullptr);
if (!moreParts) {
env->ExceptionClear();
env->DeleteLocalRef(javaStr);
continue;
}
jobject javaPath = env->CallStaticObjectMethod(pathsClass, pathsGet, javaStr, moreParts);
if (env->ExceptionCheck() || !javaPath) {
env->ExceptionDescribe();
env->ExceptionClear();
env->DeleteLocalRef(javaStr);
env->DeleteLocalRef(moreParts);
continue;
}
env->CallBooleanMethod(javaList, arrayListAdd, javaPath);
env->DeleteLocalRef(javaStr);
env->DeleteLocalRef(moreParts);
env->DeleteLocalRef(javaPath);
}
return javaList;
}
jobject RocksDBStateUploader::createJavaPathList(JNIEnv* env, const std::vector<fs::path>& files)
{
jclass arrayListClass = env->FindClass("java/util/ArrayList");
if (!arrayListClass) {
std::cerr << "Failed to find ArrayList class" << std::endl;
return nullptr;
}
jmethodID arrayListCtor = env->GetMethodID(arrayListClass, "<init>", "()V");
jmethodID arrayListAdd = env->GetMethodID(arrayListClass, "add", "(Ljava/lang/Object;)Z");
if (!arrayListCtor || !arrayListAdd) {
std::cerr << "Failed to get ArrayList methods" << std::endl;
env->DeleteLocalRef(arrayListClass);
return nullptr;
}
jclass pathsClass = env->FindClass("java/nio/file/Paths");
jmethodID pathsGet = env->GetStaticMethodID(pathsClass, "get", "(Ljava/lang/String;[Ljava/lang/String;)Ljava/nio/file/Path;");
if (!pathsClass || !pathsGet) {
std::cerr << "Failed to find Paths.get method" << std::endl;
env->DeleteLocalRef(arrayListClass);
env->DeleteLocalRef(pathsClass);
return nullptr;
}
jobject javaList = env->NewObject(arrayListClass, arrayListCtor);
if (!javaList) {
std::cerr << "Failed to create ArrayList" << std::endl;
env->DeleteLocalRef(arrayListClass);
env->DeleteLocalRef(pathsClass);
return nullptr;
}
javaList = addToJavaPathList(env,
files,
javaList,
arrayListAdd,
pathsClass,
pathsGet);
env->DeleteLocalRef(arrayListClass);
env->DeleteLocalRef(pathsClass);
return javaList;
}
StreamStateHandleType getHandleType(JNIEnv* env, jobject handleObj)
{
if (!env || !handleObj) {
return StreamStateHandleType::Unknown;
}
const char* byteStreamClassPath = "org/apache/flink/runtime/state/memory/ByteStreamStateHandle";
const char* relativeFileClassPath = "org/apache/flink/runtime/state/filesystem/RelativeFileStateHandle";
const char* fileClassPath = "org/apache/flink/runtime/state/filesystem/FileStateHandle";
jclass byteStreamClass = env->FindClass(byteStreamClassPath);
jclass relativeFileClass = env->FindClass(relativeFileClassPath);
jclass fileClass = env->FindClass(fileClassPath);
bool isByteStream = false;
bool isRelativeFile = false;
bool isFile = false;
if (byteStreamClass) {
isByteStream = env->IsInstanceOf(handleObj, byteStreamClass);
env->DeleteLocalRef(byteStreamClass);
}
if (relativeFileClass) {
isRelativeFile = env->IsInstanceOf(handleObj, relativeFileClass);
env->DeleteLocalRef(relativeFileClass);
}
if (fileClass) {
isFile = env->IsInstanceOf(handleObj, fileClass);
env->DeleteLocalRef(fileClass);
}
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
return StreamStateHandleType::Unknown;
}
if (isByteStream) {
return StreamStateHandleType::ByteStreamStateHandle;
} else if (isRelativeFile) {
return StreamStateHandleType::RelativeFileStateHandle;
} else if (isFile) {
return StreamStateHandleType::FileStateHandle;
} else {
return StreamStateHandleType::Unknown;
}
}
* 从FileStateHandle获取属性并创建C++对象
*/
std::shared_ptr<StreamStateHandle> createFileStateHandle(JNIEnv* env, jobject handleObj)
{
if (!env || !handleObj) {
throw std::invalid_argument("Invalid JNI environment or handle object");
}
jclass handleClass = env->GetObjectClass(handleObj);
if (!handleClass) {
env->ExceptionDescribe();
throw std::runtime_error("Failed to get FileStateHandle class");
}
jmethodID getFilePathMethod = env->GetMethodID(handleClass, "getFilePath", "()Lorg/apache/flink/core/fs/Path;");
jmethodID getStateSizeMethod = env->GetMethodID(handleClass, "getStateSize", "()J");
if (!getFilePathMethod || !getStateSizeMethod) {
env->ExceptionDescribe();
env->DeleteLocalRef(handleClass);
throw std::runtime_error("Failed to get FileStateHandle method IDs");
}
jobject jFilePath = env->CallObjectMethod(handleObj, getFilePathMethod);
jlong jStateSize = env->CallLongMethod(handleObj, getStateSizeMethod);
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
env->DeleteLocalRef(handleClass);
if (jFilePath) env->DeleteLocalRef(jFilePath);
throw std::runtime_error("Failed to call FileStateHandle get methods");
}
std::string filePathStr = flinkPathToString(env, jFilePath);
Path filePath(filePathStr);
uint64_t stateSize = static_cast<uint64_t>(jStateSize);
env->DeleteLocalRef(jFilePath);
env->DeleteLocalRef(handleClass);
return std::make_shared<FileStateHandle>(filePath, stateSize);
}
* 从RelativeFileStateHandle获取属性并创建C++对象
*/
std::shared_ptr<StreamStateHandle> createRelativeFileStateHandle(JNIEnv* env, jobject handleObj)
{
if (!env || !handleObj) {
throw std::invalid_argument("Invalid JNI environment or handle object");
}
jclass handleClass = env->GetObjectClass(handleObj);
if (!handleClass) {
env->ExceptionDescribe();
throw std::runtime_error("Failed to get RelativeFileStateHandle class");
}
jmethodID getFilePathMethod = env->GetMethodID(
handleClass,
"getFilePath",
"()Lorg/apache/flink/core/fs/Path;"
);
jmethodID getRelativePathMethod = env->GetMethodID(handleClass, "getRelativePath", "()Ljava/lang/String;");
jmethodID getStateSizeMethod = env->GetMethodID(handleClass, "getStateSize", "()J");
if (!getFilePathMethod || !getRelativePathMethod || !getStateSizeMethod) {
env->ExceptionDescribe();
env->DeleteLocalRef(handleClass);
throw std::runtime_error("Failed to get RelativeFileStateHandle method IDs");
}
jobject jFilePath = env->CallObjectMethod(handleObj, getFilePathMethod);
jobject jRelativePath = env->CallObjectMethod(handleObj, getRelativePathMethod);
jlong jStateSize = env->CallLongMethod(handleObj, getStateSizeMethod);
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
env->DeleteLocalRef(handleClass);
if (jFilePath) env->DeleteLocalRef(jFilePath);
if (jRelativePath) env->DeleteLocalRef(jRelativePath);
throw std::runtime_error("Failed to call RelativeFileStateHandle get methods");
}
std::string filePathStr = flinkPathToString(env, jFilePath);
Path filePath(filePathStr);
std::string relativePath = jstringToString(env, static_cast<jstring>(jRelativePath));
uint64_t stateSize = static_cast<uint64_t>(jStateSize);
env->DeleteLocalRef(jFilePath);
env->DeleteLocalRef(jRelativePath);
env->DeleteLocalRef(handleClass);
return std::make_shared<RelativeFileStateHandle>(filePath, relativePath, stateSize);
}
* 从ByteStreamStateHandle获取属性并创建C++对象
*/
std::shared_ptr<StreamStateHandle> createByteStreamStateHandle(JNIEnv* env, jobject handleObj)
{
if (!env || !handleObj) {
throw std::invalid_argument("Invalid JNI environment or handle object");
}
jclass handleClass = env->GetObjectClass(handleObj);
if (!handleClass) {
env->ExceptionDescribe();
throw std::runtime_error("Failed to get ByteStreamStateHandle class");
}
jmethodID getHandleNameMethod = env->GetMethodID(handleClass, "getHandleName", "()Ljava/lang/String;");
jmethodID getDataMethod = env->GetMethodID(handleClass, "getData", "()[B");
if (!getHandleNameMethod || !getDataMethod) {
env->ExceptionDescribe();
env->DeleteLocalRef(handleClass);
throw std::runtime_error("Failed to get ByteStreamStateHandle method IDs");
}
jobject jHandleName = env->CallObjectMethod(handleObj, getHandleNameMethod);
jbyteArray jData = static_cast<jbyteArray>(env->CallObjectMethod(handleObj, getDataMethod));
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
env->DeleteLocalRef(handleClass);
if (jHandleName) env->DeleteLocalRef(jHandleName);
if (jData) env->DeleteLocalRef(jData);
throw std::runtime_error("Failed to call ByteStreamStateHandle get methods");
}
std::string handleName = jstringToString(env, static_cast<jstring>(jHandleName));
std::vector<uint8_t> data;
if (jData) {
jsize dataLen = env->GetArrayLength(jData);
jbyte* dataBytes = env->GetByteArrayElements(jData, nullptr);
if (dataBytes) {
data.assign(reinterpret_cast<uint8_t*>(dataBytes),
reinterpret_cast<uint8_t*>(dataBytes + dataLen));
env->ReleaseByteArrayElements(jData, dataBytes, 0);
}
env->DeleteLocalRef(jData);
}
env->DeleteLocalRef(jHandleName);
env->DeleteLocalRef(handleClass);
return std::make_shared<ByteStreamStateHandle>(handleName, data);
}
std::shared_ptr<StreamStateHandle> getStreamStateHandle(JNIEnv* env, jobject jHandleObj, jmethodID getHandleMethod)
{
std::shared_ptr<StreamStateHandle> handle;
jobject jHandle = env->CallObjectMethod(jHandleObj, getHandleMethod);
if (jHandle) {
auto type = getHandleType(env, jHandle);
switch (type) {
case StreamStateHandleType::FileStateHandle:
handle = createFileStateHandle(env, jHandle);
break;
case StreamStateHandleType::RelativeFileStateHandle:
handle = createRelativeFileStateHandle(env, jHandle);
break;
case StreamStateHandleType::ByteStreamStateHandle:
handle = createByteStreamStateHandle(env, jHandle);
break;
default:
throw std::runtime_error("Unknown StreamStateHandle type");
}
env->DeleteLocalRef(jHandle);
}
return handle;
}
HandleAndLocalPath convertJavaHandleAndLocalPath(JNIEnv* env, jobject jHandleObj)
{
std::string localPath;
jclass handleClass = env->GetObjectClass(jHandleObj);
if (!handleClass) {
env->ExceptionDescribe();
throw std::runtime_error ("Failed to get HandleAndLocalPath class reference");
}
jmethodID getHandleMethod = env->GetMethodID(handleClass,
"getHandle",
"()Lorg/apache/flink/runtime/state/StreamStateHandle;");
jmethodID getLocalPathMethod = env->GetMethodID(handleClass,
"getLocalPath",
"()Ljava/lang/String;");
if (!getHandleMethod || !getLocalPathMethod) {
env->ExceptionDescribe();
env->DeleteLocalRef(handleClass);
throw std::runtime_error("Failed to get field IDs for HandleAndLocalPath");
}
auto handle = getStreamStateHandle(env, jHandleObj, getHandleMethod);
jstring jLocalPath = static_cast<jstring>(env->CallObjectMethod(jHandleObj, getLocalPathMethod));
if (jLocalPath) {
const char* pathStr = env->GetStringUTFChars(jLocalPath, nullptr);
if (pathStr) {
localPath = std::string(pathStr);
env->ReleaseStringUTFChars(jLocalPath, pathStr);
}
env->DeleteLocalRef(jLocalPath);
}
auto result = HandleAndLocalPath::of(handle, localPath);
env->DeleteLocalRef(handleClass);
return result;
}
std::vector<HandleAndLocalPath> convertJavaListToCppVector(JNIEnv* env, jobject javaList)
{
std::vector<HandleAndLocalPath> result;
if (!env || !javaList) {
return result;
}
jclass listClass = env->FindClass("java/util/List");
if (!listClass) {
env->ExceptionDescribe ();
throw std::runtime_error("Failed to find java.util.List class");
}
jmethodID sizeMethod = env->GetMethodID(listClass, "size", "()I");
jmethodID getMethod = env->GetMethodID(listClass, "get", "(I)Ljava/lang/Object;");
if (!sizeMethod || !getMethod) {
env->ExceptionDescribe();
env->DeleteLocalRef(listClass);
throw std::runtime_error("Failed to get List method IDs");
}
jint listSize = env->CallIntMethod(javaList, sizeMethod);
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
env->DeleteLocalRef(listClass);
throw std::runtime_error ("Failed to call List.size () method");
}
for (jint i = 0; i < listSize; ++i) {
jobject element = env->CallObjectMethod(javaList, getMethod, i);
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
env->DeleteLocalRef(listClass);
throw std::runtime_error("Failed to call List.get () method");
}
if (element) {
auto cppElement = convertJavaHandleAndLocalPath(env, element);
result.push_back(cppElement);
env->DeleteLocalRef(element);
}
}
env->DeleteLocalRef(listClass);
return result;
}
std::vector<HandleAndLocalPath> handleCheckpointResult(JNIEnv* env,
CheckpointedStateScope& stateScope,
jclass uploaderClass,
jobject uploaderInstance,
jobject javaFiles,
jobject jCheckpointStreamFactory)
{
jclass scopeClass = env->FindClass("org/apache/flink/runtime/state/CheckpointedStateScope");
jfieldID scopeField = env->GetStaticFieldID(scopeClass,
stateScope == CheckpointedStateScope::EXCLUSIVE ? "EXCLUSIVE" : "SHARED",
"Lorg/apache/flink/runtime/state/CheckpointedStateScope;");
jobject jStateScope = env->GetStaticObjectField(scopeClass, scopeField);
jclass closeableClass = env->FindClass("org/apache/flink/core/fs/CloseableRegistry");
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
throw std::runtime_error("Failed to find CloseableRegistry class");
}
jmethodID closeableCtor = env->GetMethodID(closeableClass, "<init>", "()V");
jobject jCloseableRegistry = env->NewObject(closeableClass, closeableCtor);
jobject jTmpResourcesRegistry = env->NewObject(closeableClass, closeableCtor);
jmethodID uploadMethod = env->GetMethodID(uploaderClass,
"uploadFilesToCheckpointFs",
"(Ljava/util/List;"
"Lorg/apache/flink/runtime/state/CheckpointStreamFactory;"
"Lorg/apache/flink/runtime/state/CheckpointedStateScope;"
"Lorg/apache/flink/core/fs/CloseableRegistry;"
"Lorg/apache/flink/core/fs/CloseableRegistry;)"
"Ljava/util/List;");
jobject resultList = env->CallObjectMethod(uploaderInstance,
uploadMethod,
javaFiles,
jCheckpointStreamFactory,
jStateScope,
jCloseableRegistry,
jTmpResourcesRegistry);
auto chkResult = convertJavaListToCppVector(env, resultList);
env->DeleteLocalRef(scopeClass);
env->DeleteLocalRef(jStateScope);
env->DeleteLocalRef(jCloseableRegistry);
env->DeleteLocalRef(jTmpResourcesRegistry);
return chkResult;
}
std::vector<HandleAndLocalPath> RocksDBStateUploader::callUploadFilesToCheckpointFs(
JNIEnv* env,
const std::vector<fs::path>& files,
CheckpointedStateScope& stateScope,
jobject jCheckpointStreamFactory)
{
jclass uploaderClass = env->FindClass("org/apache/flink/contrib/streaming/state/RocksDBStateUploader");
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
throw std::runtime_error("Failed to find RocksDBStateUploader class");
}
jmethodID uploaderCtor = env->GetMethodID(uploaderClass, "<init>", "(I)V");
jobject uploaderInstance = env->NewObject(uploaderClass, uploaderCtor, numberOfSnapshottingThreads_);
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
throw std::runtime_error("Failed to NewObject");
}
jobject javaFiles = createJavaPathList(env, files);
auto chkResult = handleCheckpointResult(env,
stateScope,
uploaderClass,
uploaderInstance,
javaFiles,
jCheckpointStreamFactory);
env->DeleteLocalRef(uploaderClass);
env->DeleteLocalRef(uploaderInstance);
env->DeleteLocalRef(javaFiles);
if (env->ExceptionCheck()) {
env->ExceptionDescribe();
env->ExceptionClear();
throw std::runtime_error("Exception occurred during JNI call");
}
return chkResult;
}
std::vector<HandleAndLocalPath> RocksDBStateUploader::callUploadFilesToCheckpointFs(
std::shared_ptr<omnistream::OmniTaskBridge> bridge,
const std::vector<fs::path>& files)
{
std::vector<Path> filePaths;
filePaths.reserve(files.size());
for (const auto& fs_path : files) {
filePaths.emplace_back(fs_path.string());
}
auto HandleAndLocalPathJobj = bridge->CallUploadFilesToCheckpointFs(filePaths, numberOfSnapshottingThreads_);
auto env = bridge->getJNIEnv();
return convertJavaListToCppVector(env, HandleAndLocalPathJobj);
}
RocksDBStateUploader::RocksDBStateUploader(int numberOfSnapshottingThreads)
: numberOfSnapshottingThreads_(numberOfSnapshottingThreads) {}