| fix(triton-to-linalg): add precheck for interleave optimization to avoid assert failure
Co-authored-by: luobaiqing<luobaiqing1@huawei.com>
# message auto-generated for no-merge-commit merge:
!1179 merge interleave into main
fix(triton-to-linalg): add precheck for interleave optimization to avoid assert failure
Created-by: luobaiqing
Commit-by: luobaiqing
Merged-by: ascend-robot
Description: interleave是指以下情况,当我们需要按照序号交错地加载一个tensor,其来源都是一个src,但一个加载奇数项,另一个加载偶数项时,如果不优化我们将需要做两次访存,且stride=2,访存不友好。我们通过优化,可以只加载一次(即tl.load(src + tl.arange(0, BLOCK) )),然后通过extrace_slice去分别提取even和odd
```
dim_range = tl.arange(0, BLOCK // 2)
last_dim_even_range = dim_range * 2
last_dim_odd_range = dim_range * 2 + 1
even = tl.load(src+last_dim_even_range)
odd = tl.load(src+last_dim_odd_range)
```
当前的interleave优化存在问题:
目前的实现仅支持last_dim_even_range = dim_range * 2和last_dim_odd_range = dim_range * 2 + 1的情况,其通过检测访存的offset项是否来源于 addi src, c1(即所谓的add constant one)来判断当前访存是奇数项还是偶数项。
然而,他在实现中默认了这一情况,使用了assert去断言如果offset来源于addOp且存在constant项,那么该constant一定为1。可是三方算子中出现了新的情况,同样是交错访存,其访存的offset项也是来源于 addOp,但不是addi src, c1,而是addi src, c64和addi src, c65,其实就是last_dim_even_range = dim_range * 2 + 64, last_dim_odd_range = dim_range * 2 + 65的情况,这会引发断言报错从而编译失败。类似的,如果static offset不为0或1,断言也会失败
本次bugfix,对interleave optimization进行提前的判断,检查其offset项是否来源于 addi src, c1,如果addOp不是+1,那么暂时不进行这项优化。
如果要增加对我新描述的情况的支持,我认为当前的interleave optimization的框架并不合适,目前按暂时规避处理
由于交错访存会出现stride=2的情况,但bishengir处理这种情况也存在bug,因此暂时不上ut看护
```
// triton-ascend\third_party\ascend\lib\Utils\InterleaveOptimization.cpp::recountReinterpretCastOffset()
...
// To trace value type offset
std::function<bool(Operation *)> traceOffset = [&](Operation *op) -> bool {
// Consider constant one in add constant one operation
if (llvm::isa<arith::ConstantOp>(op))
return false;
if (llvm::isa<arith::AddIOp>(op)) {
auto addOp = llvm::cast<arith::AddIOp>(op);
if (auto constLHS = addOp.getLhs().getDefiningOp<arith::ConstantOp>()) {
assert(dyn_cast<IntegerAttr>(constLHS.getValueAttr()).getInt() == 1 &&
"Arith::constant value of addi's operand must be 1 when "
"calculate deinterleave offset");
return false;
}
if (auto constRHS = addOp.getRhs().getDefiningOp<arith::ConstantOp>()) {
assert(dyn_cast<IntegerAttr>(constRHS.getValueAttr()).getInt() == 1 &&
"Arith::constant value of addi's operand must be 1 when "
"calculate deinterleave offset");
return false;
}
}
return true;
};
...
```
The core Triton is a small number of people, and we receive many PRs (thank
you!). To help us review your code more quickly, **if you are a new
contributor (less than 3 PRs merged) we ask that you complete the following
tasks and include the filled-out checklist in your PR description.**
Complete the following tasks before sending your PR, and replace [ ] with
[x] to indicate you have done them.
- [ ] I am not making a trivial change, such as fixing a typo in a comment.
- [ ] I have written a PR description following these
[rules](https://cbea.ms/git-commit/#why-not-how).
- [ ] I have run pre-commit run --from-ref origin/main --to-ref HEAD.
- Select one of the following.
- [ ] I have added tests.
- /test for lit tests
- /unittest for C++ tests
- /python/test for end-to-end tests
- [ ] This PR does not need a test because FILL THIS IN.
- Select one of the following.
- [ ] I have not added any lit tests.
- [ ] The lit tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices),
including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)
See merge request: Ascend/triton-ascend!1179 | 3 个月前 |