* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
#define MINDSPORE_CCSRC_FL_SERVER_ROUND_H_
#include <memory>
#include <string>
#include "ps/core/communicator/communicator_base.h"
#include "fl/server/common.h"
#include "fl/server/iteration_timer.h"
#include "fl/server/distributed_count_service.h"
#include "fl/server/kernel/round/round_kernel.h"
namespace mindspore {
namespace fl {
namespace server {
class Round {
public:
explicit Round(const std::string &name, bool check_timeout = true, size_t time_window = 3000,
bool check_count = false, size_t threshold_count = 8, bool server_num_as_threshold = false);
~Round() = default;
void Initialize(const std::shared_ptr<ps::core::CommunicatorBase> &communicator, const TimeOutCb &timeout_cb,
const FinishIterCb &finish_iteration_cb);
bool ReInitForScaling(uint32_t server_num);
bool ReInitForUpdatingHyperParams(size_t updated_threshold_count, size_t updated_time_window);
void BindRoundKernel(const std::shared_ptr<kernel::RoundKernel> &kernel);
void LaunchRoundKernel(const std::shared_ptr<ps::core::MessageHandler> &message);
void Reset();
const std::string &name() const;
size_t threshold_count() const;
bool check_timeout() const;
size_t time_window() const;
private:
void OnFirstCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
void OnLastCountEvent(const std::shared_ptr<ps::core::MessageHandler> &message);
bool IsServerAvailable(std::string *reason);
std::string name_;
bool check_timeout_;
size_t time_window_;
bool check_count_;
size_t threshold_count_;
bool server_num_as_threshold_;
std::shared_ptr<ps::core::CommunicatorBase> communicator_;
std::shared_ptr<kernel::RoundKernel> kernel_;
std::shared_ptr<IterationTimer> iter_timer_;
StopTimerCb stop_timer_cb_;
FinishIterCb finish_iteration_cb_;
FinalizeCb finalize_cb_;
};
}
}
}
#endif