Copyright (c) 2025 Huawei Technologies Co., Ltd.
This program is free software, you can redistribute it and/or modify it under the terms and conditions of
CANN Open Software License Agreement Version 2.0 (the "License").
Please refer to the License for details. You may not use this file except in compliance with the License.
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
See LICENSE in the root of the software repository for the full text of the License.
*/
#pragma once
#include <stddef.h>
#include <stdint.h>
#if !defined(__CCE_KT_TEST__) && defined(__CCE_AICORE__)
#include "pto/comm/pto_comm_inst.hpp"
#endif
constexpr int MAX_RING_RANKS = 16;
constexpr int TARGET_CHUNKS_MIN = 64;
constexpr int TARGET_CHUNKS_MAX = 128;
constexpr int MIN_CHUNK_SIZE = 4;
constexpr int MAX_CHUNK_SIZE = 64;
inline int ComputeOptimalChunkSize(int num_tiles_per_src)
{
if (num_tiles_per_src <= 0)
return MIN_CHUNK_SIZE;
int chunk_size = MIN_CHUNK_SIZE;
int chunk_count = (num_tiles_per_src + chunk_size - 1) / chunk_size;
if (chunk_count > TARGET_CHUNKS_MAX) {
chunk_size = (num_tiles_per_src + TARGET_CHUNKS_MAX - 1) / TARGET_CHUNKS_MAX;
chunk_size = ((chunk_size + 3) / 4) * 4;
if (chunk_size > MAX_CHUNK_SIZE)
chunk_size = MAX_CHUNK_SIZE;
}
return chunk_size;
}
#ifndef STREAMING_CHUNK_SIZE
#define STREAMING_CHUNK_SIZE 4
#endif
constexpr int CHUNK_SIZE = STREAMING_CHUNK_SIZE;
struct alignas(64) ChunkFlagMatrix {
int32_t num_ranks;
int32_t num_chunks_per_src;
int32_t num_tiles_per_src;
int32_t chunk_size;
int32_t stride;
int32_t my_rank;
int32_t epoch;
int32_t padding[9];
};
inline size_t ChunkFlagMatrixSize(int num_ranks, int num_tiles_per_src, int chunk_size)
{
int actual_chunk_size = (chunk_size > 0) ? chunk_size : ComputeOptimalChunkSize(num_tiles_per_src);
int num_chunks = (num_tiles_per_src + actual_chunk_size - 1) / actual_chunk_size;
int stride = ((num_chunks + 15) / 16) * 16;
return sizeof(ChunkFlagMatrix) + static_cast<size_t>(num_ranks) * stride * sizeof(int32_t);
}
inline size_t ChunkFlagMatrixSummaryOffset(int num_ranks, int num_tiles_per_src, int chunk_size)
{
return ChunkFlagMatrixSize(num_ranks, num_tiles_per_src, chunk_size);
}
inline size_t ChunkFlagMatrixWithSummarySize(int num_ranks, int num_tiles_per_src, int chunk_size)
{
return ChunkFlagMatrixSummaryOffset(num_ranks, num_tiles_per_src, chunk_size) +
static_cast<size_t>(num_ranks) * sizeof(int32_t);
}
inline void ChunkFlagMatrixInit(ChunkFlagMatrix *flags, int num_ranks, int num_tiles_per_src, int chunk_size)
{
int actual_chunk_size = (chunk_size > 0) ? chunk_size : ComputeOptimalChunkSize(num_tiles_per_src);
int num_chunks = (num_tiles_per_src + actual_chunk_size - 1) / actual_chunk_size;
int stride = ((num_chunks + 15) / 16) * 16;
flags->num_ranks = num_ranks;
flags->num_chunks_per_src = num_chunks;
flags->num_tiles_per_src = num_tiles_per_src;
flags->chunk_size = actual_chunk_size;
flags->stride = stride;
flags->my_rank = -1;
flags->epoch = 1;
for (int i = 0; i < 9; i++)
flags->padding[i] = 0;
int32_t *base = reinterpret_cast<int32_t *>(reinterpret_cast<uint8_t *>(flags) + sizeof(ChunkFlagMatrix));
for (int i = 0; i < num_ranks * stride; ++i) {
base[i] = 0;
}
}
inline void ChunkFlagMatrixSummaryInit(int32_t *summary_base, int num_ranks)
{
for (int i = 0; i < num_ranks; ++i)
summary_base[i] = 0;
}
inline void ChunkFlagMatrixReset(ChunkFlagMatrix *flags)
{
int32_t *base = reinterpret_cast<int32_t *>(reinterpret_cast<uint8_t *>(flags) + sizeof(ChunkFlagMatrix));
for (int i = 0; i < flags->num_ranks * flags->stride; ++i) {
base[i] = 0;
}
}
inline void ChunkFlagMatrixSetLocalReady(ChunkFlagMatrix *flags, int my_rank)
{
int32_t *base = reinterpret_cast<int32_t *>(reinterpret_cast<uint8_t *>(flags) + sizeof(ChunkFlagMatrix));
int offset = my_rank * flags->stride;
for (int c = 0; c < flags->num_chunks_per_src; ++c) {
base[offset + c] = 1;
}
}
#if !defined(__CCE_KT_TEST__) && defined(__CCE_AICORE__)
AICORE inline size_t ChunkFlagMatrixBytes(volatile __gm__ ChunkFlagMatrix *flags)
{
return sizeof(ChunkFlagMatrix) +
static_cast<size_t>(flags->num_ranks) * static_cast<size_t>(flags->stride) * sizeof(int32_t);
}
AICORE inline volatile __gm__ int32_t *GetChunkFlagPtr(volatile __gm__ ChunkFlagMatrix *flags, int32_t src_rank,
int32_t chunk_idx)
{
int32_t stride = flags->stride;
int32_t idx = src_rank * stride + chunk_idx;
volatile __gm__ int32_t *base = reinterpret_cast<volatile __gm__ int32_t *>(
reinterpret_cast<volatile __gm__ uint8_t *>(flags) + sizeof(ChunkFlagMatrix));
return base + idx;
}
AICORE inline volatile __gm__ int32_t *GetSummaryBase(volatile __gm__ ChunkFlagMatrix *flags)
{
return reinterpret_cast<volatile __gm__ int32_t *>(reinterpret_cast<volatile __gm__ uint8_t *>(flags) +
ChunkFlagMatrixBytes(flags));
}
AICORE inline void SetChunkFlagReady(volatile __gm__ ChunkFlagMatrix *flags, int32_t src_rank, int32_t chunk_idx)
{
volatile __gm__ int32_t *ptr = GetChunkFlagPtr(flags, src_rank, chunk_idx);
pto::comm::Signal sig(reinterpret_cast<__gm__ int32_t *>(const_cast<__gm__ int32_t *>(ptr)));
pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::AtomicAdd);
}
AICORE inline void SetRemoteChunkFlagReady(__gm__ ChunkFlagMatrix *remote_flags, int32_t src_rank, int32_t chunk_idx,
__gm__ int32_t *remote_summary_src_ptr = nullptr)
{
volatile __gm__ ChunkFlagMatrix *r = reinterpret_cast<volatile __gm__ ChunkFlagMatrix *>(remote_flags);
if (src_rank < 0 || src_rank >= r->num_ranks || chunk_idx < 0 || chunk_idx >= r->num_chunks_per_src) {
return;
}
volatile __gm__ int32_t *ptr = GetChunkFlagPtr(r, src_rank, chunk_idx);
pto::comm::Signal sig(reinterpret_cast<__gm__ int32_t *>(const_cast<__gm__ int32_t *>(ptr)));
pto::comm::TNOTIFY(sig, 1, pto::comm::NotifyOp::AtomicAdd);
if (remote_summary_src_ptr != nullptr) {
pto::comm::Signal sumSig(remote_summary_src_ptr);
pto::comm::TNOTIFY(sumSig, 1, pto::comm::NotifyOp::AtomicAdd);
}
}
AICORE inline void SetLocalSummaryReady(volatile __gm__ int32_t *summary_base, int32_t src_rank, int32_t value)
{
if (summary_base == nullptr || src_rank < 0)
return;
volatile __gm__ int32_t *ptr = summary_base + src_rank;
dcci((__gm__ void *)ptr, SINGLE_CACHE_LINE);
__asm__ __volatile__("" ::: "memory");
*ptr = value;
dcci((__gm__ void *)ptr, SINGLE_CACHE_LINE);
__asm__ __volatile__("" ::: "memory");
}
AICORE inline bool IsChunkReady(volatile __gm__ ChunkFlagMatrix *flags, int32_t src_rank, int32_t chunk_idx)
{
volatile __gm__ int32_t *ptr = GetChunkFlagPtr(flags, src_rank, chunk_idx);
int32_t epoch = flags->epoch;
dcci((__gm__ void *)ptr, SINGLE_CACHE_LINE);
__asm__ __volatile__("" ::: "memory");
return (*ptr >= epoch);
}
AICORE inline bool IsAnyReadyFromSrc(volatile __gm__ int32_t *summary_base, int32_t src_rank)
{
if (summary_base == nullptr || src_rank < 0)
return false;
volatile __gm__ int32_t *ptr = summary_base + src_rank;
dcci((__gm__ void *)ptr, SINGLE_CACHE_LINE);
__asm__ __volatile__("" ::: "memory");
return (*ptr >= 1);
}
AICORE inline int32_t GetReadyCountFromSrc(volatile __gm__ int32_t *summary_base, int32_t src_rank)
{
if (summary_base == nullptr || src_rank < 0)
return 0;
volatile __gm__ int32_t *ptr = summary_base + src_rank;
dcci((__gm__ void *)ptr, SINGLE_CACHE_LINE);
__asm__ __volatile__("" ::: "memory");
return *ptr;
}
AICORE inline void WaitReadyCountFromSrc(volatile __gm__ int32_t *summary_base, int32_t src_rank, int32_t expected)
{
if (summary_base == nullptr || src_rank < 0 || expected <= 0)
return;
__gm__ int32_t *ptr = const_cast<__gm__ int32_t *>(summary_base + src_rank);
pto::comm::Signal sig(ptr);
pto::comm::TWAIT(sig, expected, pto::comm::WaitCmp::GE);
}
#endif