* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "KafkaWriter.h"
#include <chrono>
#include <cstddef>
#include "../bind_core_manager.h"
#include "include/common.h"
KafkaWriter::KafkaWriter(
DeliveryGuarantee deliveryGuarantee,
RdKafka::Conf* kafkaProducerConfig,
std::string& transactionalIdPrefix,
std::string& topic,
const nlohmann::json& description,
int64_t maxPushRecords,
InitContextImpl<void*>* initContext,
const std::vector<KafkaWriterState>& states)
: kafkaProducerConfig(kafkaProducerConfig),
topic(topic),
deliveryGuarantee(deliveryGuarantee),
transactionalIdPrefix(transactionalIdPrefix),
description(description),
limit(maxPushRecords)
{
Init();
if (description["batch"]) {
inputFields = description["inputFields"].get<std::vector<std::basic_string<char>>>();
inputTypes = description["inputTypes"].get<std::vector<std::basic_string<char>>>();
recordSerializer = new DynamicKafkaRecordSerializationSchema(inputFields, inputTypes);
}
if (initContext != nullptr && initContext->getRestoredCheckpointId().has_value()) {
lastCheckpointId = initContext->getRestoredCheckpointId().value();
}
if (deliveryGuarantee == DeliveryGuarantee::EXACTLY_ONCE) {
abortLingeringTransactions(states, lastCheckpointId);
}
currentProducer1 =
std::make_shared<FlinkKafkaInternalProducer>(kafkaProducerConfig, std::to_string(producerIndexOne));
currentProducer2 =
std::make_shared<FlinkKafkaInternalProducer>(kafkaProducerConfig, std::to_string(producerIndexTwo));
RdKafka::Conf* tconf = RdKafka::Conf::create(RdKafka::Conf::CONF_TOPIC);
std::string errstr;
this->kafkaProducerConfig->set("default_topic_conf", tconf, errstr);
rd_topic1 = RdKafka::Topic::create(currentProducer1->getKafkaProducer(), topic, tconf, errstr);
rd_topic2 = RdKafka::Topic::create(currentProducer2->getKafkaProducer(), topic, tconf, errstr);
partitionNum = rd_topic1->get_partition_num();
kafkaWriterState = new KafkaWriterState(transactionalIdPrefix);
taskId = omnistream::TimerThreadPool::GetTimerThreadPoolInstance()->addPeriodicTask(
5000, [](KafkaWriter* kafkaWriter) { kafkaWriter->timer_thread(); }, this);
}
KafkaWriter::~KafkaWriter()
{
timer_worker_thread_flag.store(false);
omnistream::TimerThreadPool::GetTimerThreadPoolInstance()->cancel(taskId);
{
std::lock_guard<std::mutex> lock(gMtx);
}
stop_flag.store(true);
cv.notify_all();
if (worker_thread.joinable()) {
worker_thread.join();
}
delete kafkaWriterState;
kafkaWriterState = nullptr;
delete recordSerializer;
recordSerializer = nullptr;
delete rd_topic1;
rd_topic1 = nullptr;
delete rd_topic2;
rd_topic2 = nullptr;
}
void KafkaWriter::write(String* element)
{
if (unlikely(not binded)) {
if (bindCore >= 0) {
omnistream::BindCoreManager::GetInstance()->BindDirectCore(bindCore);
}
binded = true;
}
auto record = recordSerializer->Serialize(element);
ProduceRecord(record);
element->putRefCount();
}
void KafkaWriter::write(Row* element)
{
auto record = recordSerializer->Serialize(element);
ProduceRecord(record);
}
void KafkaWriter::write(RowData* element)
{
auto record = recordSerializer->Serialize(element);
ProduceRecord(record);
}
void KafkaWriter::write(omnistream::VectorBatch* input, int rowIndex)
{
auto record = recordSerializer->Serialize(input, rowIndex);
ProduceRecord(record);
}
void KafkaWriter::Flush(bool endOfInput)
{
if (deliveryGuarantee != DeliveryGuarantee::NONE || endOfInput) {
{
std::unique_lock<std::mutex> gLock(gMtx);
handleRecord();
}
waitForPendingTasks();
currentProducer1->Flush();
currentProducer2->Flush();
}
}
std::vector<KafkaCommittable> KafkaWriter::prepareCommit()
{
if (deliveryGuarantee == DeliveryGuarantee::EXACTLY_ONCE) {
auto committables = std::vector<KafkaCommittable>{
KafkaCommittable::of(currentProducer1.get(), [this](FlinkKafkaInternalProducer* producer) {
producerPool.push_back(static_cast<const std::shared_ptr<FlinkKafkaInternalProducer>>(producer));
})};
return committables;
}
return {};
}
void KafkaWriter::AbortCurrentProducer()
{
if (currentProducer1->IsInTransaction()) {
currentProducer1->AbortTransaction();
}
}
void KafkaWriter::abortLingeringTransactions(
const std::vector<KafkaWriterState>& recoveredStates, long startCheckpointId)
{
auto prefixesToAbort = std::vector<std::string>{transactionalIdPrefix};
if (!recoveredStates.empty()) {
const auto& lastState = recoveredStates.front();
if (lastState.getTransactionalIdPrefix() != transactionalIdPrefix) {
prefixesToAbort.push_back(lastState.getTransactionalIdPrefix());
}
}
TransactionAborter transactionAborter(
[this](const std::string& transactionalId) { return getOrCreateTransactionalProducer(transactionalId); },
[this](const std::shared_ptr<FlinkKafkaInternalProducer>& producer) { producerPool.push_back(producer); });
transactionAborter.abortLingeringTransactions(prefixesToAbort, startCheckpointId);
}
std::shared_ptr<FlinkKafkaInternalProducer> KafkaWriter::getTransactionalProducer(long checkpointId)
{
std::shared_ptr<FlinkKafkaInternalProducer> producer;
for (long id = lastCheckpointId + 1; id <= checkpointId; id++) {
auto transactionalId = TransactionalIdFactory::buildTransactionalId(transactionalIdPrefix, 0, id);
producer = getOrCreateTransactionalProducer(transactionalId);
}
lastCheckpointId = checkpointId;
return producer;
}
std::shared_ptr<FlinkKafkaInternalProducer> KafkaWriter::getOrCreateTransactionalProducer(
const std::string& transactionalId)
{
auto producer = producerPool.empty() ? nullptr : producerPool.front();
if (!producer) {
producer = std::make_shared<FlinkKafkaInternalProducer>(kafkaProducerConfig, transactionalId);
} else {
producerPool.pop_front();
}
return producer;
}
void KafkaWriter::ProduceRecord(KeyValueByteContainer& record)
{
std::unique_lock<std::mutex> gLock(gMtx);
values.push_back(record.value);
valuesLens.push_back(record.valueLen);
++cur;
if (cur >= limit) {
handleRecord();
}
}
void KafkaWriter::handleRecord()
{
const size_t recordCount = values.size();
if (recordCount == 0) {
valuesLens.clear();
cur = 0;
return;
}
if (recordCount != valuesLens.size()) {
INFO_RELEASE(
"KafkaWriter::handleRecord cached record size mismatch, values size: "
<< recordCount << ", valuesLens size: " << valuesLens.size() << ", cur: " << cur);
return;
}
const size_t mid = recordCount / 2;
std::vector<char*> first_half(values.begin(), values.begin() + mid);
std::vector<size_t> first_half_lens(valuesLens.begin(), valuesLens.begin() + mid);
std::vector<char*> last_half(values.begin() + mid, values.end());
std::vector<size_t> last_half_lens(valuesLens.begin() + mid, valuesLens.end());
auto kafkaProducer1 = currentProducer1->getKafkaProducer();
produce(kafkaProducer1, rd_topic1, first_half, first_half_lens);
{
std::lock_guard<std::mutex> lock(queueMutex);
tasks.emplace([this, last_half, last_half_lens]() {
auto kafkaProducer2 = currentProducer2->getKafkaProducer();
produce(kafkaProducer2, rd_topic2, last_half, last_half_lens);
});
}
cv.notify_one();
values.clear();
valuesLens.clear();
cur = 0;
}
void KafkaWriter::waitForPendingTasks()
{
std::unique_lock<std::mutex> lock(queueMutex);
tasksDrainedCv.wait(lock, [this]() { return tasks.empty() && inFlightTasks == 0; });
}
void KafkaWriter::SetSubTaskIdx(int32_t subtaskIdx)
{
this->instanceId = subtaskIdx;
if (omnistream::BindCoreManager::GetInstance()->NeedBindSink()) {
bindCore = omnistream::BindCoreManager::GetInstance()->GetSinkCore(subtaskIdx);
}
worker_thread = std::thread(&KafkaWriter::WorkerThreadFunc, this);
}
void KafkaWriter::produce(
RdKafka::Producer* kafkaProducer,
RdKafka::Topic* rd_topic,
const std::vector<char*>& value,
const std::vector<size_t>& valuesLen)
{
partitionNum = rd_topic->get_partition_num();
int32_t realPartition = partitionNum == 0 ? RdKafka::Topic::PARTITION_UA : (instanceId % partitionNum);
RdKafka::ErrorCode resp = kafkaProducer->produce(
rd_topic,
realPartition,
RdKafka::Producer::RK_MSG_FREE | RdKafka::Producer::RK_MSG_BLOCK,
value,
valuesLen,
nullptr,
0,
nullptr);
if (resp != RdKafka::ERR_NO_ERROR) {
LOG("Produce failed:" << RdKafka::err2str(resp));
}
kafkaProducer->poll(0);
}
std::vector<KafkaWriterState> KafkaWriter::snapshotState(long checkpointId)
{
if (deliveryGuarantee == DeliveryGuarantee::EXACTLY_ONCE) {
auto currentProducer = getTransactionalProducer(checkpointId + 1);
currentProducer->BeginTransaction();
}
return {*kafkaWriterState};
}