* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "SingleCheckpointBarrierHandler.h"
#include <iostream>
#include <algorithm>
#include <stdexcept>
#include "common.h"
#include "io/network/partition/consumer/CheckpointableInput.h"
#include "runtime/jobgraph/tasks/CheckpointableTask.h"
#include "runtime/io/checkpointing/BarrierAlignmentUtil.h"
#include "runtime/io/checkpointing/AlternatingWaitingForFirstBarrierUnaligned.h"
#include "WaitingForFirstBarrier.h"
#include "runtime/io/checkpointing/AlternatingWaitingForFirstBarrier.h"
namespace omnistream::runtime {
SingleCheckpointBarrierHandler::SingleCheckpointBarrierHandler(
const std::string& taskName,
CheckpointableTask* toNotifyOnCheckpoint,
SubtaskCheckpointCoordinator* subTaskCheckpointCoordinator,
Clock& clock,
int numOpenChannels,
BarrierHandlerState* currentState,
bool alternating,
BarrierAlignmentUtil::DelayableTimer<std::function<void()>>* registerTimer,
const std::vector<CheckpointableInput*>& inputs,
bool enableCheckpointAfterTasksFinished)
: CheckpointBarrierHandler(toNotifyOnCheckpoint, clock, enableCheckpointAfterTasksFinished),
taskName_(taskName),
registerTimer_(registerTimer),
subTaskCheckpointCoordinator_(subTaskCheckpointCoordinator),
inputs_(inputs),
currentCheckpointId_(-1L),
targetChannelCount_(0),
lastCancelledOrCompletedCheckpointId_(-1L),
numOpenChannels_(numOpenChannels),
currentState_(currentState),
alternating_(alternating)
{
LOG_DEBUG("SingleCheckpointBarrierHandler init")
context_ = new ControllerImpl(this, subTaskCheckpointCoordinator_);
}
SingleCheckpointBarrierHandler::~SingleCheckpointBarrierHandler()
{
delete currentState_;
delete context_;
}
void SingleCheckpointBarrierHandler::ProcessBarrier(const CheckpointBarrier& barrier,
const InputChannelInfo& channelInfo,
bool isRpcTriggered)
{
int64_t barrierId = barrier.GetId();
LOG(taskName_ + ": Received barrier from channel " + channelInfo.toString() + " @" + std::to_string(barrierId));
if (barrierId <= lastCancelledOrCompletedCheckpointId_) {
if (!barrier.GetCheckpointOptions()->IsUnalignedCheckpoint()) {
inputs_[channelInfo.getGateIdx()]->ResumeConsumption(channelInfo);
}
LOG("Func:ProcessBarrier returns, because barrierId <= lastCancelledOrCompletedCheckpointId_");
return;
}
if (currentCheckpointId_ > barrierId ||
(currentCheckpointId_ == barrierId && !IsCheckpointPending())) {
if (!barrier.GetCheckpointOptions()->IsUnalignedCheckpoint()) {
inputs_[channelInfo.getGateIdx()]->ResumeConsumption(channelInfo);
}
LOG("Func:ProcessBarrier returns, because urrentCheckpointId_ > barrierId ||(currentCheckpointId_ == barrierId && !IsCheckpointPending())");
return;
}
CheckNewCheckpoint(barrier);
if (currentCheckpointId_ != barrierId) {
throw std::runtime_error("Current checkpoint ID mismatch");
}
if (alignedChannels_.find(channelInfo) != alignedChannels_.end()) {
LOG("Func:ProcessBarrier returns, because alignedChannels_.find(channelInfo) != alignedChannels_.end()");
return;
}
const bool markChannelBlocked = (!isRpcTriggered) && !currentCheckpointUnaligned_ &&
!barrier.GetCheckpointOptions()->IsUnalignedCheckpoint();
CheckpointBarrier* barrierPtr = const_cast<CheckpointBarrier*>(&barrier);
auto stateTransformer = [this, &channelInfo, barrierPtr, markChannelBlocked](
BarrierHandlerState *state) -> BarrierHandlerState* {
LOG(">>>>> stateTransformer")
return state->BarrierReceived(context_, channelInfo, barrierPtr, markChannelBlocked);
};
MarkCheckpointAlignedAndTransformState(channelInfo, *barrierPtr, stateTransformer);
}
void SingleCheckpointBarrierHandler::MarkCheckpointAlignedAndTransformState(
const InputChannelInfo& alignedChannel,
const CheckpointBarrier& barrier,
const std::function<BarrierHandlerState*(BarrierHandlerState*)>& stateTransformer)
{
LOG_DEBUG("Func:MarkCheckpointAlignedAndTransformState start")
alignedChannels_.insert(alignedChannel);
const bool isAlwaysUnalignedHandler = (!alternating_ && subTaskCheckpointCoordinator_ != nullptr);
const bool shouldTrackAlignment = !isAlwaysUnalignedHandler &&
!currentCheckpointUnaligned_ &&
!barrier.GetCheckpointOptions()->IsUnalignedCheckpoint();
const bool isLastBarrierForAlignment =
shouldTrackAlignment && (alignedChannels_.size() == static_cast<size_t>(targetChannelCount_));
if (alternating_ && isLastBarrierForAlignment) {
ResetAlignmentTimer();
if (!allBarriersReceivedFuture_V2->IsDone()) {
allBarriersReceivedFuture_V2->Complete();
}
}
if (alignedChannels_.size() == 1 && shouldTrackAlignment) {
if (targetChannelCount_ == 1) {
MarkAlignmentStartAndEnd(barrier.GetId(), barrier.GetTimestamp());
} else {
MarkAlignmentStart(barrier.GetId(), barrier.GetTimestamp());
}
}
if (alignedChannels_.size() == (unsigned int)targetChannelCount_ && shouldTrackAlignment) {
if (targetChannelCount_ > 1) {
MarkAlignmentEnd();
}
}
try {
auto* old = currentState_;
auto* next = stateTransformer(old);
if (next != old) {
delete old;
}
currentState_ = next;
} catch (const CheckpointException& e) {
AbortInternal(currentCheckpointId_, e);
return;
} catch (const std::exception& e) {
CheckpointException checkpointException(
CheckpointFailureReason::CHECKPOINT_DECLINED,
e);
AbortInternal(currentCheckpointId_, checkpointException);
throw std::runtime_error("Error in state transformation: " + std::string(e.what()));
} catch (...) {
CheckpointException checkpointException(CheckpointFailureReason::CHECKPOINT_DECLINED);
AbortInternal(currentCheckpointId_, checkpointException);
throw std::runtime_error("Error in state transformation: unknown exception");
}
if (alignedChannels_.size() == (unsigned int)targetChannelCount_) {
alignedChannels_.clear();
lastCancelledOrCompletedCheckpointId_ = currentCheckpointId_;
LOG_DEBUG(taskName_ + ": All the channels are aligned for checkpoint " + std::to_string(currentCheckpointId_))
ResetAlignmentTimer();
if (!allBarriersReceivedFuture_V2->IsDone()) {
allBarriersReceivedFuture_V2->Complete();
}
}
}
void SingleCheckpointBarrierHandler::TriggerCheckpoint(const CheckpointBarrier& trigger)
{
LOG(taskName_ + " : Triggering checkpoint " + std::to_string(trigger.GetId()) +
" on the barrier announcement at " + std::to_string(trigger.GetTimestamp()));
NotifyCheckpoint(trigger);
}
void SingleCheckpointBarrierHandler::ProcessBarrierAnnouncement(const CheckpointBarrier& announcedBarrier,
int sequenceNumber,
const InputChannelInfo& channelInfo)
{
LOG_DEBUG("start ProcessBarrierAnnouncement, barrier Id: " << announcedBarrier.GetId())
if (announcedBarrier.GetId() <= lastCancelledOrCompletedCheckpointId_) {
LOG_DEBUG("announcedBarrier.GetId() <= lastCancelledOrCompletedCheckpointId_. barrier Id: " << announcedBarrier.GetId())
return;
}
if (announcedBarrier.GetCheckpointOptions() == nullptr ||
announcedBarrier.GetCheckpointOptions()->IsUnalignedCheckpoint()) {
LOG_DEBUG("announcedBarrier.GetCheckpointOptions() == nullptr || announcedBarrier.GetCheckpointOptions()->IsUnalignedCheckpoint(), barrier Id: " << announcedBarrier.GetId())
return;
}
CheckNewCheckpoint(announcedBarrier);
int64_t barrierId = announcedBarrier.GetId();
if (currentCheckpointId_ > barrierId || (currentCheckpointId_ == barrierId && !IsCheckpointPending())) {
LOG_DEBUG("scurrentCheckpointId_ > barrierId || (currentCheckpointId_ == barrierId && !IsCheckpointPending()), barrier Id: " << announcedBarrier.GetId())
std::cout << taskName_ << ": Obsolete announcement of checkpoint "
<< barrierId << " for channel " << channelInfo.toString() << std::endl;
return;
}
try {
auto* oldState = currentState_;
auto* nextState = currentState_->AnnouncementReceived(dynamic_cast<Controller*>(context_),
channelInfo,
sequenceNumber);
if (nextState != oldState) {
delete oldState;
}
currentState_ = nextState;
} catch (...) {
}
}
void SingleCheckpointBarrierHandler::ProcessCancellationBarrier(
const CancelCheckpointMarker& cancelBarrier,
const InputChannelInfo& channelInfo)
{
int64_t cancelledId = cancelBarrier.getCheckpointId();
if (cancelledId <= lastCancelledOrCompletedCheckpointId_) {
return;
}
std::cout << taskName_ << ": Received cancellation " << cancelledId << std::endl;
AbortInternal(cancelledId, CheckpointFailureReason::CHECKPOINT_DECLINED_ON_CANCELLATION_BARRIER);
}
void SingleCheckpointBarrierHandler::ProcessEndOfPartition(const InputChannelInfo& channelInfo)
{
numOpenChannels_--;
}
void SingleCheckpointBarrierHandler::CheckNewCheckpoint(const CheckpointBarrier& barrier)
{
LOG_DEBUG(">>>>>" << " id " << barrier.GetId()<< " timestamp " << barrier.GetTimestamp())
int64_t barrierId = barrier.GetId();
if (currentCheckpointId_ >= barrierId) {
return;
}
if (IsCheckpointPending()) {
CancelSubsumedCheckpoint(barrierId);
}
currentCheckpointId_ = barrierId;
pendingCheckpointBarrier_ = const_cast<CheckpointBarrier*>(&barrier);
alignedChannels_.clear();
targetChannelCount_ = numOpenChannels_;
currentCheckpointUnaligned_ = barrier.GetCheckpointOptions()->IsUnalignedCheckpoint();
LOG("CheckNewCheckpoint cp="<<barrierId<<" create future");
allBarriersReceivedFuture_V2 = std::make_shared<CompletableFutureV2<void>>();
}
void SingleCheckpointBarrierHandler::RegisterAlignmentTimer(const CheckpointBarrier& announcedBarrier)
{
LOG("Func:RegisterAlignmentTimer start.")
if (!announcedBarrier.GetCheckpointOptions()->IsTimeoutable()) {
return;
}
ResetAlignmentTimer();
int64_t timerDelay = BarrierAlignmentUtil::getTimerDelay(GetClock().RelativeTimeMillis(), announcedBarrier);
const int64_t barrierId = announcedBarrier.GetId();
currentAlignmentTimerCheckpointId_ = barrierId;
auto timerTask = [this, announcedBarrier, barrierId]() mutable {
try {
if (currentAlignmentTimerCheckpointId_ != barrierId) {
return;
}
if (currentCheckpointId_ != barrierId || GetAllBarriersReceivedFuture(barrierId)->IsDone()) {
return;
}
if (currentCheckpointUnaligned_) {
return;
}
if (GetAllBarriersReceivedFuture(barrierId)->IsDone()) {
return;
}
LOG("Timeout, start transition.")
currentCheckpointUnaligned_ = true;
auto* old = currentState_;
auto* next = old->AlignedCheckpointTimeout(
dynamic_cast<Controller*>(context_),
const_cast<CheckpointBarrier*>(&announcedBarrier));
if (next != old) {
delete old;
}
currentState_ = next;
} catch (const CheckpointException& ex) {
AbortInternal(barrierId, ex);
} catch (const std::exception& e) {
throw std::runtime_error("Timer task error: " + std::string(e.what()));
}
if (currentAlignmentTimerCheckpointId_ == barrierId) {
currentAlignmentTimer_.reset();
currentAlignmentTimerCheckpointId_ = -1;
}
};
currentAlignmentTimer_.reset(registerTimer_->RegisterTask(timerTask, std::chrono::milliseconds(timerDelay)));
}
void SingleCheckpointBarrierHandler::AbortInternal(int64_t cancelledId, CheckpointFailureReason reason)
{
CheckpointException exception(reason);
AbortInternal(cancelledId, exception);
}
void SingleCheckpointBarrierHandler::AbortInternal(int64_t cancelledId, const CheckpointException& exception)
{
if (cancelledId <= lastCancelledOrCompletedCheckpointId_) {
return;
}
LOG(taskName_ + ": Aborting checkpoint " + std::to_string(cancelledId) +
" due to " + omnistream::toString(exception.GetCheckpointFailureReason()));
ResetAlignmentTimer();
for (auto* input : inputs_) {
for (const auto& info : input->GetChannelInfos()) {
input->ResumeConsumption(info);
}
input->CheckpointStopped(cancelledId);
}
alignedChannels_.clear();
pendingCheckpointBarrier_ = nullptr;
currentCheckpointUnaligned_ = false;
lastCancelledOrCompletedCheckpointId_ = cancelledId;
ResetToWaitingState();
try {
NotifyAbort(cancelledId, exception);
} catch (...) {
}
if (allBarriersReceivedFuture_V2 && !allBarriersReceivedFuture_V2->IsDone()) {
allBarriersReceivedFuture_V2->Cancel();
}
}
void SingleCheckpointBarrierHandler::ResetToWaitingState()
{
if (currentState_ != nullptr) {
delete currentState_;
currentState_ = nullptr;
}
ChannelState channelState(inputs_);
if (subTaskCheckpointCoordinator_ == nullptr) {
currentState_ = new WaitingForFirstBarrier(std::move(channelState));
return;
}
if (alternating_) {
currentState_ = new AlternatingWaitingForFirstBarrier(std::move(channelState));
} else {
currentState_ = new AlternatingWaitingForFirstBarrierUnaligned(false, std::move(channelState));
}
}
void SingleCheckpointBarrierHandler::CancelSubsumedCheckpoint(int64_t barrierId)
{
std::cout << taskName_ << ": Received checkpoint barrier for checkpoint "
<< barrierId << " before completing current checkpoint "
<< currentCheckpointId_ << ". Skipping current checkpoint." << std::endl;
LOG(taskName_ + ": Received checkpoint barrier for checkpoint " + std::to_string(barrierId) +
" before completing current checkpoint " + std::to_string(currentCheckpointId_) +
". Skipping current checkpoint.")
AbortInternal(currentCheckpointId_, CheckpointFailureReason::CHECKPOINT_DECLINED_SUBSUMED);
}
void SingleCheckpointBarrierHandler::ResetAlignmentTimer()
{
LOG("ResetAlignmentTimer")
if (currentAlignmentTimer_) {
currentAlignmentTimer_->Cancel();
currentAlignmentTimer_.reset();
}
currentAlignmentTimerCheckpointId_ = -1;
}
// Reset all barriers received future
void SingleCheckpointBarrierHandler::ResetAllBarriersReceivedFuture()
{
allBarriersReceivedPromise_ = std::promise<void>();
allBarriersReceivedFuture_ = allBarriersReceivedPromise_.get_future();
}
*/
int64_t SingleCheckpointBarrierHandler::GetLatestCheckpointId() const
{
return currentCheckpointId_;
}
BarrierHandlerState *SingleCheckpointBarrierHandler::GetCurrentState() const
{
return currentState_;
}
bool SingleCheckpointBarrierHandler::IsCheckpointPending() const
{
return currentCheckpointId_ != lastCancelledOrCompletedCheckpointId_ && currentCheckpointId_ >= 0;
}
std::shared_ptr<CompletableFutureV2<void>> SingleCheckpointBarrierHandler::GetAllBarriersReceivedFuture(int64_t checkpointId)
{
if (checkpointId < currentCheckpointId_ || numOpenChannels_ == 0) {
if (!completed_V2->IsDone()) {
completed_V2->Complete();
}
return completed_V2;
}
if (checkpointId > currentCheckpointId_) {
throw std::runtime_error("Checkpoint " + std::to_string(checkpointId) + " has not been started at all");
}
LOG_DEBUG("SingleCheckpointBarrierHandler GetAllBarriersReceivedFuture checkpointId: " << checkpointId
<< ", currentCheckpointId: " << currentCheckpointId_);
return allBarriersReceivedFuture_V2;
}
int SingleCheckpointBarrierHandler::GetNumOpenChannels() const
{
return numOpenChannels_;
}
void SingleCheckpointBarrierHandler::Close()
{
ResetAlignmentTimer();
if (allBarriersReceivedFuture_V2) {
allBarriersReceivedFuture_V2->Cancel();
}
CheckpointBarrierHandler::Close();
}
std::string SingleCheckpointBarrierHandler::ToString() const
{
return taskName_ + ": current checkpoint: " + std::to_string(currentCheckpointId_) +
", current aligned channels: " + std::to_string(alignedChannels_.size()) +
", target channel count: " + std::to_string(targetChannelCount_);
}
std::unique_ptr<SingleCheckpointBarrierHandler> SingleCheckpointBarrierHandler::createUnalignCheckpointBarrierHandler(
SubtaskCheckpointCoordinator* checkpointCoordinator,
const std::string& taskName,
CheckpointableTask* toNotifyOnCheckpoint,
Clock& clock,
bool enableCheckpointsAfterTasksFinish,
const std::vector<CheckpointableInput*>& inputs)
{
int channelCount = 0;
for (const auto& input : inputs) {
channelCount += input->GetChannelInfos().size();
}
auto timer = std::make_unique<BarrierAlignmentUtil::ThrowingDelayableTimer<std::function<void()>>>();
return unaligned(
taskName,
toNotifyOnCheckpoint,
checkpointCoordinator,
clock,
channelCount,
timer.get(),
enableCheckpointsAfterTasksFinish,
inputs);
}
std::unique_ptr<SingleCheckpointBarrierHandler> SingleCheckpointBarrierHandler::unaligned(
const std::string& taskName,
CheckpointableTask* toNotifyOnCheckpoint,
SubtaskCheckpointCoordinator* checkpointCoordinator,
Clock& clock,
int numOpenChannels,
BarrierAlignmentUtil::DelayableTimer<std::function<void()>>* registerTimer,
bool enableCheckpointAfterTasksFinished,
const std::vector<CheckpointableInput*>& inputs)
{
ChannelState channelState(inputs);
auto state = new AlternatingWaitingForFirstBarrierUnaligned(false, std::move(channelState));
return std::make_unique<SingleCheckpointBarrierHandler>(
taskName,
toNotifyOnCheckpoint,
checkpointCoordinator,
clock,
numOpenChannels,
state,
false,
registerTimer,
inputs,
enableCheckpointAfterTasksFinished);
}
std::unique_ptr<SingleCheckpointBarrierHandler> SingleCheckpointBarrierHandler::aligned(const std::string& taskName,
CheckpointableTask* toNotifyOnCheckpoint, Clock& clock, int numOpenChannels,
BarrierAlignmentUtil::DelayableTimer<std::function<void()>>* registerTimer, bool enableCheckpointAfterTasksFinished,
const std::vector<CheckpointableInput*>& inputs)
{
auto channelState = ChannelState(inputs);
auto state = new WaitingForFirstBarrier(std::move(channelState));
return std::make_unique<SingleCheckpointBarrierHandler>(
taskName,
toNotifyOnCheckpoint,
nullptr,
clock,
numOpenChannels,
state,
false,
registerTimer,
inputs,
enableCheckpointAfterTasksFinished);
}
std::unique_ptr<SingleCheckpointBarrierHandler> SingleCheckpointBarrierHandler::alternating(
const std::string& taskName,
CheckpointableTask* toNotifyOnCheckpoint, SubtaskCheckpointCoordinator* checkpointCoordinator, Clock& clock,
int numOpenChannels, BarrierAlignmentUtil::DelayableTimer<std::function<void()>>* registerTimer,
bool enableCheckpointAfterTasksFinished, const std::vector<CheckpointableInput*>& inputs)
{
ChannelState channelState(inputs);
auto state = new AlternatingWaitingForFirstBarrier(std::move(channelState));
return std::make_unique<SingleCheckpointBarrierHandler>(
taskName,
toNotifyOnCheckpoint,
checkpointCoordinator,
clock,
numOpenChannels,
state,
true,
registerTimer,
inputs,
enableCheckpointAfterTasksFinished);
}
}