* Copyright (c) Huawei Technologies Co., Ltd. 2012-2025. All rights reserved.
*/
#ifndef TASKSTATEMANAGERBRIDGEIMPL_H
#define TASKSTATEMANAGERBRIDGEIMPL_H
#include <fstream>
#include <jni.h>
#include <state/bridge/TaskStateManagerBridge.h>
#include <common/global.h>
#include "runtime/state/SnapshotResult.h"
#include <iostream>
#include "checkpoint/TaskStateSnapshotDeserializer.h"
namespace omnistream {
class TaskStateManagerBridgeImpl : public TaskStateManagerBridge {
public:
explicit TaskStateManagerBridgeImpl(jobject mGlobalTaskStateMgrRef)
{
this->m_globalTaskStateMgrRef=mGlobalTaskStateMgrRef;
}
void ReportTaskStateSnapshots(std::string &checkpointMetaDataJson,
std::string &checkpointMetricsJson,
std::string &acknowledgedStateJson,
std::string &localStateJson) override
{
JNIEnv* env;
jint res = g_OmniStreamJVM->AttachCurrentThread(reinterpret_cast<void**>(&env), nullptr);
if (res != JNI_OK) {
GErrorLog("Failed to attach C++ thread to JVM inside TaskStateManagerBridgeImpl::ReportTaskStateSnapshots");
return;
}
if (m_globalTaskStateMgrRef != nullptr) {
jclass taskStateManagerWrapperClass = env->GetObjectClass(m_globalTaskStateMgrRef);
if (taskStateManagerWrapperClass == nullptr) {
GErrorLog("Error: Could not get TaskStateManagerWrapper class.");
g_OmniStreamJVM->DetachCurrentThread();
return;
}
jmethodID reportMethodId = env->GetMethodID(taskStateManagerWrapperClass, "reportTaskStateSnapshots",
"(Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;Ljava/lang/String;)V");
if (reportMethodId == nullptr) {
GErrorLog("Error: Could not find method reportTaskStateSnapshots.");
env->DeleteLocalRef(taskStateManagerWrapperClass);
g_OmniStreamJVM->DetachCurrentThread();
return;
}
jstring checkpointMetaData = env->NewStringUTF(checkpointMetaDataJson.c_str());
jstring checkpointMetrics = env->NewStringUTF(checkpointMetricsJson.c_str());
jstring acknowledgedState = env->NewStringUTF(acknowledgedStateJson.c_str());
jstring localState = env->NewStringUTF(localStateJson.c_str());
env->CallVoidMethod(m_globalTaskStateMgrRef, reportMethodId, checkpointMetaData,
checkpointMetrics, acknowledgedState, localState);
if (env->ExceptionCheck()) {
GErrorLog("Error: Exception occurred during Java method invocation.");
env->ExceptionDescribe();
env->ExceptionClear();
}
env->DeleteLocalRef(taskStateManagerWrapperClass);
env->DeleteLocalRef(checkpointMetaData);
env->DeleteLocalRef(checkpointMetrics);
env->DeleteLocalRef(acknowledgedState);
env->DeleteLocalRef(localState);
} else {
GErrorLog("Error: Could not get TaskStateManagerWrapper class for JNI call");
}
g_OmniStreamJVM->DetachCurrentThread();
};
void notifyCheckpointAborted(std::string checkpointId) override
{
JNIEnv* env;
jint res = g_OmniStreamJVM->AttachCurrentThread(reinterpret_cast<void**>(&env), nullptr);
if (res != JNI_OK) {
GErrorLog("Failed to attach C++ thread to JVM inside TaskStateManagerBridgeImpl::ReportTaskStateSnapshots");
return;
}
if (m_globalTaskStateMgrRef != nullptr) {
jclass taskStateManagerWrapperClass = env->GetObjectClass(m_globalTaskStateMgrRef);
if (taskStateManagerWrapperClass == nullptr) {
GErrorLog("Error: Could not get TaskStateManagerWrapper class.");
g_OmniStreamJVM->DetachCurrentThread();
return;
}
jmethodID notifyCheckpointAbortedMethodId = env->GetMethodID(taskStateManagerWrapperClass, "notifyCheckpointAborted",
"(Ljava/lang/String;)V");
if (notifyCheckpointAbortedMethodId == nullptr) {
GErrorLog("Error: Could not find method notifyCheckpointAborted.");
env->DeleteLocalRef(taskStateManagerWrapperClass);
g_OmniStreamJVM->DetachCurrentThread();
return;
}
jstring checkpointIdstr = env->NewStringUTF(checkpointId.c_str());
env->CallVoidMethod(m_globalTaskStateMgrRef, notifyCheckpointAbortedMethodId, checkpointIdstr);
if (env->ExceptionCheck()) {
GErrorLog("Error: Exception occurred during Java method invocation.");
env->ExceptionDescribe();
env->ExceptionClear();
}
env->DeleteLocalRef(taskStateManagerWrapperClass);
env->DeleteLocalRef(checkpointIdstr);
} else {
GErrorLog("Error: Could not get TaskStateManagerWrapper class for JNI call");
}
g_OmniStreamJVM->DetachCurrentThread();
}
void NotifyCheckpointComplete(std::string checkpointId) override
{
JNIEnv* env;
jint res = g_OmniStreamJVM->AttachCurrentThread(reinterpret_cast<void**>(&env), nullptr);
if (res != JNI_OK) {
GErrorLog("Failed to attach C++ thread to JVM inside TaskStateManagerBridgeImpl::ReportTaskStateSnapshots");
return;
}
if (m_globalTaskStateMgrRef != nullptr) {
jclass taskStateManagerWrapperClass = env->GetObjectClass(m_globalTaskStateMgrRef);
if (taskStateManagerWrapperClass == nullptr) {
GErrorLog("Error: Could not get TaskStateManagerWrapper class.");
g_OmniStreamJVM->DetachCurrentThread();
return;
}
jmethodID notifyCheckpointCompleteMethodId = env->GetMethodID(taskStateManagerWrapperClass, "notifyCheckpointComplete",
"(Ljava/lang/String;)V");
if (notifyCheckpointCompleteMethodId == nullptr) {
GErrorLog("Error: Could not find method notifyCheckpointAborted.");
env->DeleteLocalRef(taskStateManagerWrapperClass);
g_OmniStreamJVM->DetachCurrentThread();
return;
}
jstring checkpointIdstr = env->NewStringUTF(checkpointId.c_str());
env->CallVoidMethod(m_globalTaskStateMgrRef, notifyCheckpointCompleteMethodId, checkpointIdstr);
if (env->ExceptionCheck()) {
GErrorLog("Error: Exception occurred during Java method invocation.");
env->ExceptionDescribe();
env->ExceptionClear();
}
env->DeleteLocalRef(taskStateManagerWrapperClass);
env->DeleteLocalRef(checkpointIdstr);
} else {
GErrorLog("Error: Could not get TaskStateManagerWrapper class for JNI call");
}
g_OmniStreamJVM->DetachCurrentThread();
}
std::shared_ptr<TaskStateSnapshot> RetrieveLocalState(long restoreCheckpointId)
{
GErrorLog("method RetrieveLocalState begin!");
JNIEnv* env;
jint res = g_OmniStreamJVM->AttachCurrentThread(reinterpret_cast<void**>(&env), nullptr);
if (res != JNI_OK) {
GErrorLog("Failed to attach C++ thread to JVM inside RetrieveLocalState");
return nullptr;
}
std::shared_ptr<TaskStateSnapshot> taskStateSnapshot = nullptr;
try {
if (m_globalTaskStateMgrRef != nullptr) {
jclass taskStateManagerWrapperClass = env->GetObjectClass(m_globalTaskStateMgrRef);
if (taskStateManagerWrapperClass == nullptr) {
GErrorLog("Error: Could not get TaskStateManagerWrapper class.");
g_OmniStreamJVM->DetachCurrentThread();
return nullptr;
}
jmethodID retrieveMethodId = env->GetMethodID(taskStateManagerWrapperClass, "retrieveLocalState",
"(J)Ljava/lang/String;");
if (retrieveMethodId == nullptr) {
GErrorLog("Error: Could not find method retrieveLocalState.");
env->DeleteLocalRef(taskStateManagerWrapperClass);
g_OmniStreamJVM->DetachCurrentThread();
return nullptr;
}
jstring ret = (jstring)env->CallObjectMethod(m_globalTaskStateMgrRef, retrieveMethodId, (jlong)restoreCheckpointId);
if (env->ExceptionCheck()) {
GErrorLog("Error: Exception occurred during Java method invocation.");
env->ExceptionDescribe();
env->ExceptionClear();
env->DeleteLocalRef(taskStateManagerWrapperClass);
g_OmniStreamJVM->DetachCurrentThread();
return nullptr;
}
if (ret != nullptr) {
const char* resultStr = env->GetStringUTFChars(ret, nullptr);
if (resultStr == nullptr){
GErrorLog("Error: resultStr is null");
env->ExceptionDescribe();
env->ExceptionClear();
env->DeleteLocalRef(taskStateManagerWrapperClass);
g_OmniStreamJVM->DetachCurrentThread();
return nullptr;
}
std::string snapshotInfoString(resultStr);
env->ReleaseStringUTFChars(ret, resultStr);
std::stringstream ss;
ss << "retrieve result for checkpoint " << restoreCheckpointId << ": " << snapshotInfoString;
GErrorLog(ss.str());
if (snapshotInfoString == "NULL") {
GErrorLog("Java side returned NULL - no snapshot available");
} else if (snapshotInfoString == "ERROR") {
GErrorLog("Java side returned ERROR - exception occurred");
} else if (!snapshotInfoString.empty()) {
try {
nlohmann::json snapshotJson = nlohmann::json::parse(snapshotInfoString);
taskStateSnapshot =
TaskStateSnapshotDeserializer::Deserialize(snapshotJson.dump());
std::stringstream taskStateSnapshotstr;
taskStateSnapshotstr << "make taskStateSnapshot:" << taskStateSnapshot->ToString() ;
GErrorLog(taskStateSnapshotstr.str());
} catch (const std::exception& e) {
std::stringstream errorMsg;
errorMsg << "Failed to parse JSON: " << e.what();
GErrorLog(errorMsg.str());
}
} else {
GErrorLog("Received empty string from Java side");
}
} else {
GErrorLog("Java method returned null string");
}
env->DeleteLocalRef(taskStateManagerWrapperClass);
if (ret != nullptr) {
env->DeleteLocalRef(ret);
}
} else {
GErrorLog("Error: m_globalTaskStateMgrRef is null");
}
} catch (const std::exception& e) {
std::stringstream errorMsg;
errorMsg << "Exception in RetrieveLocalState: " << e.what();
GErrorLog(errorMsg.str());
}
g_OmniStreamJVM->DetachCurrentThread();
return taskStateSnapshot;
}
private:
jobject m_globalTaskStateMgrRef;
};
}
#endif