* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "fl/server/round.h"
#include <memory>
#include <string>
#include "fl/server/server.h"
#include "fl/server/iteration.h"
namespace mindspore {
namespace fl {
namespace server {
class Server;
class Iteration;
Round::Round(const std::string &name, bool check_timeout, size_t time_window, bool check_count, size_t threshold_count,
bool server_num_as_threshold)
: name_(name),
check_timeout_(check_timeout),
time_window_(time_window),
check_count_(check_count),
threshold_count_(threshold_count),
server_num_as_threshold_(server_num_as_threshold) {}
void Round::Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
const FinishIterCb &finish_iteration_cb) {
MS_EXCEPTION_IF_NULL(communicator);
communicator_ = communicator;
communicator_->RegisterMsgCallBack(name_, [&](std::shared_ptr<ps::core::MessageHandler> message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
LaunchRoundKernel(message);
});
finish_iteration_cb_ = [this, finish_iteration_cb](bool is_iteration_valid, const std::string &) -> void {
std::string reason = "Round " + name_ + " finished! This iteration is valid. Proceed to next iteration.";
finish_iteration_cb(is_iteration_valid, reason);
};
finalize_cb_ = [&](void) -> void {
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
(void)communicator_->Stop();
};
if (check_timeout_) {
iter_timer_ = std::make_shared<IterationTimer>();
MS_EXCEPTION_IF_NULL(iter_timer_);
iter_timer_->SetTimeOutCallBack([this, timeout_cb](bool is_iteration_valid, const std::string &) -> void {
std::string reason = "Round " + name_ + " timeout! This iteration is invalid. Proceed to next iteration.";
timeout_cb(is_iteration_valid, reason);
});
stop_timer_cb_ = [&](void) -> void {
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
MS_LOG(INFO) << "Round " << name_ << " kernel stops its timer.";
iter_timer_->Stop();
};
}
if (check_count_) {
auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
{first_count_handler, last_count_handler});
}
}
bool Round::ReInitForScaling(uint32_t server_num) {
if (server_num_as_threshold_) {
MS_LOG(INFO) << "Round " << name_ << " uses up-to-date server number " << server_num << " as its threshold count.";
threshold_count_ = server_num;
}
if (check_count_) {
auto first_count_handler = std::bind(&Round::OnFirstCountEvent, this, std::placeholders::_1);
auto last_count_handler = std::bind(&Round::OnLastCountEvent, this, std::placeholders::_1);
DistributedCountService::GetInstance().RegisterCounter(name_, threshold_count_,
{first_count_handler, last_count_handler});
}
MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
kernel_->InitKernel(threshold_count_);
return true;
}
bool Round::ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window) {
time_window_ = updated_time_window;
threshold_count_ = updated_threshold_count;
if (check_count_) {
if (!DistributedCountService::GetInstance().ReInitCounter(name_, threshold_count_)) {
MS_LOG(ERROR) << "Reinitializing count for " << name_ << " failed.";
return false;
}
}
MS_ERROR_IF_NULL_W_RET_VAL(kernel_, false);
kernel_->InitKernel(threshold_count_);
return true;
}
void Round::BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel) {
MS_EXCEPTION_IF_NULL(kernel);
kernel_ = kernel;
kernel_->set_stop_timer_cb(stop_timer_cb_);
kernel_->set_finish_iteration_cb(finish_iteration_cb_);
return;
}
void Round::LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(message);
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
MS_ERROR_IF_NULL_WO_RET_VAL(communicator_);
std::string reason = "";
if (!IsServerAvailable(&reason)) {
if (!communicator_->SendResponse(reason.c_str(), reason.size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
return;
}
return;
}
++Iteration::GetInstance().running_round_num_;
AddressPtr input = std::make_shared<Address>();
AddressPtr output = std::make_shared<Address>();
MS_ERROR_IF_NULL_WO_RET_VAL(input);
MS_ERROR_IF_NULL_WO_RET_VAL(output);
input->addr = message->data();
input->size = message->len();
bool ret = kernel_->Launch({input}, {}, {output});
if (output->size == 0) {
reason = "The output of the round " + name_ + " is empty.";
MS_LOG(WARNING) << reason;
if (!communicator_->SendResponse(reason.c_str(), reason.size(), message)) {
MS_LOG(ERROR) << "Sending response failed.";
return;
}
return;
}
if (!communicator_->SendResponse(output->addr, output->size, message)) {
MS_LOG(ERROR) << "Sending response failed.";
return;
}
kernel_->Release(output);
if (!ret) {
reason = "Launching round kernel of round " + name_ + " failed.";
Iteration::GetInstance().NotifyNext(false, reason);
}
--Iteration::GetInstance().running_round_num_;
return;
}
void Round::Reset() {
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
(void)kernel_->Reset();
}
const std::string &Round::name() const { return name_; }
size_t Round::threshold_count() const { return threshold_count_; }
bool Round::check_timeout() const { return check_timeout_; }
size_t Round::time_window() const { return time_window_; }
void Round::OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
MS_LOG(INFO) << "Round " << name_ << " first count event is triggered.";
if (check_timeout_) {
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
iter_timer_->Start(std::chrono::milliseconds(time_window_));
}
kernel_->OnFirstCountEvent(message);
return;
}
void Round::OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message) {
MS_ERROR_IF_NULL_WO_RET_VAL(kernel_);
MS_LOG(INFO) << "Round " << name_ << " last count event is triggered.";
if (check_timeout_) {
MS_ERROR_IF_NULL_WO_RET_VAL(iter_timer_);
iter_timer_->Stop();
}
kernel_->OnLastCountEvent(message);
return;
}
bool Round::IsServerAvailable(std::string *reason) {
MS_ERROR_IF_NULL_W_RET_VAL(reason, false);
if (Iteration::GetInstance().instance_state() == InstanceState::kFinish && name_ == "getModel") {
return true;
}
if (Iteration::GetInstance().instance_state() == InstanceState::kDisable ||
Iteration::GetInstance().instance_state() == InstanceState::kFinish) {
MS_LOG(WARNING) << "The server's training job is disabled or finished, please retry " + name_ + " later.";
*reason = ps::kJobNotAvailable;
return false;
}
if (Server::GetInstance().IsSafeMode()) {
MS_LOG(WARNING) << "The cluster is still in process of scaling, please retry " << name_ << " later.";
*reason = ps::kClusterSafeMode;
return false;
}
return true;
}
}
}
}