#ifndef OMNIFLINK_KAFKA_SINK_WRITER_STATE_H
#define OMNIFLINK_KAFKA_SINK_WRITER_STATE_H
#include <memory>
#include <string>
#include <vector>
#include <stdexcept>
#include "connector/kafka/sink/KafkaSink.h"
#include "connector/kafka/sink/KafkaWriter.h"
#include "connector/kafka/sink/KafkaWriterState.h"
#include "connector/kafka/sink/KafkaWriterStateSerializer.h"
#include "core/api/common/state/ListStateDescriptor.h"
#include "core/api/common/state/ListState.h"
#include "core/typeutils/BytePrimitiveArraySerializer.h"
#include "streaming/api/operators/util/SimpleVersionedListState.h"
#include "streaming/api/operators/sink/InitContextImpl.h"
#include "runtime/state/StateInitializationContextImpl.h"
class SinkWriterStateHandler {
public:
virtual ~SinkWriterStateHandler() = default;
template <typename K1>
KafkaWriter* createWriter(InitContextImpl<K1>* initContext, StateInitializationContextImpl* context) {
return nullptr;
}
void snapshotState(long checkpointId) {}
};
class KafkaSinkWriterStateHandler : public SinkWriterStateHandler {
private:
static std::string writerRawStatesName;
static ListStateDescriptor<std::vector<uint8_t>> WRITER_RAW_STATES_DESC;
std::shared_ptr<KafkaWriterStateSerializer> writerStateSerializer;
std::vector<std::string> previousSinkStateNames;
KafkaSink* sink;
std::vector<ListState<KafkaWriterState>*> previousSinkStates;
std::shared_ptr<SimpleVersionedListState<KafkaWriterState>> writerState;
KafkaWriter* kafkaWriter;
public:
explicit KafkaSinkWriterStateHandler(KafkaSink* sink)
: sink(sink), kafkaWriter(nullptr) {
if (!sink) {
throw std::invalid_argument("sink cannot be null");
}
writerStateSerializer = std::make_shared<KafkaWriterStateSerializer>();
}
virtual ~KafkaSinkWriterStateHandler() override {
for (auto state : previousSinkStates) {
delete state;
}
delete kafkaWriter;
}
template <typename K1>
KafkaWriter* createWriter(InitContextImpl<K1>* initContext, StateInitializationContextImpl* context) {
INFO_RELEASE("savepoint: KafkaSinkWriterStateHandler createWriter")
auto* operatorStateBackend = static_cast<DefaultOperatorStateBackend*>(context->getOperatorStateBackend());
auto rawState = operatorStateBackend->getListState(&WRITER_RAW_STATES_DESC);
writerState = std::make_shared<SimpleVersionedListState<KafkaWriterState>>(
rawState,
writerStateSerializer);
if (context->isRestored()) {
auto states = writerState->get();
std::vector<KafkaWriterState> statesList;
if (states) {
statesList = *states;
}
INFO_RELEASE("savepoint: KafkaSinkWriterStateHandler createWriter, statesList size: " << statesList.size())
kafkaWriter = sink->RestoreWriter(initContext, statesList);
} else {
INFO_RELEASE("savepoint: KafkaSinkWriterStateHandler createWriter, create new KafkaWriter")
kafkaWriter = sink->CreateWriter(initContext);
}
return kafkaWriter;
}
void snapshotState(long checkpointId) {
if (!kafkaWriter) {
throw std::runtime_error("kafkaWriter has not been created yet");
}
std::vector<KafkaWriterState> currentState = kafkaWriter->snapshotState(checkpointId);
writerState->update(currentState);
for (auto state : previousSinkStates) {
state->clear();
}
}
};
inline std::string KafkaSinkWriterStateHandler::writerRawStatesName = "writer_raw_states";
inline ListStateDescriptor<std::vector<uint8_t>> KafkaSinkWriterStateHandler::WRITER_RAW_STATES_DESC(
KafkaSinkWriterStateHandler::writerRawStatesName, new BytePrimitiveArraySerializer(nullptr));
#endif