* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef __CODEGEN_KERNEL_LOOP_H__
#define __CODEGEN_KERNEL_LOOP_H__
#include <set>
#include <utility>
#include <sstream>
#include "ascir.h"
#include "ascgen_log.h"
#include "ascir_ops_utils.h"
#include "codegen_api_param/codegen_api_param.h"
namespace codegen {
class Tiler;
class TPipe;
class Tensor;
class ApiCall;
struct Loop;
class Axis;
enum class LoopType : int8_t {
CALL = 0,
LOOP
};
enum class BoolType : int8_t {
FALSE = 0,
TRUE = 1,
FAILED = 2
};
struct LoopBody {
LoopType type;
union {
ApiCall *call;
Loop *loop;
};
};
struct ApiTensor {
ascir::TensorId id;
ascir::ReuseId reuse_id;
struct ApiTensor* reuse_from;
struct ApiTensor* reuse_next;
struct ApiTensor* share_prev;
struct ApiTensor* share_next;
mutable int32_t share_order;
const ApiCall* write;
std::vector<const ApiCall*> reads;
ApiTensor();
};
enum class ApiScene : int8_t {
kDefault = 0,
kCVFuseUBLoad,
};
enum class ComputeStage : int8_t {
kDefault = 0,
kCVFuseStage1,
kCVFuseStage2,
};
struct ApiCallContext {
ApiScene scene = ApiScene::kDefault;
ComputeStage stage = ComputeStage::kDefault;
bool isCVFusion() const {
return scene != ApiScene::kDefault;
}
};
class ApiCall {
public:
virtual ~ApiCall() = default;
explicit ApiCall(const std::string &api_name) noexcept : api_name_(api_name) {}
virtual Status Init(const ascir::NodeView &node);
virtual Status ParseAttr(const ascir::NodeView &node) {
(void) node;
return af::SUCCESS;
}
virtual Status BuildApiParam(const TPipe &tpipe, const std::vector<ascir::AxisId> ¤t_axis,
const std::vector<std::reference_wrapper<const Tensor>> &input,
const std::vector<std::reference_wrapper<const Tensor>> &output) const;
virtual Status GenerateApiCallString(std::string &result) const;
virtual Status GenDimensionParam(const CodegenApiParam &api_param, std::stringstream &ss) const;
virtual Status PreProcess(const TPipe &tpipe, const std::vector<ascir::AxisId> ¤t_axis,
const std::vector<std::reference_wrapper<const Tensor>> &outputs,
std::string &result) const;
virtual Status GenerateFuncDefinition(const TPipe &tpipe, const Tiler &tiler, std::stringstream &ss) const;
virtual Status Generate(const TPipe &tpipe, const std::vector<ascir::AxisId> ¤t_axis,
const std::vector<std::reference_wrapper<const Tensor>> &input,
const std::vector<std::reference_wrapper<const Tensor>> &output, std::string &result) const;
virtual Status PostProcess(const TPipe &tpipe, const std::vector<ascir::AxisId> ¤t_axis,
const std::vector<std::reference_wrapper<const Tensor>> &outputs,
std::string &result) const;
virtual Status Generate(const TPipe &tpipe, const std::vector<ascir::AxisId> ¤t_axis,
std::string &result) const;
virtual Status GenerateMacro(std::string &result) const {
(void)result;
return af::SUCCESS;
}
virtual bool AreContiguousBufsPreferred() const {
return false;
}
bool FreeInputs(const TPipe &tpipe, std::stringstream &ss) const;
bool FreeUnusedOutputs(const TPipe &tpipe, std::stringstream &ss) const;
bool SyncOutputs(const TPipe &tpipe, std::stringstream &ss) const;
bool WaitInputs(const TPipe &tpipe, std::stringstream &ss) const;
bool IsReadOutersideWrite(ascir::AxisId &target_id) const;
Status AllocOutputs(const TPipe &tpipe, std::stringstream &ss) const;
bool IsUnitLastRead(const ApiTensor &tensor) const;
std::string api_name_;
ascir::AxisId axis;
std::string type;
int64_t depth;
ascir::ComputeUnit unit;
ascir::ComputeType compute_type;
std::vector<ApiTensor> outputs;
std::vector<const ApiTensor *> inputs;
bool enable_cache{false};
bool is_input_tbuf_contiguous = false;
std::string enable_cache_with_condition;
af::ExecuteCondition exec_condition;
ApiCallContext api_call_context = {ApiScene::kDefault, ComputeStage::kDefault};
std::unordered_map<int64_t, int64_t> tmp_buf_id;
af::AscNodePtr node;
std::string graph_name;
std::string node_name;
private:
bool WaitInputVector(const TPipe &tpipe, const ApiTensor *in, const Tensor &t, std::stringstream &ss) const;
bool WaitInputMte(const TPipe &tpipe, const ApiTensor *in, const Tensor &t, std::stringstream &ss) const;
BoolType WaitShareInputs(const TPipe &tpipe, const ApiTensor *in, const Tensor t, std::stringstream &ss) const;
BoolType AllocShareOutputs(const TPipe &tpipe, const ApiTensor &out, const Tensor t, std::stringstream &ss) const;
Status HandleVecOutAlloc(const TPipe &tpipe, const ApiTensor &out, const Tensor &t, std::stringstream &ss,
bool with_define) const;
};
struct Loop {
ascir::AxisId axis_id;
struct Loop* parent;
std::vector<LoopBody> bodys;
std::set<const ApiCall *> used_calls = {};
bool is_graph_has_reduce_node = false;
bool is_ar = false;
explicit Loop(const ascir::AxisId axis);
ComputeStage compute_stage = ComputeStage::kDefault;
void AddLoop(Loop *loop);
void AddCall(ApiCall *call);
Status ConstructFromNodes(ascir::NodeViewVisitorConst nodes, const Tiler &tiler, TPipe& tpipe);
void Destruct();
Status Generate(const Tiler& tiler, const TPipe& tpipe, std::string &result,
ComputeStage stage = ComputeStage::kDefault);
const Tensor* GetReduceOutputTensor(const TPipe &tpipe) const;
const Tensor* GetReduceInputTensor(const TPipe &tpipe) const;
void CollectTensorCrossLoop(std::map<ascir::AxisId, std::vector<ApiCall *>> &api_calls);
Status ActualSizeDefine(const Tiler &tiler, const TPipe &tpipe, std::string dtype_name, std::string &result);
private:
Status GenerateLoop(const Tiler& tiler, const TPipe& tpipe, std::vector<ascir::AxisId>& current_axis, std::stringstream& ss);
Status GenerateBody(const Tiler& tiler, const TPipe& tpipe, std::vector<ascir::AxisId>& current_axis,
std::stringstream& ss);
void GenerateEnCacheCondition(const Tiler &tiler, const TPipe &tpipe, const Axis &axis, std::stringstream &ss) const;
bool IsFindInUsedCalls(const ApiCall *call) const;
std::string GetReduceType() const;
bool IsHaveReduceType(const std::string &type) const;
bool IsBodyContainLoop() const;
};
}
#endif