* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CCU_PRIMITIVES_HPP
#define CCU_PRIMITIVES_HPP
#include <vector>
#include "ccu_primitives_impl.h"
#include "ccu_control_flow_macro.h"
#include "ccu_variable.hpp"
#include "ccu_address.hpp"
#include "ccu_event.hpp"
#include "ccu_buffer.hpp"
#include "ccu_local_addr.hpp"
#include "ccu_remote_addr.hpp"
#include "ccu_array.hpp"
#include "ccu_func.hpp"
#include "ccu_loop.hpp"
namespace AscendC {
namespace ccu {
using LoopConfig = ::CcuLoopConfig;
using LoopGroupConfig = ::CcuLoopGroupConfig;
template <typename T>
inline T GetResByChannel(ChannelHandle , uint32_t ) {
static_assert(sizeof(T) == 0,
"ccu::GetResByChannel<T> is not specialized for this type T; "
"currently supported: Variable.");
}
template <>
inline Variable GetResByChannel<Variable>(ChannelHandle channel, uint32_t varIndex) {
Variable v{detail::NoAllocTag{}};
CCU_THROW_IF_FAILED(CcuVariableCreateByChannel(channel, varIndex, &v.handle),
"CcuVariableCreateByChannel: failed");
return v;
}
inline CcuResult EventRecord(Event e, uint16_t mask = 1) { return CcuEventRecord(e.handle, mask); }
inline CcuResult EventWait(Event e, uint16_t mask = 1) { return CcuEventWait(e.handle, mask); }
inline CcuResult EventRecord(const char *notifyTag, uint16_t mask = 1) { return CcuLocalNotifyRecord(notifyTag, mask); }
inline CcuResult EventWait(const char *notifyTag, uint16_t mask = 1) { return CcuLocalNotifyWait(notifyTag, mask); }
inline CcuResult NotifyRecord(ChannelHandle channel, uint32_t remoteNotifyIdx, uint16_t mask=1){ return CcuNotifyRecord(channel, remoteNotifyIdx, mask); }
inline CcuResult NotifyWait(ChannelHandle channel, uint32_t localNotifyIdx, uint16_t mask=1){ return CcuNotifyWait(channel, localNotifyIdx, mask); }
inline CcuResult WriteVariableWithNotify(ChannelHandle channel, Variable var,uint32_t remoteVarIdx, uint32_t remoteNotifyIdx, uint16_t mask=1){ return CcuWriteVariableWithNotify(channel, var.handle, remoteVarIdx, remoteNotifyIdx, mask); }
inline CcuResult LoadArg(Variable v, uint32_t argId) { return CcuLoadArg(v.handle, argId); }
inline CcuResult Load(uint64_t addr, Array<Variable>& vArr, uint32_t num) { return CcuLoadVar(addr, vArr[0].handle, num); }
inline CcuResult Load(uint64_t addr, Variable v) { return CcuLoadVar(addr, v.handle, 1); }
inline CcuResult Load(Variable addrVar, Array<Variable>& vArr, uint32_t num) { return CcuLoadVarFromVarAddr(addrVar.handle, vArr[0].handle, num); }
inline CcuResult Load(Variable addrVar, Variable v) { return CcuLoadVarFromVarAddr(addrVar.handle, v.handle, 1); }
inline CcuResult Store(uint64_t addr, Array<Variable>& vArr, uint32_t num) { return CcuStoreVar(addr, vArr[0].handle, num);}
inline CcuResult Store(uint64_t addr, Variable v) { return CcuStoreVar(addr, v.handle, 1); }
inline CcuResult Store(Variable addrVar, Array<Variable>& vArr, uint32_t num) { return CcuStoreVarToVarAddr(addrVar.handle, vArr[0].handle,num); }
inline CcuResult Store(Variable addrVar, Variable v) { return CcuStoreVarToVarAddr(addrVar.handle, v.handle, 1); }
inline CcuResult LocalCopy(LocalAddr dst, LocalAddr src,Variable len, Event event, uint16_t mask = 1) { return CcuLocalCopyMemToMem(dst.handle, src.handle, len.handle, event.handle, mask); }
inline CcuResult LocalCopy(CcuBuffer dst, LocalAddr src, Variable len, Event event, uint16_t mask = 1) { return CcuLocalCopyMemToBuffer(dst.handle, src.handle, len.handle, event.handle, mask); }
inline CcuResult LocalCopy(LocalAddr dst, CcuBuffer src,Variable len, Event event, uint16_t mask = 1) { return CcuLocalCopyBufferToMem(dst.handle, src.handle, len.handle, event.handle, mask); }
inline CcuResult LocalReduce(LocalAddr dst, LocalAddr src,Variable len, HcclDataType dataType, HcclReduceOp opType, Event event, uint16_t mask = 1) { return CcuLocalMemReduce(dst.handle, src.handle, len.handle, dataType, opType, event.handle, mask); }
inline CcuResult LocalReduce(CcuBuffer* buffers, uint32_t count, HcclDataType dataType, HcclDataType outputDataType, HcclReduceOp opType, Variable len, Event event, uint16_t mask = 1)
{
if (buffers == nullptr || count == 0) {
return CcuResult::CCU_E_PARA;
}
std::vector<CcuBufferHandle> bufHandles(count);
for (uint32_t i = 0; i < count; i++) {
bufHandles[i] = buffers[i].handle;
}
return CcuLocalBufferReduce(bufHandles.data(), count, dataType, outputDataType, opType, len.handle, event.handle, mask);
}
inline CcuResult Read(ChannelHandle ch, LocalAddr local, RemoteAddr remote, Variable len, Event event, uint16_t mask = 1) { return CcuReadMemToMem(ch, local.handle, remote.handle, len.handle, event.handle, mask); }
inline CcuResult Read(ChannelHandle ch, CcuBuffer local, RemoteAddr remote, Variable len, Event event, uint16_t mask = 1) { return CcuReadMemToBuffer(ch, local.handle, remote.handle, len.handle, event.handle, mask); }
inline CcuResult ReadReduce(ChannelHandle ch, LocalAddr local, RemoteAddr remote, Variable len, HcclDataType dataType, HcclReduceOp opType, Event event, uint16_t mask = 1) { return CcuReadMemToMemReduce(ch, local.handle, remote.handle, len.handle, dataType, opType, event.handle, mask); }
inline CcuResult Write(ChannelHandle ch, RemoteAddr remote, LocalAddr local, Variable len, Event event, uint16_t mask = 1){ return CcuWriteMemToMem(ch, remote.handle, local.handle, len.handle, event.handle, mask); }
inline CcuResult Write(ChannelHandle ch, RemoteAddr remote, CcuBuffer local, Variable len, Event event, uint16_t mask = 1) { return CcuWriteBufferToMem(ch, remote.handle, local.handle, len.handle, event.handle, mask); }
inline CcuResult WriteReduce(ChannelHandle ch, RemoteAddr remote, LocalAddr local, Variable len, HcclDataType dataType, HcclReduceOp opType, Event event, uint16_t mask = 1){ return CcuWriteMemToMemReduce(ch, remote.handle, local.handle, len.handle, dataType, opType, event.handle, mask);}
}
}
#endif