* Copyright (c) 2025 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef AIV_SYNC_INTERFACE_H
#define AIV_SYNC_INTERFACE_H
#include "kernel_operator.h"
using namespace AscendC;
constexpr uint64_t UB_FLAG_PAD_COUNT = 8;
constexpr uint64_t UB_ADDRESS_PAD_COUNT = 4;
__aicore__ inline void SetSignalValue(__gm__ int32_t* gmSignalAddr, LocalTensor<int32_t>& localTensor, int32_t value)
{
GlobalTensor<int32_t> globalTensor;
globalTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(gmSignalAddr));
localTensor.SetValue(0, value);
pipe_barrier(PIPE_ALL);
DataCopy(globalTensor, localTensor, 32);
pipe_barrier(PIPE_ALL);
return;
}
__aicore__ inline void AddSignalValue(__gm__ int32_t* gmSignalAddr, LocalTensor<int32_t>& localTensor, uint64_t value)
{
GlobalTensor<int32_t> globalTensor;
globalTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(gmSignalAddr));
Duplicate<int32_t>(localTensor, value, UB_FLAG_PAD_COUNT);
SetAtomicAdd<int32_t>();
PipeBarrier<PIPE_ALL>();
DataCopy(globalTensor, localTensor, UB_FLAG_PAD_COUNT);
SetAtomicNone();
return;
}
__aicore__ inline void WaitSignalValue(__gm__ int32_t *gmSignalAddr, LocalTensor<int32_t>& localTensor, int32_t expectedValue)
{
GlobalTensor<int32_t> globalTensor;
globalTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(gmSignalAddr));
while (true) {
DataCopy(localTensor, globalTensor, 32);
pipe_barrier(PIPE_ALL);
if (localTensor.GetValue(0) == expectedValue) {
break;
}
}
return;
}
__aicore__ inline void WaitSignalGEValue(__gm__ int32_t *gmSignalAddr, LocalTensor<int32_t>& localTensor, int32_t value)
{
GlobalTensor<int32_t> globalTensor;
globalTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(gmSignalAddr));
while (true) {
DataCopy(localTensor, globalTensor, 32);
pipe_barrier(PIPE_ALL);
if (localTensor.GetValue(0) >= value) {
break;
}
}
return;
}
__aicore__ inline uint64_t GetSignalValue(__gm__ int32_t *gmSignalAddr, LocalTensor<int32_t>& localTensor)
{
GlobalTensor<int32_t> globalTensor;
globalTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(gmSignalAddr));
DataCopy(localTensor, globalTensor, UB_FLAG_PAD_COUNT);
pipe_barrier(PIPE_ALL);
int32_t ret = localTensor.GetValue(0);
return ret;
}
__aicore__ inline void WaitSignalNotEqValue(__gm__ int32_t *gmSignalAddr, LocalTensor<int32_t>& localTensor, uint64_t value)
{
GlobalTensor<int32_t> globalTensor;
globalTensor.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(gmSignalAddr));
while (true) {
DataCopy(localTensor, globalTensor, UB_FLAG_PAD_COUNT);
pipe_barrier(PIPE_ALL);
if (localTensor.GetValue(0) != value) {
break;
}
}
return;
}
__aicore__ inline void SetFlagBatchValue(__gm__ int32_t *ctrlFlagGM, TQue<QuePosition::VECOUT, 1> &batchQue, uint64_t setValue, uint64_t count)
{
GlobalTensor<int32_t> globalBatchSet;
globalBatchSet.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(ctrlFlagGM));
LocalTensor<int32_t> localBatchSet = batchQue.AllocTensor<int32_t>();
for (int32_t i = 0; i < count; i++) {
localBatchSet.SetValue(i * UB_FLAG_PAD_COUNT, setValue);
}
pipe_barrier(PIPE_ALL);
DataCopy(globalBatchSet, localBatchSet, UB_FLAG_PAD_COUNT * count);
batchQue.FreeTensor(localBatchSet);
}
#endif