* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ATVOSS_GRAPH_EXPR_FLATTEN_H
#define ATVOSS_GRAPH_EXPR_FLATTEN_H
#include "utils/utility.h"
#include "expression/expr_template.h"
namespace Atvoss::Graph {
using Atvoss::IsParam_v;
using Atvoss::Util::Append_t;
using Atvoss::Util::IsTypeList_v;
using Atvoss::Util::Size_v;
using ParaConst = Param<1, int, ParamUsage::IN>;
template <typename Op, typename... New>
struct ReplaceOp;
template <template <typename> class Op, typename Old, typename New>
struct ReplaceOp<Op<Old>, New> {
using Type = Op<New>;
};
template <template <auto, typename> class Op, auto Arg, typename Old, typename New>
struct ReplaceOp<Op<Arg, Old>, New> {
using Type = Op<Arg, New>;
};
template <template <typename, typename> class Op, typename OldL, typename OldR, typename NewL, typename NewR>
struct ReplaceOp<Op<OldL, OldR>, NewL, NewR> {
using Type = Op<NewL, NewR>;
};
template <typename LocalVarList, typename TargetOp>
constexpr bool FindLocalVarIndex = false;
template <typename TargetOp>
constexpr bool FindLocalVarIndex<TypeList<>, TargetOp> = false;
template <typename Head, typename... Tail, typename TargetOp>
constexpr bool FindLocalVarIndex<TypeList<Head, Tail...>, TargetOp> =
std::is_same_v<typename Head::Like, TargetOp> ? true : FindLocalVarIndex<TypeList<Tail...>, TargetOp>;
template <typename LocalVarList, typename TargetOp>
constexpr size_t GetLocalVarIndex = 0;
template <typename TargetOp>
constexpr size_t GetLocalVarIndex<TypeList<>, TargetOp> = 0;
template <typename Head, typename... Tail, typename TargetOp>
constexpr size_t GetLocalVarIndex<TypeList<Head, Tail...>, TargetOp> =
std::is_same_v<typename Head::Like, TargetOp> ? Head::number : GetLocalVarIndex<TypeList<Tail...>, TargetOp>;
template <typename Operand, size_t Idx, typename M>
constexpr auto CreateLocalVar(M map)
{
if constexpr (IsParam_v<Operand>) {
return std::make_pair(map, Operand{});
} else if constexpr (FindLocalVarIndex<M, Operand>) {
constexpr size_t idx = GetLocalVarIndex<M, Operand>;
using LocalVarType = LocalVar<idx, float, ParaConst>;
return std::make_pair(map, LocalVarType{});
} else {
using NewLocalVar = LocalVar<Idx, float, Operand>;
using MergedMap = Append_t<M, NewLocalVar>;
using LocalVarType = LocalVar<Idx, float, ParaConst>;
return std::make_pair(MergedMap{}, LocalVarType{});
}
}
template <typename Lhs, typename Rhs, typename... Operands>
using MakeAssignOp = OpAssign<Lhs, typename ReplaceOp<Rhs, Operands...>::Type>;
template <typename Op, typename M, typename Expr>
constexpr auto FlattenOp(Op op, M map, Expr expr);
template <typename Operand, typename M, typename Expr>
constexpr auto FlattenOperand(Operand operand, M map, Expr expr)
{
auto [mapAfterOperand, localVarForOperand] = CreateLocalVar<Operand, Size_v<M> + 1>(map);
constexpr bool needFlattenOperand = !IsParam_v<Operand> && !FindLocalVarIndex<M, Operand>;
if constexpr (needFlattenOperand) {
auto [mapAfterFlattenOperand, exprAfterFlattenOperand] = FlattenOp(operand, mapAfterOperand, expr);
return std::tuple(localVarForOperand, mapAfterFlattenOperand, exprAfterFlattenOperand);
} else {
return std::tuple(localVarForOperand, mapAfterOperand, expr);
}
}
template <typename Op, typename M, typename Expr>
constexpr auto FlattenUnaryOp(Op op, M map, Expr expr)
{
using Operand = typename Op::DataType;
auto [localVar, mapAfterFlattenOperand, exprAfterFlattenOperand] = FlattenOperand(Operand{}, map, expr);
using NewAssign = MakeAssignOp<LocalVar<Size_v<M>, float, ParaConst>, Op, decltype(localVar)>;
using FinalExpr = Append_t<decltype(exprAfterFlattenOperand), NewAssign>;
return std::make_pair(mapAfterFlattenOperand, FinalExpr{});
}
template <typename Op, typename M, typename Expr>
constexpr auto FlattenBinaryOp(Op op, M map, Expr expr)
{
using LhsOperand = typename Op::LhsType;
using RhsOperand = typename Op::RhsType;
auto [localVarForLhs, mapAfterFlattenLhsOperand, exprAfterFlattenLhsOperand] =
FlattenOperand(LhsOperand{}, map, expr);
auto [localVarForRhs, mapAfterFlattenRhsOperand, exprAfterFlattenRhsOperand] =
FlattenOperand(RhsOperand{}, mapAfterFlattenLhsOperand, exprAfterFlattenLhsOperand);
using NewAssign =
MakeAssignOp<LocalVar<Size_v<M>, float, ParaConst>, Op, decltype(localVarForLhs), decltype(localVarForRhs)>;
using FinalExpr = Append_t<decltype(exprAfterFlattenRhsOperand), NewAssign>;
return std::make_pair(mapAfterFlattenRhsOperand, FinalExpr{});
}
template <typename Op, typename M, typename Expr>
constexpr auto FlattenOp(Op op, M map, Expr expr)
{
static_assert(IsUnaryOp_v<Op> || IsBinaryOp_v<Op>, "Op must be either UnaryOp or BinaryOp.");
if constexpr (IsUnaryOp_v<Op>) {
return FlattenUnaryOp(op, map, expr);
} else {
return FlattenBinaryOp(op, map, expr);
}
}
template <typename Lhs, typename Rhs, typename M, typename Expr>
constexpr auto FlattenOpAssign(OpAssign<Lhs, Rhs>, M map, Expr expr)
{
if constexpr (IsParam_v<Rhs>) {
using FinalExpr = Append_t<Expr, OpAssign<Lhs, Rhs>>;
return std::make_pair(map, FinalExpr{});
} else if constexpr (IsUnaryOp_v<Rhs>) {
using Operand = typename Rhs::DataType;
auto [localVar, mapAfterFlattenOperand, exprAfterFlattenOperand] = FlattenOperand(Operand{}, map, expr);
using NewAssign = MakeAssignOp<Lhs, Rhs, decltype(localVar)>;
using FinalExpr = Append_t<decltype(exprAfterFlattenOperand), NewAssign>;
return std::make_pair(mapAfterFlattenOperand, FinalExpr{});
} else if constexpr (IsBinaryOp_v<Rhs>) {
using LhsOperand = typename Rhs::LhsType;
using RhsOperand = typename Rhs::RhsType;
auto [localVarForLhs, mapAfterFlattenLhsOperand, exprAfterFlattenLhsOperand] =
FlattenOperand(LhsOperand{}, map, expr);
auto [localVarForRhs, mapAfterFlattenRhsOperand, exprAfterFlattenRhsOperand] =
FlattenOperand(RhsOperand{}, mapAfterFlattenLhsOperand, exprAfterFlattenLhsOperand);
using NewAssign = MakeAssignOp<Lhs, Rhs, decltype(localVarForLhs), decltype(localVarForRhs)>;
using FinalExpr = Append_t<decltype(exprAfterFlattenRhsOperand), NewAssign>;
return std::make_pair(mapAfterFlattenRhsOperand, FinalExpr{});
} else {
static_assert(true, "OpAssign's Right only support Param/UnaryOp/BinaryOp");
}
}
template <typename M, typename Expr>
constexpr auto FlattenTypeList(TypeList<>, M map, Expr expr)
{
return std::make_pair(map, expr);
}
template <typename Head, typename... Tail, typename M, typename Expr>
constexpr auto FlattenTypeList(TypeList<Head, Tail...>, M map, Expr expr)
{
static_assert(IsOpAssign_v<Head>, "Each one of TypeList must be a OpAssign");
auto [curMap, curExpr] = FlattenOpAssign(Head{}, map, expr);
return FlattenTypeList(TypeList<Tail...>{}, curMap, curExpr);
}
template <typename Expr>
constexpr auto FlattenExprRecursively(Expr expr)
{
if constexpr (IsOpAssign_v<Expr>) {
auto [map, finalExpr] = FlattenOpAssign(expr, TypeList<>{}, TypeList<>{});
return finalExpr;
} else if constexpr (IsTypeList_v<Expr>) {
auto [map, finalExpr] = FlattenTypeList(expr, TypeList<>{}, TypeList<>{});
return finalExpr;
} else {
static_assert(true, "Expr must be a OpAssign or TypeList");
}
}
}
#endif