* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef KERNEL_SIMT_CPU_DEBUG
#define KERNEL_SIMT_CPU_DEBUG
#ifdef ASCENDC_CPU_DEBUG
#include <vector>
#include <atomic>
#include <mutex>
#include <thread>
#include <condition_variable>
#include <chrono>
#include "stub_def.h"
#include "kernel_log.h"
namespace cce {
struct dim3 {
uint32_t x = 1u, y = 1u, z = 1u;
dim3(uint32_t x_) { x = x_; }
dim3(uint32_t x_, uint32_t y_)
{
x = x_;
y = y_;
}
dim3(uint32_t x_, uint32_t y_, uint32_t z_)
{
x = x_;
y = y_;
z = z_;
}
};
}
inline cce::dim3 blockDim(1u, 1u, 1u);
inline cce::dim3 blockIdx(0u, 0u, 0u);
inline thread_local cce::dim3 threadIdx(0u, 0u, 0u);
inline cce::dim3 gridDim(8u, 1u, 1u);
namespace AscendC {
namespace Simt {
constexpr uint32_t THREAD_PER_WARP = 32;
constexpr uint32_t MEMORY_PIECE = 2;
template <typename Func>
void FuncWrapper(Func func, uint32_t warpId, uint32_t threadIndex);
class Warp {
public:
Warp() {}
~Warp();
template <typename Func>
void Schedule(Func func, uint32_t warpId, uint32_t idx)
{
threads[idx] = std::thread(FuncWrapper<decltype(func)>, func, warpId, idx);
}
void Done();
template <typename T, typename Func>
T WarpOp(T val, Func action)
{
std::unique_lock<std::mutex> lck(mtx_);
auto currGeneration = syncGeneration;
void* temp = reinterpret_cast<void *>(&data[currGeneration % MEMORY_PIECE]);
T &dataToUpdate = *reinterpret_cast<T *>(temp);
activeThreads--;
if (activeThreads == 0) {
syncGeneration++;
activeThreads = THREAD_PER_WARP;
dataToUpdate = action(val, dataToUpdate);
isReset = false;
cv_.notify_all();
} else {
if (!isReset) {
dataToUpdate = val;
isReset = true;
} else {
dataToUpdate = action(val, dataToUpdate);
}
bool notTimeout = cv_.wait_for(
lck, std::chrono::seconds(5), [this, currGeneration] { return currGeneration != syncGeneration; });
if (!notTimeout) {
KERNEL_LOG(KERNEL_ERROR,
"Warp operation timeout, CPU Debug only supports all 32 threads must be involved in the same "
"warp operation. If it has already satisfied this condition, maybe deadlock occurred.");
}
}
return dataToUpdate;
}
template <typename T>
T WarpShuffleOp(T val, uint32_t laneToWrite, uint32_t laneToRead)
{
std::unique_lock<std::mutex> lck(mtx_);
auto currGeneration = syncGeneration;
void* temp = reinterpret_cast<void *>(&shuffleData[laneToWrite][currGeneration % MEMORY_PIECE]);
T &dataToUpdate = *reinterpret_cast<T *>(temp);
dataToUpdate = val;
activeThreads--;
if (activeThreads == 0) {
syncGeneration++;
activeThreads = THREAD_PER_WARP;
cv_.notify_all();
} else {
bool notTimeout = cv_.wait_for(
lck, std::chrono::seconds(5), [this, currGeneration] { return currGeneration != syncGeneration; });
if (!notTimeout) {
KERNEL_LOG(KERNEL_ERROR,
"Shuffle Warp operation timeout, CPU Debug only supports all 32 threads must be involved in the same "
"warp operation. If it has already satisfied this condition, maybe deadlock occurred.");
}
}
void* temp2 = reinterpret_cast<void *>(&shuffleData[laneToRead][currGeneration % MEMORY_PIECE]);
return *reinterpret_cast<T *>(temp2);
}
private:
uint32_t activeThreads{THREAD_PER_WARP};
uint32_t syncGeneration{0};
bool isReset{false};
uint32_t shuffleData[THREAD_PER_WARP][MEMORY_PIECE];
uint64_t data[MEMORY_PIECE]{0};
std::mutex mtx_;
std::condition_variable cv_;
std::thread threads[THREAD_PER_WARP];
};
class ThreadBlock {
public:
static ThreadBlock &GetBlockInstance();
void Init(uint32_t num);
template <typename Func>
void Schedule(Func func, uint32_t idx)
{
ASCENDC_ASSERT((idx / THREAD_PER_WARP < warpNum_),
{ KERNEL_LOG(KERNEL_ERROR, "thread idx %u exceeds warp count %u", idx, warpNum_); });
warps_[idx / THREAD_PER_WARP].Schedule<Func>(func, idx / THREAD_PER_WARP, idx % THREAD_PER_WARP);
}
template <typename Func>
void AtomicOp(Func action)
{
std::unique_lock<std::mutex> lck(mtx_);
action();
}
void FinishJobs();
void SyncAllThreads();
void ThreadFinished();
public:
ThreadBlock() : activeThreads(0), syncGeneration(0), threadThreshold(0), warpNum_(0) {}
~ThreadBlock()
{
FinishJobs();
}
std::vector<Warp> warps_;
std::mutex mtx_;
std::condition_variable cv_;
uint32_t activeThreads{0};
uint32_t syncGeneration{0};
uint32_t threadThreshold{0};
uint32_t warpNum_{0};
};
template <typename Func>
void FuncWrapper(Func func, uint32_t warpId, uint32_t threadIndex)
{
uint32_t overallIdx = warpId * THREAD_PER_WARP + threadIndex;
g_threadIdxX = overallIdx % g_threadDimX;
g_threadIdxY = (overallIdx / g_threadDimX) % g_threadDimY;
g_threadIdxZ = overallIdx / (g_threadDimY * g_threadDimX);
threadIdx.x = g_threadIdxX;
threadIdx.y = g_threadIdxY;
threadIdx.z = g_threadIdxZ;
func();
ThreadBlock::GetBlockInstance().ThreadFinished();
}
uint32_t GetThreadIdx();
uint32_t GetLaneId();
uint32_t GetWarpId();
void Sync();
}
}
#endif
#endif