* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* MindIE is licensed under Mulan PSL v2.
* You can use this software according to the terms and conditions of the Mulan PSL v2.
* You may obtain a copy of Mulan PSL v2 at:
* http://license.coscl.org.cn/MulanPSL2
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
* EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
* MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
* See the Mulan PSL v2 for more details.
*/
#ifndef THREAD_GROUP_CC_H
#define THREAD_GROUP_CC_H
#include <algorithm>
#include <any>
#include <condition_variable>
#include <iostream>
#include <mutex>
#include <thread>
#include <vector>
namespace mindie_llm {
enum class CollectiveType : int8_t {
BARRIER,
BROADCAST,
GATHER,
SCATTER,
ALL_TO_ALL,
REDUCE
};
enum class ReduceOp {
SUM,
PRODUCT,
MAX,
MIN,
};
class ThreadGroupCC {
public:
static ThreadGroupCC &GetInstance(size_t numThreads = 2);
~ThreadGroupCC() = default;
template <typename T>
void AllGather(const std::vector<T> &sendData, std::vector<std::vector<T>> &recvData, size_t idx);
protected:
explicit ThreadGroupCC(size_t numThreads);
template <typename T>
void AllGatherSend_(const std::vector<T> &sendData, size_t idx);
template <typename T>
void AllGatherRecv_(std::vector<std::vector<T>> &recvData, size_t idx);
template <typename T>
static void CopyData2Buf_(const std::vector<T> &src, std::vector<std::any> &dst);
template <typename T>
static void CopyBuf2Data_(const std::vector<std::any> &src, std::vector<T> &dst);
private:
size_t numThreads_;
std::mutex barrierMtx_;
size_t barrierCount_{0};
size_t barrierPhase_{0};
std::condition_variable barrierCv_;
std::mutex broadcastMtx_;
std::condition_variable broadcastDataReadyCv_;
std::condition_variable broadcastAllReadCv_;
std::vector<std::any> broadcastBuf_;
std::vector<bool> broadcastReadyVec_;
size_t broadcastReadersDone_{0};
std::mutex gatherMtx_;
std::condition_variable gatherDataReadyCv_;
std::condition_variable gatherAllReadCv_;
std::vector<std::vector<std::any>> gatherBuf_;
std::vector<bool> gatherReadyVec_;
std::vector<bool> gatherReaderDoneVec_;
std::mutex scatterMtx_;
std::condition_variable scatterDataReadyCv_;
std::condition_variable scatterAllReadCv_;
std::vector<std::vector<std::any>> scatterBuf_;
std::vector<bool> scatterReadyVec_;
size_t scatterReaderDone_{0};
std::mutex allGatherMtx_;
std::condition_variable allGatherDataReadyCv_;
std::condition_variable allGatherAllReadCv_;
std::vector<std::vector<std::any>> allGatherBuf_;
std::vector<std::vector<bool>> allGatherReadyVec_;
std::vector<std::vector<bool>> allGatherReaderDoneVec_;
std::mutex reduceMtx_;
std::condition_variable reduceDataReadyCv_;
std::condition_variable reduceAllReadCv_;
std::vector<std::vector<std::any>> reduceBuf_;
std::vector<bool> reduceReadyVec_;
std::vector<bool> reduceReaderDoneVec_;
};
template <typename T>
void ThreadGroupCC::AllGather(const std::vector<T> &sendData, std::vector<std::vector<T>> &recvData, size_t idx) {
AllGatherSend_(sendData, idx);
AllGatherRecv_(recvData, idx);
}
template <typename T>
void ThreadGroupCC::AllGatherSend_(const std::vector<T> &sendData, size_t idx) {
if (idx >= allGatherBuf_.size()) {
throw std::out_of_range("AllGather index out of range: " + std::to_string(idx) +
" >= " + std::to_string(allGatherBuf_.size()));
}
CopyData2Buf_(sendData, allGatherBuf_[idx]);
std::unique_lock<std::mutex> lock(allGatherMtx_);
for (size_t i = 0; i < numThreads_; ++i) {
if (idx >= allGatherReadyVec_[i].size() || idx >= allGatherReaderDoneVec_[i].size()) {
throw std::runtime_error("AllGather index out of range: " + std::to_string(idx) + " >= numThreads_" +
std::to_string(i));
}
allGatherReadyVec_[i][idx] = true;
allGatherReaderDoneVec_[i][idx] = false;
}
allGatherDataReadyCv_.notify_all();
}
template <typename T>
void ThreadGroupCC::AllGatherRecv_(std::vector<std::vector<T>> &recvData, size_t idx) {
{
std::unique_lock<std::mutex> lock(allGatherMtx_);
allGatherDataReadyCv_.wait(lock, [this, idx] {
return std::all_of(allGatherReadyVec_[idx].begin(), allGatherReadyVec_[idx].end(),
[](bool ready) { return ready; });
});
}
recvData.resize(numThreads_);
for (size_t i = 0; i < numThreads_; ++i) {
CopyBuf2Data_(allGatherBuf_[i], recvData[i]);
}
std::unique_lock<std::mutex> lock(allGatherMtx_);
std::fill(allGatherReadyVec_[idx].begin(), allGatherReadyVec_[idx].end(), false);
std::fill(allGatherReaderDoneVec_[idx].begin(), allGatherReaderDoneVec_[idx].end(), true);
allGatherAllReadCv_.notify_all();
allGatherAllReadCv_.wait(lock, [this, idx] {
for (size_t i = 0; i < numThreads_; ++i) {
if (!allGatherReaderDoneVec_[i][idx]) {
return false;
}
}
return true;
});
}
template <typename T>
void ThreadGroupCC::CopyData2Buf_(const std::vector<T> &src, std::vector<std::any> &dst) {
dst.resize(src.size());
for (size_t i = 0; i < dst.size(); ++i) {
dst[i] = std::make_any<T>(src[i]);
}
}
template <typename T>
void ThreadGroupCC::CopyBuf2Data_(const std::vector<std::any> &src, std::vector<T> &dst) {
dst.resize(src.size());
for (size_t i = 0; i < dst.size(); ++i) {
dst[i] = std::any_cast<T>(src[i]);
}
}
}
#endif