* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef OMNISTREAM_TASKLOCALSTATESTORE_H
#define OMNISTREAM_TASKLOCALSTATESTORE_H
#include <map>
#include <mutex>
#include <utility>
#include <set>
#include "core/utils/threads/CompletableFutureV2.h"
#include "runtime/executiongraph/JobIDPOD.h"
#include "runtime/jobgraph/JobVertexID.h"
#include "runtime/checkpoint/TaskStateSnapshot.h"
#include "runtime/checkpoint/TaskStateSnapshotDeserializer.h"
#include "runtime/checkpoint/TaskStateSnapshotSerializer.h"
#include "LocalRecoveryConfig.h"
#include "LocalRecoveryDirectoryProviderImpl.h"
namespace fs = std::filesystem;
class TaskLocalStateStore {
public:
TaskLocalStateStore() = default;
TaskLocalStateStore(
omnistream::JobIDPOD jobID,
omnistream::JobVertexID jobVertexId,
int subtaskIndex,
std::shared_ptr<LocalRecoveryConfig> localRecoveryConfig
);
CompletableFutureV2<void> *dispose();
void storeLocalState(long checkpointId, std::shared_ptr<TaskStateSnapshot> localState);
std::shared_ptr<TaskStateSnapshot> retrieveLocalState(long checkpointID);
void pruneMatchingCheckpoints(std::function<bool(long)> matcher);
void confirmCheckpoint(long confirmedCheckpointId);
void abortCheckpoint(long abortedCheckpointId);
static std::shared_ptr<TaskStateSnapshot> nullDummy();
std::shared_ptr<LocalRecoveryDirectoryProvider> getLocalRecoveryDirectoryProvider()
{
auto provider = localRecoveryConfig_->GetLocalStateDirectoryProvider();
if (provider == nullptr) {
throw std::invalid_argument("Local recovery must be enabled.");
}
return provider;
}
const std::string taskStateSnapshotFilename = "_task_state_snapshot";
std::shared_ptr<LocalRecoveryConfig> getLocalRecoveryConfig()
{
return localRecoveryConfig_;
}
void setLocalRecoveryConfig(std::shared_ptr<LocalRecoveryConfig> localRecoveryConfig)
{
localRecoveryConfig_= std::move(localRecoveryConfig);
}
protected:
void pruneCheckpoints(std::function<bool(long)> matcher, bool breakOnceCheckerFalse);
std::map<long, std::shared_ptr<TaskStateSnapshot>> storedTaskStateByCheckpointID;
private:
void persistLocalStateMetadata(long checkpointId, std::shared_ptr<TaskStateSnapshot> localState);
void asyncDiscardLocalStateForCollection(std::set<std::pair<long, std::shared_ptr<TaskStateSnapshot>>> toDiscard);
void syncDiscardLocalStateForCollection(std::set<std::pair<long, std::shared_ptr<TaskStateSnapshot>>> toDiscard);
void discardLocalStateForCheckpoint(long checkpointID, std::shared_ptr<TaskStateSnapshot> o);
std::shared_ptr<TaskStateSnapshot> loadTaskStateSnapshot(long checkpointID);
std::shared_ptr<TaskStateSnapshot> tryLoadTaskStateSnapshotFromDisk(long checkpointID);
std::filesystem::path getTaskStateSnapshotFile(long checkpointId)
{
return getCheckpointDirectory(checkpointId) / taskStateSnapshotFilename;
}
std::filesystem::path getCheckpointDirectory(long checkpointId)
{
return getLocalRecoveryDirectoryProvider()->SubtaskSpecificCheckpointDirectory(checkpointId);
}
void createFolderOrFail(const std::filesystem::path& checkpointDirectory)
{
if (!std::filesystem::exists(checkpointDirectory) &&
!std::filesystem::create_directories(checkpointDirectory)) {
throw std::runtime_error(
std::string("Could not create the checkpoint directory '") +
checkpointDirectory.string() + "'");
}
}
omnistream::JobIDPOD jobID_;
omnistream::JobVertexID jobVertexId_;
int subtaskIndex_;
bool disposed_;
std::mutex mutexLock;
std::shared_ptr<LocalRecoveryConfig> localRecoveryConfig_;
};
#endif