* Copyright (c) 2026 Huawei Technologies Co., Ltd.
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
* CANN Open Software License Agreement Version 2.0 (the "License").
* Please refer to the License for details. You may not use this file except in compliance with the License.
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
* See LICENSE in the root of the software repository for the full text of the License.
*/
#ifndef ATVOSS_EXPR_LINEARIZER_H
#define ATVOSS_EXPR_LINEARIZER_H
#include <cstddef>
#include "expression/expr_template.h"
#include "utils/utility.h"
#include "expr_cast_eliminate.h"
namespace Atvoss {
namespace Detail {
*
* @tparam Expr TypeList
* @tparam From 需要替换的类型
* @tparam To 目标类型
*/
template <typename Expr, typename From, typename To>
struct ReplaceExpr;
* 递归替换引擎
* @tparam Expr TypeList
* @tparam From 需要替换的类型
* @tparam To 目标类型
*/
template <typename Expr, typename From, typename To>
struct ReplaceRecursive {
using Type = Expr;
};
template <typename Expr, typename From, typename To>
struct ReplaceExpr {
using Type = std::conditional_t<
std::is_same_v<Expr, From>, To,
typename ReplaceRecursive<Expr, From, To>::Type>;
};
template <size_t N, typename T, auto Usage, size_t R, typename From, typename To>
struct ReplaceExpr<Atvoss::Param<N, T, Usage, R>, From, To> {
using Type = Atvoss::Param<N, T, Usage, R>;
};
template <template <typename, typename> class Op, typename LHS, typename RHS, typename From, typename To>
struct ReplaceRecursive<Op<LHS, RHS>, From, To> {
using newLhs = typename ReplaceExpr<LHS, From, To>::Type;
using newRhs = typename ReplaceExpr<RHS, From, To>::Type;
using Type = Op<newLhs, newRhs>;
};
template <template <typename> class Op, typename Inner, typename From, typename To>
struct ReplaceRecursive<Op<Inner>, From, To> {
using newInner = typename ReplaceExpr<Inner, From, To>::Type;
using Type = Op<newInner>;
};
template <template <auto, typename> class Op, auto N, typename Inner, typename From, typename To>
struct ReplaceRecursive<Op<N, Inner>, From, To> {
using newInner = typename ReplaceExpr<Inner, From, To>::Type;
using Type = Op<N, newInner>;
};
template <template <auto, typename, typename> class Op, auto N, typename R, typename Inner, typename From, typename To>
struct ReplaceRecursive<Op<N, R, Inner>, From, To> {
using newInner = typename ReplaceExpr<Inner, From, To>::Type;
using Type = Op<N, R, newInner>;
};
template <typename Expr, typename = void>
struct ExtractTypeListPostOrder;
template <size_t N, typename T, auto Usage, size_t R>
struct ExtractTypeListPostOrder<Atvoss::Param<N, T, Usage, R>> {
using Type = Atvoss::Util::TypeList<>;
};
template <size_t N, typename T, typename U>
struct ExtractTypeListPostOrder<Atvoss::LocalVar<N, T, U>> {
using Type = Atvoss::Util::TypeList<>;
};
template <template <typename, typename> class Op, typename LHS, typename RHS>
struct ExtractTypeListPostOrder<Op<LHS, RHS>, std::enable_if_t<IsBinaryOp_v<Op<LHS, RHS>>>> {
using lhsList = typename ExtractTypeListPostOrder<LHS>::Type;
using rhsList = typename ExtractTypeListPostOrder<RHS>::Type;
using concatList = typename Atvoss::Util::Concatenate<
typename std::conditional_t<
Atvoss::Util::IsSpecializationOf_v<Atvoss::OpAssign, Op<LHS, RHS>>,
typename Atvoss::Util::Concatenate<rhsList, lhsList>,
typename Atvoss::Util::Concatenate<lhsList, rhsList>
>::Type,
std::conditional_t<
Atvoss::Util::IsSpecializationOf_v<Atvoss::OpAndThen, Op<LHS, RHS>>, Atvoss::Util::TypeList<>,
Atvoss::Util::TypeList<Op<LHS, RHS>>>>::Type;
using Type = concatList;
};
template <template <typename> class Op, typename Expr>
struct ExtractTypeListPostOrder<Op<Expr>> {
using innerList = typename ExtractTypeListPostOrder<Expr>::Type;
using Type = typename Atvoss::Util::Concatenate<innerList, Atvoss::Util::TypeList<Op<Expr>>>::Type;
};
template <template <auto, typename> class Op, auto Pattern, typename Expr>
struct ExtractTypeListPostOrder<Op<Pattern, Expr>> {
using innerList = typename ExtractTypeListPostOrder<Expr>::Type;
using Type = typename Atvoss::Util::Concatenate<innerList, Atvoss::Util::TypeList<Op<Pattern, Expr>>>::Type;
};
template <template <auto, typename, typename> class Op, auto Pattern, typename R, typename Expr>
struct ExtractTypeListPostOrder<Op<Pattern, R, Expr>> {
using innerList = typename ExtractTypeListPostOrder<Expr>::Type;
using Type = typename Atvoss::Util::Concatenate<innerList, Atvoss::Util::TypeList<Op<Pattern, R, Expr>>>::Type;
};
using Atvoss::Util::TypeList;
template <size_t ID, typename Expr>
struct VarDef {};
template <size_t ID, typename DefList>
struct FindVarDef;
template <size_t ID>
struct FindVarDef<ID, TypeList<>> {
using Type = void;
};
template <size_t ID, typename Expr, typename... Rest>
struct FindVarDef<ID, TypeList<VarDef<ID, Expr>, Rest...>> {
using Type = Expr;
};
template <size_t ID, size_t OtherID, typename Expr, typename... Rest>
struct FindVarDef<ID, TypeList<VarDef<OtherID, Expr>, Rest...>> {
using Type = typename FindVarDef<ID, TypeList<Rest...>>::Type;
};
template <size_t ID, typename StmtList>
struct IsLocalVarReferencedByParam;
template <size_t ID>
struct IsLocalVarReferencedByParam<ID, TypeList<>> {
static constexpr bool value = false;
};
template <size_t ID, typename NextStmt, typename... OtherStmts>
struct IsLocalVarReferencedByParam<ID, TypeList<NextStmt, OtherStmts...>> {
template <typename Stmt>
struct CheckStmt;
template <size_t PID, typename PT, ParamUsage PU, size_t LID, typename LT, typename LSrc>
struct CheckStmt<OpAssign<Param<PID, PT, PU>, LocalVar<LID, LT, LSrc>>> {
static constexpr bool value = (LID == ID);
};
template <typename Other>
struct CheckStmt {
static constexpr bool value = false;
};
static constexpr bool current = CheckStmt<NextStmt>::value;
static constexpr bool rest = IsLocalVarReferencedByParam<ID, TypeList<OtherStmts...>>::value;
static constexpr bool value = current || rest;
};
* 把无效的localVar简化掉,比如localVar1 = in1*in1, out = localVar1简化成 out = in1*in1.
* @tparam TypeList 需要简化的表达式的初始列表
* @tparam DefList VarDef的列表,存放的是LocalVar的number与其右边的表达式的映射关系
* @tparam Result 化简后的结果
*/
template <typename TypeList, typename DefList = Atvoss::Util::TypeList<>, typename Result = Atvoss::Util::TypeList<>>
struct SimplifyImpl;
template <typename... Defs, typename... ResultTypes>
struct SimplifyImpl<TypeList<>, TypeList<Defs...>, TypeList<ResultTypes...>> {
using result = TypeList<ResultTypes...>;
};
template <typename First, typename... Rest, typename... Defs, typename... ResultTypes>
struct SimplifyImpl<TypeList<First, Rest...>, TypeList<Defs...>, TypeList<ResultTypes...>> {
private:
template <size_t ID, typename T, typename Likes, typename Expr>
struct ProcessLocalVarDef {
static constexpr bool willBeReplaced = IsLocalVarReferencedByParam<ID, TypeList<Rest...>>::value;
using newDefs = TypeList<Defs..., VarDef<ID, Expr>>;
using Type = std::conditional_t<
willBeReplaced, std::integral_constant<int, 0>,
OpAssign<LocalVar<ID, T, Likes>, Expr>
>;
};
template <size_t PID, typename PT, ParamUsage PU, size_t LID, typename LT, typename Likes, typename RestList>
struct ProcessParamAssign {
using found = FindVarDef<LID, TypeList<Defs...>>;
using newDefs = TypeList<Defs...>;
using Type = std::conditional_t<
!std::is_same_v<typename found::Type, void>,
OpAssign<Param<PID, PT, PU>, typename found::Type>,
OpAssign<Param<PID, PT, PU>, LocalVar<LID, LT, Likes>>
>;
template <typename T>
using ReplaceLocalVar2Param = ReplaceExpr<T, LocalVar<LID, LT, Likes>, Param<PID, PT, PU>>;
using restList = Util::Map_t<ReplaceLocalVar2Param, RestList>;
};
template <typename Stmt>
struct ProcessOther {
using newDefs = TypeList<Defs...>;
using Type = Stmt;
};
template <typename Stmt, typename RestList>
struct ProcessStmt;
template <size_t ID, typename T, typename Likes, typename Expr, typename RestList>
struct ProcessStmt<OpAssign<LocalVar<ID, T, Likes>, Expr>, RestList> {
using processor = ProcessLocalVarDef<ID, T, Likes, Expr>;
using newDefs = typename processor::newDefs;
using resultType = typename processor::Type;
using restList = RestList;
};
template <size_t PID, typename PT, ParamUsage PU, size_t LID, typename LT, typename Likes, typename RestList>
struct ProcessStmt<OpAssign<Param<PID, PT, PU>, LocalVar<LID, LT, Likes>>, RestList> {
using processor = ProcessParamAssign<PID, PT, PU, LID, LT, Likes, RestList>;
using newDefs = typename processor::newDefs;
using resultType = typename processor::Type;
using restList = typename processor::restList;
};
template <typename Stmt, typename RestList>
struct ProcessStmt {
using processor = ProcessOther<Stmt>;
using newDefs = typename processor::newDefs;
using resultType = typename processor::Type;
using restList = RestList;
};
using processor = ProcessStmt<First, TypeList<Rest...>>;
using newDefs = typename processor::newDefs;
using currentResult = typename processor::resultType;
using nextResult = std::conditional_t<
std::is_same_v<currentResult, std::integral_constant<int, 0>>,
TypeList<ResultTypes...>,
TypeList<ResultTypes..., currentResult>>;
using next = SimplifyImpl<typename processor::restList, newDefs, nextResult>;
public:
using result = typename next::result;
};
template <typename List>
struct Simplify {
using Type = typename SimplifyImpl<List>::result;
};
template <typename ExprList, size_t LocalVarNumber, typename Processed>
struct OptimizeWithLocalVarsImpl;
template <size_t LocalVarNumber, typename Processed>
struct OptimizeWithLocalVarsImpl<Atvoss::Util::TypeList<>, LocalVarNumber, Processed> {
using Type = Processed;
};
template <typename First, typename... Rest, size_t LocalVarNumber, typename Processed>
struct OptimizeWithLocalVarsImpl<Atvoss::Util::TypeList<First, Rest...>, LocalVarNumber, Processed> {
private:
static constexpr bool shouldCache = !Atvoss::IsParam_v<First>;
static constexpr bool firstIsOpAssign = Atvoss::Util::IsSpecializationOf_v<Atvoss::OpAssign, First>;
static constexpr auto NextLocalVarNumber = firstIsOpAssign ? LocalVarNumber : LocalVarNumber + 1;
using localVarType = Atvoss::LocalVar<
LocalVarNumber, typename First::RetType, Atvoss::Param<1ul, typename First::RetType, (Atvoss::ParamUsage)0>>;
using assignmentType = std::conditional_t<firstIsOpAssign, First, Atvoss::OpAssign<localVarType, First>>;
template <typename Expr>
using ReplaceOne = typename ReplaceExpr<Expr, First, localVarType>::Type;
using updatedRest = Atvoss::Util::TypeList<ReplaceOne<Rest>...>;
using nextProcessed = typename std::conditional_t<
shouldCache, Atvoss::Util::Append<Processed, assignmentType>, Util::TypeWrapper<Processed>>::Type;
using nextState = std::conditional_t<
shouldCache, OptimizeWithLocalVarsImpl<updatedRest, NextLocalVarNumber, nextProcessed>,
OptimizeWithLocalVarsImpl<Atvoss::Util::TypeList<Rest...>, LocalVarNumber, nextProcessed>>;
public:
using Type = typename nextState::Type;
};
template <typename ExprList>
struct OptimizeWithLocalVars : OptimizeWithLocalVarsImpl<ExprList, 1, Atvoss::Util::TypeList<>> {};
template <typename ExprList, typename Processed>
struct OptimizeBindBuffExprImpl;
template <typename Processed>
struct OptimizeBindBuffExprImpl<Atvoss::Util::TypeList<>, Processed> {
using Type = Processed;
};
template <typename T>
struct SimplifyAssignOfOpParam {
using Type = T;
};
template <typename LHS, typename RHS>
struct SimplifyAssignOfOpParam<Atvoss::OpAssign<LHS, RHS>> {
using Type = LHS;
};
template <typename T>
struct SimplifyAssign;
template <size_t N, typename T, auto Usage, size_t R>
struct SimplifyAssign<Atvoss::Param<N, T, Usage, R>> {
using Type = Atvoss::Param<N, T, Usage, R>;
};
template <typename LHS, typename RHS>
struct SimplifyAssign<Atvoss::OpAssign<LHS, RHS>> {
using Type = Atvoss::OpAssign<LHS, typename SimplifyAssign<RHS>::Type>;
};
template <template <typename, typename> class Op, typename LHS, typename RHS>
struct SimplifyAssign<Op<LHS, RHS>> {
using Type = Op<typename SimplifyAssignOfOpParam<LHS>::Type, typename SimplifyAssignOfOpParam<RHS>::Type>;
};
template <template <typename> class Op, typename Expr>
struct SimplifyAssign<Op<Expr>> {
using Type = Op<typename SimplifyAssignOfOpParam<Expr>::Type>;
};
template <template <auto, typename> class Op, auto N, typename Expr>
struct SimplifyAssign<Op<N, Expr>> {
using Type = Op<N, typename SimplifyAssignOfOpParam<Expr>::Type>;
};
template <auto N, typename R, typename Expr>
struct SimplifyAssign<Atvoss::OpCast<N, R, Expr>> {
using Type = Atvoss::OpCast<N, R, typename SimplifyAssignOfOpParam<Expr>::Type>;
};
template <typename First, typename... Rest, typename Processed>
struct OptimizeBindBuffExprImpl<Atvoss::Util::TypeList<First, Rest...>, Processed> {
private:
static constexpr bool shouldCache = !Atvoss::IsParam_v<First>;
using localVarType = First;
using assignSimplified = std::conditional_t<
Atvoss::Util::IsSpecializationOf_v<Atvoss::OpAssign, First>, TypeList<typename SimplifyAssign<First>::Type>,
TypeList<>>;
using nextProcessed = Atvoss::Util::Concatenate_t<Processed, assignSimplified>;
using nextState = OptimizeBindBuffExprImpl<Atvoss::Util::TypeList<Rest...>, nextProcessed>;
public:
using Type = typename nextState::Type;
};
template <typename ExprList>
struct OptimizeBindBuffExpr : OptimizeBindBuffExprImpl<ExprList, Atvoss::Util::TypeList<>> {};
template <typename Expr>
struct ExprLinearizer {
using postOrderList = Atvoss::Util::Unique_t<typename ExtractTypeListPostOrder<Expr>::Type>;
using removeRedundantCastExprList =
decltype(Atvoss::Graph::RemoveRedundantCast<typename std::conditional_t<
Util::Size<typename LocalVars<Expr>::Type>::value == 0, OptimizeWithLocalVars<postOrderList>,
OptimizeBindBuffExpr<postOrderList>>::Type>());
using optimizedList = typename Simplify<removeRedundantCastExprList>::Type;
};
}
template <typename Expr>
__host_aicore__ constexpr auto ToLinearizerExpr(Expr expr)
{
return typename ToOpAndThenExpr<typename Detail::ExprLinearizer<typename Expr::Type>::optimizedList>::Type();
}
}
#endif