* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "ChannelStateWriterImpl.h"
namespace omnistream {
ChannelStateWriterImpl::ChannelStateWriterImpl(
const JobVertexID &jobVertexID,
const std::string &taskName,
int subtaskIndex,
std::shared_ptr<CheckpointStorage> checkpointStorage,
std::shared_ptr<CheckpointStorageWorkerView> streamFactoryResolver,
int maxCheckpoints,
int maxSubtasksPerChannelStateFile)
: jobVertexID_(jobVertexID),
taskName_(taskName),
subtaskIndex_(subtaskIndex),
maxCheckpoints_(maxCheckpoints),
wasClosed_(false)
{
serializer_ = std::make_shared<ChannelStateSerializerImpl>();
auto dispatcher = std::make_shared<ChannelStateWriteRequestDispatcherImpl>(
checkpointStorage,
JobIDPOD{},
serializer_,
streamFactoryResolver);
executor_ = std::make_shared<ChannelStateWriteRequestExecutorImpl>(dispatcher);
}
void ChannelStateWriterImpl::Start(long checkpointId, const CheckpointOptions &options)
{
std::lock_guard<std::mutex> lock(resultsMutex_);
if (results_.find(checkpointId) != results_.end()) {
return ;
}
if (results_.size() >= static_cast<size_t>(maxCheckpoints_)) {
throw std::runtime_error(taskName_ + " exceeded max checkpoints");
}
auto channelStateWriteResult = ChannelStateWriter::ChannelStateWriteResult::CreateEmpty();
channelStateWriteResult->setIsNeedsChannelState(true);
results_.emplace(checkpointId, channelStateWriteResult);
auto &result = results_.at(checkpointId);
enqueue(
ChannelStateWriteRequest::start(
jobVertexID_,
subtaskIndex_,
checkpointId,
channelStateWriteResult,
"Start"),
false);
}
void ChannelStateWriterImpl::AddInputData(
long checkpointId,
const InputChannelInfo &info,
int startSeqNum,
std::vector<Buffer*> data)
{
validateCheckpointId(checkpointId);
enqueue(
ChannelStateWriteRequest::writeInput(
jobVertexID_,
subtaskIndex_,
checkpointId,
info,
data),
false);
}
void ChannelStateWriterImpl::AddOutputData(
long checkpointId,
const ResultSubpartitionInfoPOD &info,
int startSeqNum,
std::vector<Buffer*> &data)
{
validateCheckpointId(checkpointId);
enqueue(
ChannelStateWriteRequest::writeOutput(
jobVertexID_,
subtaskIndex_,
checkpointId,
info,
data),
false);
}
void ChannelStateWriterImpl::AddOutputDataFuture(
long checkpointId,
const ResultSubpartitionInfoPOD &info,
int startSeqNum,
std::shared_ptr<CompletableFutureV2<std::vector<Buffer*>>> data)
{
validateCheckpointId(checkpointId);
enqueue(
ChannelStateWriteRequest::writeOutputFuture(
jobVertexID_,
subtaskIndex_,
checkpointId,
info,
data),
false);
}
void ChannelStateWriterImpl::FinishInput(long checkpointId)
{
validateCheckpointId(checkpointId);
enqueue(
ChannelStateWriteRequest::completeInput(
jobVertexID_,
subtaskIndex_,
checkpointId),
false);
}
void ChannelStateWriterImpl::FinishOutput(long checkpointId)
{
validateCheckpointId(checkpointId);
enqueue(
ChannelStateWriteRequest::completeOutput(
jobVertexID_,
subtaskIndex_,
checkpointId),
false);
}
void ChannelStateWriterImpl::Abort(long checkpointId, const std::exception_ptr &cause, bool cleanup)
{
enqueue(
ChannelStateWriteRequest::terminate(
jobVertexID_,
subtaskIndex_,
checkpointId,
cause),
true);
enqueue(
ChannelStateWriteRequest::terminate(
jobVertexID_,
subtaskIndex_,
checkpointId,
cause),
false);
if (cleanup) {
std::lock_guard<std::mutex> lock(resultsMutex_);
results_.erase(checkpointId);
}
}
std::shared_ptr<ChannelStateWriter::ChannelStateWriteResult> ChannelStateWriterImpl::GetAndRemoveWriteResult(long checkpointId)
{
std::lock_guard<std::mutex> lock(resultsMutex_);
auto it = results_.find(checkpointId);
if (it == results_.end()) {
throw std::invalid_argument("Checkpoint result not found");
}
auto result = it->second;
results_.erase(it);
return result;
}
void ChannelStateWriterImpl::Close()
{
if (wasClosed_.exchange(true)) {
return;
}
{
std::lock_guard<std::mutex> lock(resultsMutex_);
results_.clear();
}
enqueue(
ChannelStateWriteRequest::releaseSubtask(
jobVertexID_,
subtaskIndex_),
false);
}
void ChannelStateWriterImpl::open()
{
executor_->start();
}
void ChannelStateWriterImpl::validateCheckpointId(long checkpointId)
{
}
void ChannelStateWriterImpl::enqueue(std::shared_ptr<ChannelStateWriteRequest> request, bool priority)
{
LOG_DEBUG(" enqueue " << request->getName() << " checkpointid " << request->getCheckpointId());
if (priority) {
executor_->submitPriority(request);
} else {
executor_->submit(request);
}
}
}