| 【feat】: ATT支持多Group场景的CacheLine冲突建模
Co-authored-by: zhang_shengjie<804425610@qq.com>
# message auto-generated for no-merge-commit merge:
!440 merge feat/cache-line-conflict-detection into develop
【feat】: ATT支持多Group场景的CacheLine冲突建模
Created-by: zhang_shengjie
Commit-by: zhang_shengjie
Merged-by: cann-robot
Description:
# Pull Request
## 描述
### 一、主要解决的问题
#### 1.1 Cache Line冲突导致性能建模不准确
在ATT Tiling代码生成的**组并行(group-parallel)**场景中,多个group并行执行时,perf建模原先采用统一的并行合并策略(GenPerfUpdateCode):所有group的perf通过UpdateCurPerfAndBlockByGroup进行并行归并。
然而,当某些group的数据搬运的**单次搬运字节数小于硬件Cache Line大小**(如128字节)时,会产生**Cache Line冲突(false sharing)**。此时,这些group的实际执行性能不再是并行关系,而是**串行叠加**的。原有建模将冲突组perf按并行归并,导致性能评估偏高。
**影响范围**:
- Load(GM→UB)和Store(UB→GM)两个方向均可能产生冲突
- 冲突组的perf应按**串行求和**而非**并行归并**计算
#### 1.2 CacheLineConfig缺少方向信息和表达式语义不精确
| 问题 | 描述 |
|------|------|
| CacheLineConfig缺少方向信息 | 无法区分GM→UB读操作和UB→GM写操作 |
| 表达式语义不精确 | GetBlockCount使用总字节数而非单次搬运字节数 |
| 无法区分两种表达式的用途 | 冲突检测需要单次搬运量,Solver需要总量,共用一个字段 |
---
### 二、修改方案
#### 2.1 CacheLineDirection枚举与CacheLineConfig扩展(model_info.h)
新增CacheLineDirection枚举:
```cpp
enum class CacheLineDirection : uint8_t {
kUnknown = 0, // 未知方向,不参与冲突检测
kGmToUb = 1, // GM→UB 读操作(Load)
kUbToGm = 2, // UB→GM 写操作(Store)
};
```
扩展CacheLineConfig结构体:
| 字段/方法 | 说明 |
|-----------|------|
| direction | 数据搬运方向 |
| solver_cache_line_expr | 总搬运字节数表达式(accumulate(dims) * dtype_size),编译期常量时为空 |
| cache_line_expr | 含义收窄为单次搬运字节数(transfer_len * dtype_size) |
| IsCacheLineConflictCandidate() | 判断是否需参与冲突检测(kGmToUb || kUbToGm,kUnknown返回false) |
**关键设计**:将cache_line_expr(单次搬运量,用于冲突检测)和solver_cache_line_expr(总量,用于Solver UB约束)分离。
#### 2.2 Perf注册层重构(ascendc_api_perf_v2.cpp)
将GetBlockCount重构为AppendCacheLineConfig:
| 维度 | 原实现 | 新实现 |
|------|--------|--------|
| 签名 | GetBlockCount(node_info, config) | AppendCacheLineConfig(node_info, direction, config) |
| 表达式 | accumulate(dims) * dtype_size(总量) | 分离:transfer_len * dtype_size(单次)+ 总量 |
| 方向 | 无 | 参数化传入 |
调用点:Load→kGmToUb,Store→kUbToGm,Nddma→kGmToUb。
#### 2.3 冲突检测Helper函数生成(tiling_code_gen_impl.cpp/h)
在tiling codegen阶段,为每个Group的每个tiling_case生成冲突检测lambda。
**新增6个方法**:
| 方法 | 职责 |
|------|------|
| IsConflictCacheLineConfig | 静态辅助:判断CacheLineConfig是否需参与冲突检测 |
| GenConflictGroupHelpers | 入口:遍历所有group,生成冲突检测lambda |
| GenConflictGroupHelper | 为单个group的特定tiling_case生成冲突检测lambda |
| GenConflictExprContextCode | 解析表达式中的自由符号,生成变量声明上下文 |
| GenConflictGroupInvoke | 生成运行时switch-case分派代码 |
| GenMixedPerfUpdateCode | 替换原GenPerfUpdateCode,实现混合perf聚合 |
##### GenConflictGroupHelper — 生成单个冲突检测Lambda
生成格式:IsConflictGroup_{asc_graph_id}_{impl_graph_id}_{group_id}_{tiling_case_id}
**检测逻辑**:
1. **检查Schedule Table**:为空或未启用cache line检查 → fallback返回false
2. **过滤方向**:只处理IsConflictCacheLineConfig()为true的配置(即kGmToUb或kUbToGm,kUnknown被跳过)
3. **解析表达式符号**:调用GenConflictExprContextCode解析表达式中的自由变量
4. **冲突判定**:if (cache_line_expr < cfg_cache_line_size) return true;,其中cfg_cache_line_size优先使用自身值(> 0时),回退到ScheduleTable全局值
5. **异常处理**:无可有效配置 → fallback,表达式不可codegen → fallback
##### GenConflictExprContextCode — 符号解析
使用emit_decl lambda消除重复代码,按优先级查找符号来源:
| 符号类型 | 数据源 | 判断方式 |
|---------|--------|---------|
| block_dim | tiling_data | 名称匹配 |
| 硬件约束符号 | tiling_data | std::any_of遍历hardware_cons |
| 容器/张量表达式 | group_tiling_data | count()查找 |
| 输入变量 | group_tiling_data | ArgsManager获取input_vars |
| axis列表变量 | group_tiling_data | std::any_of遍历arg_list |
| 未知符号 | — | 返回{ "", false }触发fallback |
去重机制:declared_symbols set确保同一符号只声明一次。
##### GenConflictGroupInvoke — 运行时分派
生成IIFE风格的switch-case,按final_tiling_key分发到对应helper。去重防护:重复key直接fallback。
##### GenMixedPerfUpdateCode — 混合Perf聚合
```
conflict_perf_sum = 0.0
normal_perf_merged = 0.0
for each group:
if (is_conflict): conflict_perf_sum += perf // 冲突组:直接求和
else if (!has_normal_group): 初始化首个普通组
else: UpdateCurPerfAndBlockByGroup(...) // 后续普通组:并行归并
if (has_normal_group): normal_perf_merged += cur_tmp_perf
cur_perf = conflict_perf_sum + normal_perf_merged
```
#### 2.4 Solver约束修正(axes_reorder_solver_gen.cpp)
使用solver_cache_line_expr(总量)替代cache_line_expr(单次量),并增加IsValid检查:
```cpp
// 修正前: if (c.cache_line_size > 0)
// 修正后: if (c.cache_line_size > 0 && IsValid(c.solver_cache_line_expr))
```
---
### 三、代码修改流程图
#### 3.1 文件依赖关系图
```mermaid
graph TD
MI["model_info.h<br/>CacheLineDirection枚举<br/>CacheLineConfig扩展"]
TCI_H["tiling_code_gen_impl.h<br/>6个新方法声明"]
TCI_CPP["tiling_code_gen_impl.cpp<br/>冲突检测核心实现"]
PERF["ascendc_api_perf_v2.cpp<br/>AppendCacheLineConfig重构"]
SOLVER["axes_reorder_solver_gen.cpp<br/>solver_cache_line_expr"]
STUB["stub_solver_model_info.h/cpp<br/>测试辅助类"]
UT["att_generator_unittest.cpp<br/>11个单元测试"]
ST["test_concat.cpp<br/>7输入concat基线场景"]
PERF_UT["test_ascir_perf_v2.cpp<br/>Store CacheLine表达式测试"]
MI --> TCI_CPP
MI --> PERF
MI --> SOLVER
MI --> STUB
TCI_H --> TCI_CPP
PERF --> PERF_UT
STUB --> UT
style MI fill:#FFD700,stroke:#B8860B,stroke-width:2px
style TCI_CPP fill:#FFD700,stroke:#B8860B,stroke-width:2px
style PERF fill:#FFD700,stroke:#B8860B,stroke-width:2px
style SOLVER fill:#FFD700,stroke:#B8860B,stroke-width:2px
style STUB fill:#90EE90,stroke:#006400,stroke-width:2px
style UT fill:#90EE90,stroke:#006400,stroke-width:2px
style ST fill:#90EE90,stroke:#006400,stroke-width:2px
style PERF_UT fill:#90EE90,stroke:#006400,stroke-width:2px
```
#### 3.2 冲突检测函数生成流程
```mermaid
flowchart TD
Start([GenGetScheduleResult]) --> SecondTiling[二次Tiling生成]
SecondTiling --> GenHelpers[GenConflictGroupHelpers]
GenHelpers --> LoopGroups{遍历每个group}
LoopGroups --> LoopCases{遍历每个model_info}
LoopCases --> Dedup{tiling_case_id<br/>去重检查}
Dedup -->|重复| Skip[跳过]
Dedup -->|首次| CheckSchedule{检查Schedule Table}
CheckSchedule -->|nullptr或未启用| Fallback1["OP_LOGD: unavailable<br/>return false"]
CheckSchedule -->|已启用| FilterDirection{IsConflictCandidate?<br/>kGmToUb or kUbToGm}
FilterDirection -->|kUnknown 跳过| Fallback2["OP_LOGD: no valid gm双向ub expr<br/>return false"]
FilterDirection -->|有效方向| ResolveSymbols[GenConflictExprContextCode<br/>解析符号来源]
ResolveSymbols -->|解析失败| Fallback3["OP_LOGD: not codegenable<br/>return false"]
ResolveSymbols -->|解析成功| GenerateCheck["生成: if expr lt cfg_cache_line_size<br/>return true"]
GenerateCheck --> LoopCases
Fallback1 --> LoopCases
Fallback2 --> LoopCases
Fallback3 --> LoopCases
Skip --> LoopCases
LoopCases -->|所有case完成| LoopGroups
LoopGroups -->|所有group完成| PerfCalc[perf计算和更新]
PerfCalc --> CollectFlags[GenConflictGroupInvoke<br/>收集groups_conflict_flags]
CollectFlags --> MixedPerf[GenMixedPerfUpdateCode<br/>混合聚合]
style Start fill:#90EE90
style Fallback1 fill:#FF6347
style Fallback2 fill:#FF6347
style Fallback3 fill:#FF6347
style MixedPerf fill:#FF69B4,stroke:#C71585,stroke-width:2px
```
#### 3.3 混合性能聚合时序
```mermaid
sequenceDiagram
participant CG as CodeGen
participant RT as Runtime
participant H0 as Helper Group0
participant H1 as Helper Group1
participant HN as Helper GroupN
CG->>RT: 生成IsConflictGroup_x_x_0_0 lambda
CG->>RT: 生成IsConflictGroup_x_x_1_1 lambda
CG->>RT: 生成switch-case分发代码
RT->>RT: 遍历所有groups
RT->>H0: 调用冲突检测
H0-->>RT: true (conflict)
RT->>RT: conflict_perf_sum += perf0
RT->>H1: 调用冲突检测
H1-->>RT: false (normal)
RT->>RT: cur_tmp_perf = perf1 (首个普通组)
RT->>HN: 调用冲突检测
HN-->>RT: false (normal)
RT->>RT: UpdateCurPerfAndBlockByGroup归并
RT->>RT: cur_perf = conflict_perf_sum + normal_perf_merged
```
---
### 四、单元测试覆盖(11个代码生成UT + 1个Perf UT)
| # | 测试名称 | 验证场景 | 预期行为 |
|---|---------|---------|---------|
| 1 | AllConflict_UseSumAggregation | 两个Group都有冲突 | 使用求和聚合,验证conflict_perf_sum和合并代码存在 |
| 2 | BoundaryEqualCacheLine_StaysNormal | 表达式值恰好等于cache_line_size | 不判定为冲突(< 128返回false) |
| 3 | FirstConflictSecondNormal_InitFromFirstNormal | Group0冲突,Group1正常 | 从第一个正常组初始化,验证has_normal_group逻辑 |
| 4 | FinalTilingKeyDispatch_UsesFinalCaseHelper | 多case tiling key分发 | switch-case按case_id正确分发到对应helper |
| 5 | ByteExprDoesNotMultiplyDtypeAgain | 字节表达式不应重复乘dtype_size | 验证生成的代码不包含dtype_size |
| 6 | MissingScheduleTable_FallbackToNormalWithLog | Schedule Table为nullptr | fallback到正常组,验证日志输出 |
| 7 | GmToUbConflict_UseSumAggregation | Group0为kGmToUb方向(Load冲突) | Load方向也触发冲突检测,使用求和聚合 |
| 8 | DuplicateFinalKey_FallbackToNormalWithLog | 重复的tiling_case_id | 检测到重复key,验证日志输出 |
| 9 | DynamicInputSizeSymbols_GenerateContext | 动态Shape符号(s1, s20) | 正确解析符号并生成context代码 |
| 10 | UnknownDirection_FallbackToNormal | kUnknown方向 | fallback到正常组,验证日志输出 |
| 11 | MultiWriteExprs_DeduplicateContext | 同group多个CacheLine配置共享符号 | 符号声明去重,CountSubstr == 1 |
另有一个Perf注册级别测试(test_ascir_perf_v2.cpp):TestStoreApiCacheLineExprUsesSingleTransferLen,验证Store操作的cache_line_expr使用**最内层传输长度**而非全维度乘积。
---
## 变更类型
- [ ] 🐛 Bug 修复
- [x] ✨ 新功能
- [ ] 💄 代码风格更新
- [x] ♻️ 重构(GenConflictExprContextCode重构)
- [ ] 📦 构建过程或辅助工具的变动
- [ ] 📝 文档内容更新
## 关联的Issue
(待补充)
## 如何测试
### 一 测试用例说明
#### 1.1 单元测试
运行ATT generator单元测试:
```bash
# 冲突检测相关测试(11个用例)
./att_generator_unittest --gtest_filter="*CacheLine*"
# API perf级别测试
./utest_ascir_perf_v2 --gtest_filter="*CacheLineExpr*"
```
#### 1.2 系统测试
运行concat场景测试,验证七输入中轴concat的基线行为:
```bash
# concat场景端到端测试
./test_concat --gtest_filter="*SevenInputsMiddle*"
```
---
## 核对清单
- [x] 我的代码遵循了项目的代码风格
- [x] 我已对代码进行了自测
- [x] 我已更新了相关的文档
- [x] 我在标题中使用了合适的类型标签(feat:)
- [x] 我已经详细阅读了贡献指南
## 其他信息
### 验证方法
1. 编译通过,所有新增单元测试通过
2. 生成的tiling代码中包含IsConflictGroup_*lambda函数
3. 冲突场景下perf计算使用求和策略,普通场景保持并行合并
4. 方向元数据正确传播:Load→kGmToUb, Store→kUbToGm
5. 双向冲突检测覆盖:kGmToUb和kUbToGm方向均参与检测
### 注意事项
- CacheLineConfig的direction字段默认值为kUnknown,IsCacheLineConflictCandidate()对kUnknown返回false(不参与检测),确保向后兼容
- 每个CacheLineConfig可拥有独立的cache_line_size阈值(当> 0时使用自身值),否则回退到ScheduleTable全局值
- 重复tiling_key会被自动检测并fallback,不会导致运行时错误
- 动态Shape符号需要正确注册在model_info的sizes/tensor_exprs/container_exprs中才能被解析
- 冲突检测覆盖双向(GM→UB + UB→GM),仅kUnknown方向被排除
See merge request: cann/graph-autofusion!440 | 11 天前 |