* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef CCU_LOOP_HPP
#define CCU_LOOP_HPP
#include <vector>
#include "ccu_types.h"
#include "ccu_variable.hpp"
#include "ccu_func.hpp"
#include "ccu_primitives_impl.h"
#include "ccu_utils.hpp"
namespace AscendC {
namespace ccu {
class Loop {
public:
Loop(Variable &loopCfg, const Func &func)
{
ComposeLoopBody(func);
isVarBased_ = true;
loopParamVar_ = &loopCfg;
}
Loop(const CcuLoopConfig &loopCfg, const Func &func)
{
ComposeLoopBody(func);
isVarBased_ = false;
config_ = loopCfg;
}
CcuLoop Handle() const
{
return handle_;
}
bool IsVarBased() const
{
return isVarBased_;
}
Variable *LoopParamVar() const
{
return loopParamVar_;
}
const CcuLoopConfig *Config() const
{
return &config_;
}
private:
void ComposeLoopBody(const Func &func)
{
if (func.NumIn() != 0) {
throw ::AscendC::ccu::detail::CcuException(CcuResult::CCU_E_PARA,
"ccu::Loop requires a no-argument ccu::Func");
}
CCU_THROW_IF_FAILED(::CcuLoopCreate(&handle_), "CcuLoopCreate failed");
CCU_THROW_IF_FAILED(::_CcuLoopBodyEnter(handle_), "_CcuLoopBodyEnter failed");
try {
func.RunBody(nullptr);
} catch (...) {
(void)::_CcuLoopBodyExit(handle_);
throw;
}
CCU_THROW_IF_FAILED(::_CcuLoopBodyExit(handle_), "_CcuLoopBodyExit failed");
}
CcuLoop handle_{0};
bool isVarBased_{false};
Variable *loopParamVar_{nullptr};
CcuLoopConfig config_{};
};
class LoopGroup {
public:
LoopGroup(Variable ¶llelCfg, Variable &offsetCfg, uint32_t maxLoopNum,
const std::vector<Loop> &loops)
{
CCU_THROW_IF_FAILED(
::CcuLoopGroupCreateFromVar(&handle_, maxLoopNum,
parallelCfg.handle, offsetCfg.handle),
"CcuLoopGroupCreateFromVar failed");
AddLoops(loops);
}
LoopGroup(const CcuLoopGroupConfig &loopGroupCfg, uint32_t maxLoopNum,
const std::vector<Loop> &loops)
{
CcuLoopGroupConfig localCfg = loopGroupCfg;
CCU_THROW_IF_FAILED(
::CcuLoopGroupCreate(&handle_, maxLoopNum, &localCfg),
"CcuLoopGroupCreate failed");
AddLoops(loops);
}
CcuLoopGroup Handle() const
{
return handle_;
}
private:
void AddLoops(const std::vector<Loop> &loops)
{
for (const auto &loop : loops) {
if (loop.IsVarBased()) {
auto *loopParamVar = loop.LoopParamVar();
if (loopParamVar == nullptr) {
throw ::AscendC::ccu::detail::CcuException(CcuResult::CCU_E_PARA,
"ccu::Loop var-based loop has null loop parameter");
}
CCU_THROW_IF_FAILED(
::CcuLoopGroupAddLoopFromVar(handle_, loop.Handle(), loopParamVar->handle),
"CcuLoopGroupAddLoopFromVar failed");
} else {
CCU_THROW_IF_FAILED(
::CcuLoopGroupAddLoop(handle_, loop.Handle(), loop.Config()),
"CcuLoopGroupAddLoop failed");
}
}
}
CcuLoopGroup handle_{0};
};
}
}
#endif