* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
* \file hccl_ccu_v0_prepare.h
* \brief
*/
#ifndef IMPL_HCCL_CCU_V0_PREEPARE_H
#define IMPL_HCCL_CCU_V0_PREEPARE_H
#include "hccl_ccu_xn_utils.h"
#include "../impl/platform_v310/hccl_ccu_v0_def.h"
namespace AscendC {
template<const auto &config>
__aicore__ inline void
HcclImpl<HcclServerType::HCCL_SERVER_TYPE_CCU, config>::CcuPrepareForAllToAllV(__gm__ CommonPrepareParamCcu *commParam, __gm__ AlltoAllVParamCcu *allToAllVParam)
{
xnData_[0] = GetOpId(commParam);
uint64_t offset = commParam->count * ccuParam_.repeatIndex * DATA_TYPE_MAP[commParam->dataType];
xnData_[1] = (uint64_t)commParam->sendBuf + offset;
xnData_[2] = (uint64_t)commParam->recvBuf + offset;
xnData_[3] = 0;
xnData_[4] = 0;
AssembleHcclMsgExtForCCU(ccuParam_, commParam, allToAllVParam);
xnData_[5] = reinterpret_cast<uint64_t>(ccuParam_.ccuMsgExt) + CCU_MSG_EXT_RANK_OFFSET * ccuParam_.alltoallvCnt++;
return;
}
template<const auto &config>
__aicore__ inline void
HcclImpl<HcclServerType::HCCL_SERVER_TYPE_CCU, config>::CcuPrepareForAllToAll(__gm__ CommonPrepareParamCcu *commParam)
{
uint64_t dataSize = DATA_TYPE_MAP[commParam->dataType];
xnData_[0] = GetOpId(commParam);
uint64_t offset = commParam->count * ccuParam_.repeatIndex * dataSize;
xnData_[1] = (uint64_t)commParam->sendBuf + offset;
xnData_[2] = (uint64_t)commParam->recvBuf + offset;
uint64_t sliceSizeAlltoall = commParam->count * dataSize;
uint64_t strideSize = commParam->strideCount == 0 ? sliceSizeAlltoall : commParam->strideCount * dataSize;
xnData_[3] = sliceSizeAlltoall;
xnData_[4] = strideSize;
xnData_[5] = 0;
xnData_[6] = strideSize * ccuParam_.rankId;
return;
}
template<const auto &config>
__aicore__ inline void
HcclImpl<HcclServerType::HCCL_SERVER_TYPE_CCU, config>::CcuPrepareForAllToAllVWrite(__gm__ CommonPrepareParamCcu *commParam)
{
xnData_[0] = GetOpId(commParam);
uint64_t offset = commParam->count * ccuParam_.repeatIndex * DATA_TYPE_MAP[commParam->dataType];
xnData_[1] = (uint64_t)commParam->sendBuf + offset;
xnData_[2] = (uint64_t)commParam->recvBuf + offset;
uint64_t tmpCount = commParam->count / ccuParam_.rankNum;
uint64_t loopCount = CCU_LOOP_COUNT;
xnData_[1] = reinterpret_cast<uint64_t>(commParam->sendBuf);
xnData_[2] = commParam->wParamExt.sendSizes;
xnData_[8] = commParam->wParamExt.remoteWinOffset;
uint64_t sliceSize = commParam->count;
loopCount = CCU_LOOP_COUNT_ATAVW;
uint64_t loopSize = loopCount * CCU_MEMSLICE_SIZE;
uint64_t m = sliceSize / loopSize;
uint64_t n = (sliceSize - m * loopSize) / CCU_MEMSLICE_SIZE;
uint64_t p = sliceSize - m * loopSize - n * CCU_MEMSLICE_SIZE;
auto dataSize = DATA_TYPE_MAP[static_cast<uint64_t>(commParam->dataType)];
xnData_[3] = commParam->wParamExt.sendOffsets;
xnData_[4] = loopSize * m;
xnData_[5] = m;
CalcLoopGroupParam(xnData_, m, n, p);
}
template<const auto &config>
__aicore__ inline void
HcclImpl<HcclServerType::HCCL_SERVER_TYPE_CCU, config>::CcuPrepareForAllReduce(__gm__ CommonPrepareParamCcu *commParam)
{
xnData_[0] = GetOpId(commParam);
uint64_t offset = commParam->count * ccuParam_.repeatIndex * DATA_TYPE_MAP[commParam->dataType];
xnData_[1] = (uint64_t)commParam->sendBuf + offset;
xnData_[2] = (uint64_t)commParam->recvBuf + offset;
uint64_t tmpCount = commParam->count / ccuParam_.rankNum;
uint64_t loopCount = CCU_LOOP_COUNT;
uint64_t sliceCount = (ccuParam_.rankId == ccuParam_.rankNum - 1) ?
(commParam->count - (ccuParam_.rankNum - 1) * tmpCount) : tmpCount;
uint64_t sliceSize = sliceCount * DATA_TYPE_MAP[commParam->dataType];
uint64_t loopSize = loopCount * CCU_MEMSLICE_SIZE;
uint64_t m = sliceSize / loopSize;
uint64_t n = (sliceSize - m * loopSize) / CCU_MEMSLICE_SIZE;
uint64_t p = sliceSize - m * loopSize - n * CCU_MEMSLICE_SIZE;
auto dataSize = DATA_TYPE_MAP[static_cast<uint64_t>(commParam->dataType)];
xnData_[3] = (commParam->strideCount == 0) ? tmpCount * dataSize * ccuParam_.rankId :
(commParam->strideCount * dataSize * ccuParam_.rankId);
xnData_[4] = loopSize * m;
xnData_[5] = m;
CalcLoopGroupParam(xnData_, m, n, p);
}
template<const auto &config>
__aicore__ inline void
HcclImpl<HcclServerType::HCCL_SERVER_TYPE_CCU, config>::CcuPrepareForAllGather(__gm__ CommonPrepareParamCcu *commParam)
{
xnData_[0] = GetOpId(commParam);
auto dataSize = DATA_TYPE_MAP[static_cast<uint64_t>(commParam->dataType)];
uint64_t offset = commParam->count * ccuParam_.repeatIndex * dataSize;
xnData_[1] = (uint64_t)commParam->sendBuf + offset;
xnData_[2] = (uint64_t)commParam->recvBuf + offset;
uint64_t tmpCount = commParam->count / ccuParam_.rankNum;
uint64_t loopCount = CCU_LOOP_COUNT;
uint64_t sliceCount = commParam->count;
uint64_t sliceSize = sliceCount * dataSize;
uint64_t loopSize = loopCount * CCU_MEMSLICE_SIZE;
uint64_t m = sliceSize / loopSize;
uint64_t n = (sliceSize - m * loopSize) / CCU_MEMSLICE_SIZE;
uint64_t p = sliceSize - m * loopSize - n * CCU_MEMSLICE_SIZE;
xnData_[3] = (commParam->strideCount == 0) ? sliceSize * ccuParam_.rankId :
(commParam->strideCount * dataSize * ccuParam_.rankId);
xnData_[4] = loopSize * m;
xnData_[5] = m;
CalcLoopGroupParam(xnData_, m, n, p);
}
template<const auto &config>
__aicore__ inline void
HcclImpl<HcclServerType::HCCL_SERVER_TYPE_CCU, config>::CcuPrepareForReduceScatter(__gm__ CommonPrepareParamCcu *commParam)
{
CcuPrepareForAllGather(commParam);
}
template<const auto &config>
__aicore__ inline void
HcclImpl<HcclServerType::HCCL_SERVER_TYPE_CCU, config>::CcuPrepareForAllReduceM2M(__gm__ CommonPrepareParamCcu *commParam)
{
xnData_[0] = GetOpId(commParam);
auto dataSize = DATA_TYPE_MAP[static_cast<uint64_t>(commParam->dataType)];
uint64_t offset = commParam->count * ccuParam_.repeatIndex * dataSize;
xnData_[1] = (uint64_t)commParam->sendBuf + offset;
xnData_[2] = (uint64_t)commParam->recvBuf + offset;
uint64_t loopCount = CCU_LOOP_COUNT_M2M_RE;
uint64_t tmpCount = (commParam->count + ccuParam_.rankNum - 1) / ccuParam_.rankNum;
uint64_t sliceCount = (ccuParam_.rankId == ccuParam_.rankNum - 1) ?
(commParam->count - (ccuParam_.rankNum - 1) * tmpCount) : tmpCount;
uint64_t sliceSize = sliceCount * DATA_TYPE_MAP[commParam->dataType];
xnData_[3] = ccuParam_.scratchAddr;
uint64_t rankSliceOffset =
ccuParam_.rankId * ((commParam->strideCount == 0) ? sliceSize : (commParam->strideCount * dataSize));
KERNEL_LOG(KERNEL_INFO, "ApiClient CcuPrepareForAllReduceM2M scratchAddr:0x%llx, rankSliceOffset:%d",
ccuParam_.scratchAddr, rankSliceOffset);
xnData_[4] = rankSliceOffset;
xnData_[5] = rankSliceOffset;
uint64_t normalSliceCount = (commParam->count + ccuParam_.rankNum - 1) / ccuParam_.rankNum;
uint64_t normalSliceSize = normalSliceCount * dataSize;
uint64_t normalRankCount = commParam->count / normalSliceCount;
uint64_t lastSliceSize = 0;
if (normalRankCount == ccuParam_.rankNum) {
lastSliceSize = (commParam->count * dataSize) - ((normalRankCount - 1) * normalSliceSize);
} else {
lastSliceSize = (commParam->count * dataSize) - (normalRankCount * normalSliceSize);
}
KERNEL_LOG(KERNEL_INFO, "ApiClient CcuPrepareForAllReduceM2M normalSliceSize:%d, lastSliceSize:%d", normalSliceSize, lastSliceSize);
xnData_[6] = normalSliceSize;
xnData_[7] = lastSliceSize;
xnData_[8] = ccuParam_.rankId < normalRankCount ? normalSliceSize : lastSliceSize;
xnData_[9] = ccuParam_.rankId * normalSliceSize;
xnData_[10] = 0;
if (ccuParam_.rankId == ccuParam_.rankNum - 1) {
CalcGoSize(lastSliceSize, loopCount, CCU_MEMSLICE_SIZE, &xnData_[11]);
} else {
CalcGoSize(normalSliceSize, loopCount, CCU_MEMSLICE_SIZE, &xnData_[11]);
}
}
template<const auto &config>
__aicore__ inline void
HcclImpl<HcclServerType::HCCL_SERVER_TYPE_CCU, config>::CcuPrepareForAllGatherM2M(__gm__ CommonPrepareParamCcu *commParam)
{
xnData_[0] = GetOpId(commParam);
auto dataSize = DATA_TYPE_MAP[static_cast<uint64_t>(commParam->dataType)];
uint64_t offset = commParam->count * ccuParam_.repeatIndex * dataSize;
xnData_[1] = (uint64_t)commParam->sendBuf + offset;
xnData_[2] = (uint64_t)commParam->recvBuf + offset;
uint64_t loopCount = CCU_LOOP_COUNT_M2M_AG;
uint64_t sliceCount = commParam->count;
uint64_t sliceSize = sliceCount * dataSize;
xnData_[3] = ccuParam_.rankId * ((commParam->strideCount == 0) ? sliceSize : (commParam->strideCount * dataSize));
xnData_[4] = sliceSize;
CalcGoSize(sliceSize, loopCount, CCU_MEMSLICE_SIZE * 8, &xnData_[5]);
}
template<const auto &config>
__aicore__ inline void
HcclImpl<HcclServerType::HCCL_SERVER_TYPE_CCU, config>::CcuPrepareForReduceScatterM2M(__gm__ CommonPrepareParamCcu *commParam)
{
xnData_[0] = GetOpId(commParam);
auto dataSize = DATA_TYPE_MAP[static_cast<uint64_t>(commParam->dataType)];
uint64_t offset = commParam->count * ccuParam_.repeatIndex * dataSize;
xnData_[1] = (uint64_t)commParam->sendBuf + offset;
xnData_[2] = (uint64_t)commParam->recvBuf + offset;
uint64_t loopCount = CCU_LOOP_COUNT_M2M_RE;
uint64_t sliceCount = commParam->count;
uint64_t sliceSize = sliceCount * dataSize;
xnData_[3] = ccuParam_.scratchAddr;
uint64_t rankSliceOffset =
ccuParam_.rankId * ((commParam->strideCount == 0) ? sliceSize : (commParam->strideCount * dataSize));
KERNEL_LOG(KERNEL_INFO, "ApiClient CcuPrepareForReduceScatterM2M scratchAddr:0x%llx, rankSliceOffset:%d",
ccuParam_.scratchAddr, rankSliceOffset);
xnData_[4] = rankSliceOffset;
xnData_[7] = sliceSize;
xnData_[8] = UINT64_MAX - 1;
CalcGoSize(sliceSize, loopCount, CCU_MEMSLICE_SIZE, &xnData_[9]);
}
}
#endif