* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#include "TaskStateSnapshot.h"
const std::shared_ptr<TaskStateSnapshot> TaskStateSnapshot::finishedOnRestore = std::make_shared<TaskStateSnapshot>(
std::unordered_map<OperatorID, std::shared_ptr<OperatorSubtaskState>>(), true, true);
std::shared_ptr<OperatorSubtaskState> TaskStateSnapshot::GetSubtaskStateByOperatorID(const OperatorID& operatorID) const
{
auto it = subtaskStatesByOperatorID.find(operatorID);
if (it != subtaskStatesByOperatorID.end()) {
return it->second;
}
return nullptr;
}
* Maps the given operator id to the given subtask state. Returns the subtask state of a
* previous mapping, if such a mapping existed or null otherwise.
*/
std::shared_ptr<OperatorSubtaskState> TaskStateSnapshot::PutSubtaskStateByOperatorID(
const OperatorID& operatorID, std::shared_ptr<OperatorSubtaskState> state)
{
auto it = subtaskStatesByOperatorID.find(operatorID);
std::shared_ptr<OperatorSubtaskState> previousState = nullptr;
if (it != subtaskStatesByOperatorID.end()) {
previousState = it->second;
}
subtaskStatesByOperatorID[operatorID] = state;
return previousState;
}
std::set<std::pair<OperatorID, std::shared_ptr<OperatorSubtaskState>>> TaskStateSnapshot::GetSubtaskStateMappings()
const
{
std::set<std::pair<OperatorID, std::shared_ptr<OperatorSubtaskState>>> result;
for (const auto& entry : subtaskStatesByOperatorID) {
result.insert(entry);
}
return result;
}
* Returns true if at least one {@link OperatorSubtaskState} in subtaskStatesByOperatorID has
* state.
*/
bool TaskStateSnapshot::HasState() const
{
for (const auto& pair : subtaskStatesByOperatorID) {
std::shared_ptr<OperatorSubtaskState> operatorSubtaskState = pair.second;
if (operatorSubtaskState != nullptr && operatorSubtaskState->HasState()) {
return true;
}
}
return isTaskDeployedAsFinished;
}
* Returns the input channel mapping for rescaling with in-flight data or {@link
* InflightDataRescalingDescriptor#noRescale}.
*/
std::shared_ptr<InflightDataRescalingDescriptor> TaskStateSnapshot::GetInputRescalingDescriptor() const
{
std::vector<std::shared_ptr<InflightDataRescalingDescriptor>> list;
for (const auto& subtaskState : subtaskStatesByOperatorID) {
std::shared_ptr<InflightDataRescalingDescriptor> descriptor =
subtaskState.second->getInputRescalingDescriptor();
if (!InflightDataRescalingDescriptor::noRescale->Equals(&*descriptor)) {
list.emplace_back(descriptor);
}
}
INFO_RELEASE("GetInputRescalingDescriptor descriptor size: " << list.size());
int index = 0;
for (const auto& inflightDataRescalingDescriptor : list) {
INFO_RELEASE(
"InputRescaling index: " << ++index << ", descriptor: " << inflightDataRescalingDescriptor->ToString());
}
if (list.empty()) {
INFO_RELEASE("InputRescaling is noRescale.");
return InflightDataRescalingDescriptor::noRescale;
}
return list[0];
}
* Returns the output channel mapping for rescaling with in-flight data or {@link
* InflightDataRescalingDescriptor#noRescale}.
*/
std::shared_ptr<InflightDataRescalingDescriptor> TaskStateSnapshot::GetOutputRescalingDescriptor() const
{
std::vector<std::shared_ptr<InflightDataRescalingDescriptor>> list;
for (const auto& subtaskState : subtaskStatesByOperatorID) {
std::shared_ptr<InflightDataRescalingDescriptor> descriptor =
subtaskState.second->getOutputRescalingDescriptor();
if (!InflightDataRescalingDescriptor::noRescale->Equals(&*descriptor)) {
list.emplace_back(descriptor);
}
}
LOG("GetOutputRescalingDescriptor descriptor size: " << list.size());
int index = 0;
for (const auto& inflightDataRescalingDescriptor : list) {
LOG("OutputRescaling index: " << ++index << ", descriptor: " << inflightDataRescalingDescriptor->ToString());
}
if (list.empty()) {
LOG("OutputRescaling is noRescale.");
return InflightDataRescalingDescriptor::noRescale;
}
return list[0];
}
void TaskStateSnapshot::DiscardState()
{
std::vector<std::shared_ptr<OperatorSubtaskState>> values;
for (const auto& pair : subtaskStatesByOperatorID) {
values.push_back(pair.second);
}
for (auto value : values) {
value->DiscardState();
}
}
long TaskStateSnapshot::GetStateSize() const
{
long size = 0L;
for (const auto& pair : subtaskStatesByOperatorID) {
std::shared_ptr<OperatorSubtaskState> subtaskState = pair.second;
if (subtaskState != nullptr) {
size += subtaskState->GetStateSize();
}
}
return size;
}
void TaskStateSnapshot::RegisterSharedStates(SharedStateRegistry& stateRegistry, long checkpointID)
{
for (const auto& pair : subtaskStatesByOperatorID) {
std::shared_ptr<OperatorSubtaskState> operatorSubtaskState = pair.second;
if (operatorSubtaskState != nullptr) {
operatorSubtaskState->RegisterSharedStates(stateRegistry, checkpointID);
}
}
}
size_t TaskStateSnapshot::HashCode() const
{
return 0;
}
std::string TaskStateSnapshot::ToString() const
{
nlohmann::json j;
j["stateHandleName"] = "TaskStateSnapshot";
j["isTaskDeployedAsFinished"] = isTaskDeployedAsFinished;
j["isTaskFinished"] = isTaskFinished;
j["stateSize"] = GetStateSize();
nlohmann::json subtask_states_map;
for (const auto& pair : subtaskStatesByOperatorID) {
std::string operatorIdHex = pair.first.toString();
if (pair.second != nullptr) {
subtask_states_map[operatorIdHex] = nlohmann::json::parse(pair.second->ToString());
} else {
subtask_states_map[operatorIdHex] = nullptr;
}
}
j["subtaskStatesByOperatorID"] = subtask_states_map;
return j.dump();
}
std::shared_ptr<InflightDataRescalingDescriptor> TaskStateSnapshot::GetMapping(
std::function<std::shared_ptr<InflightDataRescalingDescriptor>(const std::shared_ptr<OperatorSubtaskState>&)>
mappingExtractor) const
{
std::vector<std::shared_ptr<InflightDataRescalingDescriptor>> mappings;
for (const auto& pair : subtaskStatesByOperatorID) {
if (pair.second != nullptr) {
std::shared_ptr<InflightDataRescalingDescriptor> mapping = mappingExtractor(pair.second);
if (!(mapping == InflightDataRescalingDescriptor::noRescale)) {
mappings.push_back(mapping);
}
}
}
if (mappings.size() == 1) {
return mappings[0];
} else if (mappings.empty()) {
return InflightDataRescalingDescriptor::noRescale;
} else {
throw std::runtime_error("getMapping gets more than 1 result!");
}
}
long TaskStateSnapshot::GetCheckpointedSize()
{
long size = 0L;
for (auto it = subtaskStatesByOperatorID.begin(); it != subtaskStatesByOperatorID.end(); it++) {
auto subtaskState = it->second;
if (subtaskState != nullptr) {
size += subtaskState->GetCheckpointedSize();
}
}
return size;
}
void TaskStateSnapshot::SetIsTaskFinished(bool hasTaskFinished)
{
this->isTaskFinished = hasTaskFinished;
}
void TaskStateSnapshot::SetIsTaskDeployedAsFinished(bool hasTaskDeployedAsFinished)
{
this->isTaskDeployedAsFinished = hasTaskDeployedAsFinished;
}