* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef DYNAMIC_BATCH_RECORDER_H
#define DYNAMIC_BATCH_RECORDER_H
#include <memory>
#include <mutex>
#include <unordered_map>
namespace mindie_llm {
class LatencyPredictor;
class DecodeBatchSizeTracker;
* @brief DynamicBatchRecorder records batch statistics for dynamic batch size across DPs.
*
* This class uses a singleton pattern where each DP rank has its own recorder instance.
* Data from all DPs can be collected and aggregated (e.g., take maximum) for global
* batch size decisions in DP aggregation scenarios.
*/
class DynamicBatchRecorder {
public:
DynamicBatchRecorder(const DynamicBatchRecorder &) = delete;
DynamicBatchRecorder &operator=(const DynamicBatchRecorder &) = delete;
* @brief Get the instance of DynamicBatchRecorder for a specific DP rank.
* @param localDPRank The DP rank ID
* @return Reference to the DynamicBatchRecorder instance
*/
static DynamicBatchRecorder &GetInstance(size_t localDPRank);
* @brief Set the LatencyPredictor for this DP rank.
* @param predictor Shared pointer to LatencyPredictor
*/
void SetLatencyPredictor(const std::shared_ptr<LatencyPredictor> &predictor);
* @brief Get the LatencyPredictor for this DP rank.
* @return Shared pointer to LatencyPredictor
*/
std::shared_ptr<LatencyPredictor> GetLatencyPredictor() const;
* @brief Set the DecodeBatchSizeTracker for this DP rank.
* @param tracker Shared pointer to DecodeBatchSizeTracker
*/
void SetDecodeBatchSizeTracker(const std::shared_ptr<DecodeBatchSizeTracker> &tracker);
* @brief Get the DecodeBatchSizeTracker for this DP rank.
* @return Shared pointer to DecodeBatchSizeTracker
*/
std::shared_ptr<DecodeBatchSizeTracker> GetDecodeBatchSizeTracker() const;
* @brief Get the DP rank for this recorder instance.
* @return DP rank ID
*/
size_t GetLocalDPRank() const;
* @brief Set the running request count for this DP rank.
* @param runningSize Number of running requests
*/
void SetRunningSize(size_t runningSize);
* @brief Get the running request count for this DP rank.
* @return Number of running requests
*/
size_t GetRunningSize() const;
* @brief Aggregate latency, batch size and decode request count from all DPs.
* @param forwardNum Window size for calculating average
* @param maxDecodeLatency Output: Maximum decode latency across all DPs
* @param maxBatchSize Output: Maximum batch size across all DPs
* @param maxDecodeRequestNum Output: Maximum decode request count across all DPs
* @return Number of valid DPs that contributed data
*/
static size_t AggregateAllFromAllDPs(size_t forwardNum, double &maxDecodeLatency, uint64_t &maxBatchSize,
size_t &maxDecodeRequestNum);
~DynamicBatchRecorder() = default;
private:
explicit DynamicBatchRecorder(size_t localDPRank);
private:
size_t localDPRank_{0};
size_t runningSize_{0};
std::shared_ptr<LatencyPredictor> predictor_{nullptr};
std::shared_ptr<DecodeBatchSizeTracker> decodeBatchSizeTracker_{nullptr};
static std::unordered_map<size_t, std::unique_ptr<DynamicBatchRecorder>> instances_;
static std::mutex mutex_;
};
}
#endif