/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef THESTRAL_PLUGIN_MASTER_JNI_COMMON_H
#define THESTRAL_PLUGIN_MASTER_JNI_COMMON_H

#include <jni.h>
#include "util/omni_exception.h"
#include "common/common.h"
#include "compute/ResultIterator.h"
#include "compute/Runtime.h"


class JniColumnarBatchIterator : public omniruntime::ColumnarBatchIterator {
public:
    explicit JniColumnarBatchIterator(JNIEnv *env, jobject jColumnarBatchItr);

    JniColumnarBatchIterator(const JniColumnarBatchIterator &) = delete;

    JniColumnarBatchIterator(JniColumnarBatchIterator &&) = delete;

    JniColumnarBatchIterator &operator=(const JniColumnarBatchIterator &) = delete;

    JniColumnarBatchIterator &operator=(JniColumnarBatchIterator &&) = delete;

    ~JniColumnarBatchIterator() override;

    omniruntime::vec::VectorBatch *Next() override;

private:
    JavaVM *vm_;
    jobject jColumnarBatchItr_;
    jclass serializedColumnarBatchIteratorClass_;
    jmethodID serializedColumnarBatchIteratorHasNext_;
    jmethodID serializedColumnarBatchIteratorNext_;
};

static jint jniVersion = JNI_VERSION_1_8;
static inline void AttachCurrentThreadAsDaemonOrThrow(JavaVM* vm, JNIEnv** out) {
    int getEnvStat = vm->GetEnv(reinterpret_cast<void**>(out), jniVersion);
    if (getEnvStat == JNI_EDETACHED) {
        // Reattach current thread to JVM
        getEnvStat = vm->AttachCurrentThreadAsDaemon(reinterpret_cast<void**>(out), NULL);
        if (getEnvStat != JNI_OK) {
            throw std::runtime_error("Failed to reattach current thread to JVM.");
        }
        return;
    }
    if (getEnvStat != JNI_OK) {
        throw std::runtime_error("Failed to attach current thread to JVM.");
    }
}

static inline std::string JStringToCString(JNIEnv *env, jstring string)
{
    int32_t clen = env->GetStringUTFLength(string);
    int32_t jlen = env->GetStringLength(string);
    char buffer[clen];
    env->GetStringUTFRegion(string, 0, jlen, buffer);
    return std::string(buffer, clen);
}

static inline void CheckException(JNIEnv* env) {
    if (env->ExceptionCheck()) {
        jthrowable t = env->ExceptionOccurred();
        env->ExceptionClear();
        jclass describerClass = env->FindClass("org/apache/gluten/exception/JniExceptionDescriber");
        jmethodID describeMethod =
        env->GetStaticMethodID(describerClass, "describe", "(Ljava/lang/Throwable;)Ljava/lang/String;");
        std::string description =
            JStringToCString(env, (jstring)env->CallStaticObjectMethod(describerClass, describeMethod, t));
        if (env->ExceptionCheck()) {
            LogWarn("Fatal: Uncaught Java exception during calling the Java exception describer method! ");
        }
        throw omniruntime::exception::OmniException("Error during calling Java code from native code: " + description);
    }
}

static inline jclass createGlobalClassReference(JNIEnv *env, const char *className)
{
    jclass localClass = env->FindClass(className);
    jclass globalClass = (jclass)env->NewGlobalRef(localClass);
    env->DeleteLocalRef(localClass);
    return globalClass;
}

static inline jclass createGlobalClassReferenceOrError(JNIEnv *env, const char *className)
{
    jclass globalClass = createGlobalClassReference(env, className);
    if (globalClass == nullptr) {
        std::string errorMessage = "Unable to CreateGlobalClassReferenceOrError for" + std::string(className);
        throw std::runtime_error(errorMessage);
    }
    return globalClass;
}

static inline jmethodID getMethodId(JNIEnv *env, jclass thisClass, const char *name, const char *sig)
{
    jmethodID ret = env->GetMethodID(thisClass, name, sig);
    return ret;
}

static inline jmethodID getMethodIdOrError(JNIEnv *env, jclass thisClass, const char *name, const char *sig)
{
    jmethodID ret = getMethodId(env, thisClass, name, sig);
    if (ret == nullptr) {
        std::string errorMessage = "Unable to find method " + std::string(name) + " within signature" +
            std::string(sig);
        throw std::runtime_error(errorMessage);
    }
    return ret;
}

spark::CompressionKind GetCompressionType(JNIEnv* env, jstring codec_jstr);

jclass CreateGlobalClassReference(JNIEnv* env, const char* class_name);

jmethodID GetMethodID(JNIEnv* env, jclass this_class, const char* name, const char* sig);

omniruntime::Runtime *GetRuntime(JNIEnv *env, jobject runtimeAware);

#define JNI_FUNC_START try {

#define JNI_FUNC_END(exceptionClass)                \
    }                                               \
    catch (const std::exception &e)                 \
    {                                               \
        env->ThrowNew(exceptionClass, e.what());    \
        return 0;                                   \
    }                                               \


#define JNI_FUNC_END_VOID(exceptionClass)           \
    }                                               \
    catch (const std::exception &e)                 \
    {                                               \
        env->ThrowNew(exceptionClass, e.what());    \
        return;                                     \
    }                                               \

#define JNI_FUNC_END_WITH_VECBATCH(exceptionClass, toDeleteVecBatch) \
    }                                                                \
    catch (const std::exception &e)                                  \
    {                                                                \
        VectorHelper::FreeVecBatch(toDeleteVecBatch);                \
        env->ThrowNew(exceptionClass, e.what());          \
        return 0;                                         \
    }

#define JNI_FUNC_END_WITH_VECTORS(exceptionClass, vectors)       \
    } catch (const std::exception &e) {                          \
        for (auto vec : vectors) {                               \
            delete vec;                                          \
        }                                                        \
        env->ThrowNew(runtimeExceptionClass, e.what());          \
        return;                                                  \
    }                                                            \


extern jclass runtimeExceptionClass;
extern jclass splitResultClass;
extern jclass jsonClass;
extern jclass arrayListClass;
extern jclass threadClass;
extern jclass serializedColumnarBatchIteratorClass;
extern jclass vecBatchCls;
extern jclass infoCls;
extern jclass runtimeAwareClass;
extern jclass metricsBuilderClass;

extern jmethodID jsonMethodInt;
extern jmethodID jsonMethodLong;
extern jmethodID jsonMethodHas;
extern jmethodID jsonMethodString;
extern jmethodID jsonMethodJsonObj;
extern jmethodID arrayListGet;
extern jmethodID arrayListSize;
extern jmethodID jsonMethodObj;
extern jmethodID splitResultConstructor;
extern jmethodID currentThread;
extern jmethodID threadGetId;
extern jmethodID serializedColumnarBatchIteratorHasNext;
extern jmethodID serializedColumnarBatchIteratorNext;
extern jmethodID vecBatchInitMethodId;
extern jmethodID method;
extern jmethodID runtimeAwareCtxHandle;
extern jmethodID metricsBuilderConstructor;
#endif //THESTRAL_PLUGIN_MASTER_JNI_COMMON_H