* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "ascend/include/TritonToLinalg/BlockPtrAnalysis.h"
#include "ascend/include/TritonToLinalg/TritonToLinalgPass.h"
#include "ascend/include/Utils/Utils.h"
#include "bishengir/Dialect/Annotation/IR/Annotation.h"
#include "bishengir/Dialect/HIVM/IR/HIVM.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/Types.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/SmallVectorExtras.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/ErrorHandling.h"
#include <cassert>
#include <set>
#define DEBUG_TYPE "triton-block-ptr-analysis"
namespace mlir {
namespace triton {
SmallVector<OpFoldResult> &BlockData::getOffsetsRef() { return this->offsets; }
SmallVector<OpFoldResult> &BlockData::getSizesRef() { return this->sizes; }
SmallVector<OpFoldResult> &BlockData::getStridesRef() { return this->strides; }
Value &BlockData::getSourceRef() { return this->source; }
OpFoldResult &BlockData::getScalarRef() { return this->scalar; }
SmallVector<OpFoldResult> BlockData::getOffsets() const {
return this->offsets;
}
SmallVector<OpFoldResult> BlockData::getSizes() const { return this->sizes; }
SmallVector<OpFoldResult> BlockData::getStrides() const {
return this->strides;
}
OpFoldResult BlockData::getOffset(int index) const {
return this->offsets[index];
}
OpFoldResult BlockData::getSize(int index) const { return this->sizes[index]; }
OpFoldResult BlockData::getStride(int index) const {
return this->strides[index];
}
OpFoldResult BlockData::getScalar() const { return this->scalar; }
Value BlockData::getSource() const { return this->source; }
MemAccType BlockData::getMemAccType() const { return this->memAccTy; };
MemAccType &BlockData::getMemAccTypeRef() { return this->memAccTy; };
bool BlockData::isScalar() const { return !(this->scalar).isNull(); }
bool BlockData::isEmpty() const {
return !(this->getRank() || this->source || !(this->scalar).isNull());
}
bool BlockData::hasSource() const { return this->source != nullptr; }
void BlockData::removeSource() { this->source = nullptr; };
bool BlockData::hasResElemTy() const { return this->resElemTy != nullptr; }
Type &BlockData::getResElemTyRef() { return this->resElemTy; }
Type BlockData::getResElemTy() const { return this->resElemTy; }
int64_t BlockData::getRank() const {
assert(offsets.size() == sizes.size() && offsets.size() == strides.size());
return this->offsets.size();
}
void BlockData::setResElemTy(const Type &Ty) { this->resElemTy = Ty; }
void BlockData::setScalar(const OpFoldResult &scalar) { this->scalar = scalar; }
void BlockData::setSource(const Value &src) { this->source = src; }
void BlockData::setOffsets(const SmallVector<OpFoldResult> &offsets) {
this->offsets = offsets;
}
void BlockData::setStrides(const SmallVector<OpFoldResult> &strides) {
this->strides = strides;
}
void BlockData::setSizes(const SmallVector<OpFoldResult> &szs) {
this->sizes = szs;
}
void BlockData::setMemAccTy(const MemAccType &v) { this->memAccTy = v; }
void BlockData::setMemAccVal(const MemAccVal v) { this->memAccTy.value = v; }
OpFoldResult BlockData::inferBlockOffset(const Location &loc,
OpBuilder &builder) const {
OpFoldResult retOffset = builder.getIndexAttr(0);
for (auto ofr : offsets) {
retOffset = addOpFoldResult(retOffset, ofr, loc, builder);
}
return retOffset;
}
MemRefType BlockData::getResultMemrefType(int64_t offset,
ArrayRef<int64_t> resultShape) const {
SmallVector<int64_t> staticStrides;
SmallVector<Value> dynamicStrides;
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
auto baseMemrefType = dyn_cast<BaseMemRefType>(this->source.getType());
assert(baseMemrefType && "Invalid element type. It should be a base memref type.");
auto elementType = baseMemrefType.getElementType();
auto layout =
StridedLayoutAttr::get(this->source.getContext(), offset, staticStrides);
return MemRefType::get(resultShape, elementType, layout);
}
void BlockData::addBlock(BlockData &lBlock, BlockData &rBlock, Location loc,
ConversionPatternRewriter &rewriter) {
assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank());
assert(!(lBlock.hasSource() && rBlock.hasSource()) &&
"Don't support each BlockData has own base source pointer");
this->source =
lBlock.hasSource() ? lBlock.getSourceRef() : rBlock.getSourceRef();
assert(!(lBlock.hasResElemTy() && rBlock.hasResElemTy()));
if (lBlock.hasResElemTy()) {
assert(lBlock.hasSource());
this->resElemTy = lBlock.getResElemTyRef();
} else if (rBlock.hasResElemTy()) {
assert(rBlock.hasSource());
this->resElemTy = rBlock.getResElemTyRef();
}
if (lBlock.isScalar() && rBlock.isScalar()) {
auto addScalar = addOpFoldResult(lBlock.getScalarRef(),
rBlock.getScalarRef(), loc, rewriter);
this->scalar = addScalar;
} else if (lBlock.getRank() == 0) {
this->scalar =
lBlock.isScalar() ? lBlock.getScalarRef() : rBlock.getScalarRef();
}
for (const auto &[lOffset, rOffset] :
llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) {
this->offsets.push_back(addOpFoldResult(lOffset, rOffset, loc, rewriter));
}
for (const auto &[lStride, rStride] :
llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) {
this->strides.push_back(addOpFoldResult(lStride, rStride, loc, rewriter));
}
this->sizes = lBlock.getSizesRef();
this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef());
this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef());
}
void BlockData::subBlock(BlockData &lBlock, BlockData &rBlock, Location loc,
ConversionPatternRewriter &rewriter) {
assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank());
if (lBlock.isScalar() && rBlock.isScalar()) {
auto subScalar = subOpFoldResult(lBlock.getScalarRef(),
rBlock.getScalarRef(), loc, rewriter);
this->scalar = subScalar;
} else if (lBlock.getRank() == 0) {
this->scalar =
lBlock.isScalar() ? lBlock.getScalarRef() : rBlock.getScalarRef();
}
for (const auto &[lOffset, rOffset] :
llvm::zip(lBlock.getOffsetsRef(), rBlock.getOffsetsRef())) {
this->offsets.push_back(subOpFoldResult(lOffset, rOffset, loc, rewriter));
}
for (const auto &[lStride, rStride] :
llvm::zip(lBlock.getStridesRef(), rBlock.getStridesRef())) {
this->strides.push_back(subOpFoldResult(lStride, rStride, loc, rewriter));
}
this->sizes = lBlock.getSizesRef();
this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef());
this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef());
}
void BlockData::mulBlock(BlockData &lBlock, BlockData &rBlock, Location loc,
ConversionPatternRewriter &rewriter) {
assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank());
assert(!(lBlock.hasSource() && rBlock.hasSource()));
if (lBlock.isScalar() && rBlock.isScalar()) {
LLVM_DEBUG({llvm::dbgs() << "lBlock.scalar:" << lBlock.getScalar()
<< " rBlbock.scalar:" << rBlock.getScalar() << "\n"; });
auto scalar = mulOpFoldResult(lBlock.getScalar(), rBlock.getScalar(), loc, rewriter);
this->scalar = scalar;
}
BlockData *lb = &lBlock;
BlockData *rb = &rBlock;
if (lb->isScalar()) {
std::swap(lb, rb);
}
OpFoldResult rScalar = rb->getScalarRef();
for (const auto &lOffset : lb->getOffsetsRef()) {
this->offsets.push_back(mulOpFoldResult(lOffset, rScalar, loc, rewriter));
}
for (const auto &lStride : lb->getStridesRef()) {
this->strides.push_back(mulOpFoldResult(lStride, rScalar, loc, rewriter));
}
this->sizes = lb->getSizesRef();
this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef());
this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef());
}
void BlockData::divBlock(BlockData &lBlock, BlockData &rBlock, Location loc,
ConversionPatternRewriter &rewriter) {
assert(this->isEmpty() && lBlock.getRank() == rBlock.getRank());
assert(!(lBlock.hasSource() && rBlock.hasSource()));
assert(lBlock.isScalar() && rBlock.isScalar());
auto rScalar = rBlock.getScalar();
this->scalar = divOpFoldResult(lBlock.getScalar(), rScalar, loc, rewriter);
for (auto lOffset : lBlock.getOffsetsRef()) {
this->offsets.push_back(divOpFoldResult(lOffset, rScalar, loc, rewriter));
}
for (auto lStride : lBlock.getStridesRef()) {
this->strides.push_back(divOpFoldResult(lStride, rScalar, loc, rewriter));
}
this->sizes = lBlock.getSizesRef();
this->getMemAccTypeRef().merge(lBlock.getMemAccTypeRef());
this->getMemAccTypeRef().merge(rBlock.getMemAccTypeRef());
}
memref::ReinterpretCastOp BlockData::createCastOp(ArrayRef<int64_t> resultShape,
const Location &loc,
OpBuilder &builder) const {
OpFoldResult resOffset = this->inferBlockOffset(loc, builder);
auto resultType = this->getResultMemrefType(
isa<Attribute>(resOffset) ? getConstantIntValue(resOffset).value()
: ShapedType::kDynamic,
resultShape);
SmallVector<OpFoldResult> strides(this->strides);
for (size_t i = 0; i < strides.size(); i++) {
if (resultShape[i] == 1) {
if (auto strideValue = dyn_cast<Value>(strides[i])) {
auto oneIdx = builder.create<arith::ConstantOp>(loc, builder.getIndexAttr(1));
strides[i] = builder.create<arith::MaxSIOp>(loc, strideValue, oneIdx).getResult();
}
}
}
return builder.create<memref::ReinterpretCastOp>(
loc, resultType, this->source, resOffset, this->sizes, strides);
}
void BlockData::dump() const {
llvm::outs() << "[INFO][BEG] BlockData info\n";
llvm::outs() << "offsets has " << offsets.size() << " items\n";
int cnt = 0;
for (auto it = offsets.begin(); it != offsets.end(); ++it) {
llvm::outs() << "offsets[" << cnt++ << "] = " << *it << "\n";
}
llvm::outs() << "sizes has " << sizes.size() << " items\n";
cnt = 0;
for (auto it = sizes.begin(); it != sizes.end(); ++it) {
llvm::outs() << "sizes[" << cnt++ << "] = " << *it << "\n";
}
llvm::outs() << "strides has " << strides.size() << " items\n";
cnt = 0;
for (auto it = strides.begin(); it != strides.end(); ++it) {
llvm::outs() << "strides[" << cnt++ << "] = " << *it << "\n";
}
llvm::outs() << "source = " << source << "\n";
llvm::outs() << "scalar = " << scalar << "\n";
llvm::outs() << "resElemTy = " << resElemTy << "\n";
llvm::outs() << "memAccTy = " << memAccTy.toString() << "\n";
llvm::outs() << "[INFO][END] BlockData info\n";
}
Value BlockDataParser::getScalarMemRef(Value ptr, Value memref,
const Location &loc,
ConversionPatternRewriter &rewriter) {
assert(isa<triton::PointerType>(ptr.getType()) && "expect a scalar pointer");
if (ptr.getDefiningOp<triton::AddPtrOp>()) {
if (auto castOp = memref.getDefiningOp<memref::ReinterpretCastOp>()) {
return castOp.getResult();
} else {
llvm_unreachable("pointer value is defined by an unexpected op");
}
}
assert(isa<BlockArgument>(ptr) &&
"pointer should be produced by addptr or block argument");
BlockData data;
data.setSource(memref);
data.getOffsetsRef().push_back(rewriter.getIndexAttr(0));
data.getSizesRef().push_back(rewriter.getIndexAttr(1));
data.getStridesRef().push_back(rewriter.getIndexAttr(1));
auto castOp = data.createCastOp(SmallVector<int64_t>(1, 1), loc, rewriter);
return castOp.getResult();
}
void BlockDataParser::parse(
Value operand, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
if (known.find(operand) != known.end()) {
return data = known.lookup(operand), void();
}
if (isa<IntegerType>(operand.getType())) {
data.setScalar(getOpFoldResultOfLayoutInfo(operand, rewriter));
return;
}
if (isa<triton::PointerType>(operand.getType())) {
auto remappedPtr = rewriter.getRemappedValue(operand);
assert(remappedPtr);
if (auto op = operand.getDefiningOp()) {
if (auto addPtrOp = dyn_cast<triton::AddPtrOp>(op)) {
parseAddPtr(addPtrOp, data, loc, rewriter, known);
} else if (auto bitcastOp = dyn_cast<triton::BitcastOp>(op)) {
parseBitcast(bitcastOp, data, loc, rewriter, known);
} else if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(op)) {
parseTensorPtr(makeTensorPtrOp, data, loc, rewriter, known);
} else if (auto advanceOp = dyn_cast<triton::AdvanceOp>(op)) {
parseTensorPtr(advanceOp, data, loc, rewriter, known);
} else if (auto intToPtrOp = dyn_cast<triton::IntToPtrOp>(op)) {
data.setSource(remappedPtr);
} else if (auto customOp = dyn_cast<hivm::CustomOp>(op)) {
data.setSource(remappedPtr);
} else {
LLVM_DEBUG({ llvm::dbgs() << operand << "\n"; });
llvm_unreachable("Unexpected operand defining operation, a scalar "
"pointer can only be produced by AddPtrOp or direct block ptr or hivm CustomOp");
}
} else {
data.setSource(remappedPtr);
}
return;
}
if (auto addOp = operand.getDefiningOp<arith::AddIOp>()) {
parseAdd(addOp, data, loc, rewriter, known);
} else if (auto subOp = operand.getDefiningOp<arith::SubIOp>()) {
parseSub(subOp, data, loc, rewriter, known);
} else if (auto mulOp = operand.getDefiningOp<arith::MulIOp>()) {
parseMul(mulOp, data, loc, rewriter, known);
} else if (auto addPtrOp = operand.getDefiningOp<triton::AddPtrOp>()) {
parseAddPtr(addPtrOp, data, loc, rewriter, known);
} else if (auto constOp = operand.getDefiningOp<arith::ConstantOp>()) {
parseConstSplat(constOp, data, loc, rewriter, known);
} else if (auto broadcastOp = operand.getDefiningOp<triton::BroadcastOp>()) {
parseBroadcast(broadcastOp, data, loc, rewriter, known);
} else if (auto splatOp = operand.getDefiningOp<triton::SplatOp>()) {
parseSplat(splatOp, data, loc, rewriter, known);
} else if (auto expandDimsOp =
operand.getDefiningOp<triton::ExpandDimsOp>()) {
parseExpandDims(expandDimsOp, data, loc, rewriter, known);
} else if (auto remOp = operand.getDefiningOp<arith::RemSIOp>()) {
parseRem(remOp, data, loc, rewriter, known);
} else if (auto bitcastOp = operand.getDefiningOp<triton::BitcastOp>()) {
parseBitcast(bitcastOp, data, loc, rewriter, known);
} else if (auto extsiOp = operand.getDefiningOp<arith::ExtSIOp>()) {
parseExtSI(extsiOp, data, loc, rewriter, known);
} else if (auto divOp = operand.getDefiningOp<arith::DivSIOp>()) {
parseDiv(divOp, data, loc, rewriter, known);
} else if (auto makeRangeOp = operand.getDefiningOp<triton::MakeRangeOp>()) {
parseMakeRange(makeRangeOp, data, loc, rewriter, known);
} else if (auto reduceOp = operand.getDefiningOp<triton::ReduceOp>()) {
parseReduce(reduceOp, data, loc, rewriter, known);
} else if (auto loadOp = operand.getDefiningOp<triton::LoadOp>()) {
parseIndirectLoad<triton::LoadOp>(loadOp, data, loc, rewriter, known);
} else if (auto castOp = operand.getDefiningOp<arith::FPToSIOp>()) {
parseIndirectLoad<arith::FPToSIOp>(castOp, data, loc, rewriter, known);
} else if (auto extractSliceOp =
operand.getDefiningOp<tensor::ExtractSliceOp>()) {
parseExtractSlice(extractSliceOp, data, loc, rewriter, known);
} else if (auto forOp = operand.getDefiningOp<scf::ForOp>()) {
auto opResult = dyn_cast<OpResult>(operand);
assert(opResult && "expected OpResult for scf.for result");
unsigned resultIdx = opResult.getResultNumber();
parseIndirectLoad<scf::ForOp>(forOp, data, loc, rewriter, known, resultIdx);
} else if (auto tensorCastOp = operand.getDefiningOp<tensor::CastOp>()) {
parse(tensorCastOp.getSource(), data, loc, rewriter, known);
} else if (auto fillOp = operand.getDefiningOp<linalg::FillOp>()) {
parseFill(fillOp, data, loc, rewriter, known);
} else if (auto selectOp = operand.getDefiningOp<arith::SelectOp>()){
parseSelect(selectOp, data, loc, rewriter, known);
} else if (auto customOp = operand.getDefiningOp<hivm::CustomOp>()) {
auto opResult = dyn_cast<OpResult>(operand);
assert(opResult && "Expected operand to be an OpResult");
unsigned resultIdx = opResult.getResultNumber();
parseCustomOp(customOp, data, loc, rewriter, known, resultIdx);
} else if (auto genericOp = operand.getDefiningOp<linalg::GenericOp>()) {
if (genericOp->hasAttr("tt.from_make_range")) {
parseLinalgGenericFromMakeRange(genericOp, data, loc, rewriter, known);
} else {
operand.dump();
llvm_unreachable("encountered AddPtrOp produced by unsupported operation");
}
} else if (auto atomicRMWOp = operand.getDefiningOp<triton::AtomicRMWOp>()) {
parseAtomicRmw(atomicRMWOp, data, loc, rewriter, known);
} else {
operand.dump();
llvm_unreachable("encountered AddPtrOp produced by unsupported operation");
}
}
void BlockDataParser::parseAtomicRmw(
triton::AtomicRMWOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known)
{
auto opRes = op->getResult(0);
auto opResTy = opRes.getType();
std::vector<int64_t> resShape;
if (auto shapedResTy = dyn_cast<ShapedType>(opResTy)) {
resShape = shapedResTy.getShape().vec();
if (resShape.size() == 1 && resShape[0] == 1) {
Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value extracted = rewriter.create<tensor::ExtractOp>(loc, opRes, ValueRange{zeroIdx});
Value scalarIdx = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), extracted);
data.setMemAccVal(MemAccVal::StrucMemAcc);
data.setScalar(scalarIdx);
data.getSizesRef().push_back(rewriter.getIndexAttr(1));
data.getStridesRef().push_back(rewriter.getIndexAttr(0));
data.getOffsetsRef().push_back(scalarIdx);
return;
}
data.setMemAccVal(MemAccVal::UnstrucMemAcc);
} else {
data.setMemAccVal(MemAccVal::StrucMemAcc);
resShape.push_back(1);
}
for (auto &s : resShape) {
data.getOffsetsRef().push_back(rewriter.getIndexAttr(0));
data.getSizesRef().push_back(rewriter.getIndexAttr(s));
data.getStridesRef().push_back(rewriter.getIndexAttr(1));
}
data.setSource(opRes);
}
void BlockDataParser::parseAdd(
arith::AddIOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
BlockData lBlock, rBlock;
parse(op.getLhs(), lBlock, loc, rewriter, known);
parse(op.getRhs(), rBlock, loc, rewriter, known);
data.addBlock(lBlock, rBlock, loc, rewriter);
}
void BlockDataParser::parseSub(
arith::SubIOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
BlockData lBlock, rBlock;
parse(op.getLhs(), lBlock, loc, rewriter, known);
parse(op.getRhs(), rBlock, loc, rewriter, known);
data.subBlock(lBlock, rBlock, loc, rewriter);
}
void BlockDataParser::parseMul(
arith::MulIOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
BlockData lBlock, rBlock;
parse(op.getLhs(), lBlock, loc, rewriter, known);
parse(op.getRhs(), rBlock, loc, rewriter, known);
data.mulBlock(lBlock, rBlock, loc, rewriter);
}
void BlockDataParser::parseDiv(
arith::DivSIOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
BlockData lBlock, rBlock;
parse(op.getLhs(), lBlock, loc, rewriter, known);
parse(op.getRhs(), rBlock, loc, rewriter, known);
data.divBlock(lBlock, rBlock, loc, rewriter);
}
void BlockDataParser::parseRem(
arith::RemSIOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(false && "Address expression with modulo is not supported yet, it "
"shall be analysis at linearize.");
}
void BlockDataParser::parseMakeRange(
triton::MakeRangeOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
auto shape = dyn_cast<ShapedType>(op.getType()).getShape();
auto start = op.getStart();
auto end = op.getEnd();
auto stride = (end >= start) && (end - start <= shape[0]);
assert(stride == 1 &&
"make_range op should always return a tensor of stride 1");
data.getOffsetsRef().push_back(rewriter.getIndexAttr(start));
data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0]));
data.getStridesRef().push_back(rewriter.getIndexAttr(stride));
}
void BlockDataParser::parseLinalgGenericFromMakeRange(
linalg::GenericOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
assert(op->hasAttr("tt.from_make_range") &&
"expected tt.from_make_range attribute");
auto offsetAttr = op->getAttr("tt.make_range_offset");
auto sizeAttr = op->getAttr("tt.make_range_size");
assert(offsetAttr && sizeAttr &&
"tt.make_range_offset and tt.make_range_size required");
int64_t offset = cast<IntegerAttr>(offsetAttr).getInt();
int64_t size = cast<IntegerAttr>(sizeAttr).getInt();
data.getOffsetsRef().push_back(rewriter.getIndexAttr(offset));
data.getSizesRef().push_back(rewriter.getIndexAttr(size));
data.getStridesRef().push_back(rewriter.getIndexAttr(1));
}
void BlockDataParser::parseExpandDims(
triton::ExpandDimsOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
parse(op.getSrcMutable().get(), data, loc, rewriter, known);
auto resShape = dyn_cast<ShapedType>(op.getResult().getType()).getShape();
auto axis = op.getAxis();
assert(resShape[axis] == 1 &&
"The destiny shape of changed dimension should be 1");
data.getOffsetsRef().insert(data.getOffsetsRef().begin() + axis,
rewriter.getIndexAttr(0));
data.getSizesRef().insert(data.getSizesRef().begin() + axis,
rewriter.getIndexAttr(1));
data.getStridesRef().insert(data.getStridesRef().begin() + axis,
rewriter.getIndexAttr(0));
}
void BlockDataParser::parseExtractSlice(
tensor::ExtractSliceOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
const std::string scenarioMessages =
"PtsAnalysis supports indirectly block load in the "
"following scenario\n"
"B = tl.load(Aptr + Aoffset) # B is 1D tensor\n"
"s = tl.extract_slice(indices, offsets= (i,), sizes= "
"(1,), strides= (1,)) # s is a tensor<1x$dtype>\n"
"D = tl.load(Cptr + s + Coffset) # s is used as the "
"scalar offset\n";
auto extract_src = op->getOperand(0);
BlockData srcBlock;
parse(extract_src, srcBlock, loc, rewriter, known);
if (!srcBlock.hasSource()) {
llvm_unreachable(scenarioMessages.c_str());
}
if (!isa_and_nonnull<triton::LoadOp>(srcBlock.getSource().getDefiningOp())) {
llvm_unreachable(scenarioMessages.c_str());
}
auto extract_result = op->getResult(0);
auto shaped_ty = dyn_cast<RankedTensorType>(extract_result.getType());
auto shape = shaped_ty.getShape();
if (shape.size() > 1 || shape[0] > 1) {
llvm_unreachable(scenarioMessages.c_str());
}
auto castOp = rewriter.create<arith::IndexCastOp>(
loc, RankedTensorType::get(shape, rewriter.getIndexType()),
extract_result);
auto offset = castOp.getResult();
if (data.isEmpty()) {
data.getOffsetsRef().push_back(offset);
data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0]));
data.getStridesRef().push_back(rewriter.getIndexAttr(1));
} else {
llvm_unreachable(
"parseExtractSlice with offset already setup not yet supported");
}
}
void BlockDataParser::parseBitcast(
triton::BitcastOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
parse(op.getSrc(), data, loc, rewriter, known);
auto resType = op.getResult().getType();
Type resElemPointeeTy = nullptr;
if (auto resShapedTy = dyn_cast<ShapedType>(resType)) {
auto resElemTy = resShapedTy.getElementType();
resElemPointeeTy =
dyn_cast<triton::PointerType>(resElemTy).getPointeeType();
} else {
auto srcPointeeType =
cast<triton::PointerType>(op.getSrc().getType()).getPointeeType();
auto resPointeeType = cast<triton::PointerType>(resType).getPointeeType();
if (op->hasAttr("MetaUse") ||
(isa<BlockArgument>(op.getSrc()) &&
srcPointeeType == rewriter.getIntegerType(1) &&
resPointeeType == rewriter.getIntegerType(8))) {
resElemPointeeTy = resPointeeType;
} else {
auto remappedValue = rewriter.getRemappedValue(op);
data.setSource(remappedValue);
LLVM_DEBUG({
llvm::dbgs() << "Remapping bitcastOp:\n";
llvm::dbgs() << op << "\nto \n";
llvm::dbgs() << remappedValue << "\n";
});
}
}
data.setResElemTy(resElemPointeeTy);
}
void BlockDataParser::parseExtSI(
arith::ExtSIOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
parse(op.getIn(), data, loc, rewriter, known);
}
void BlockDataParser::parseBroadcast(
triton::BroadcastOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
auto src = op.getSrcMutable().get();
auto dst = op.getResult();
assert(isa<ShapedType>(src.getType()) &&
"tt.broadcast's input should be a tensor");
auto srcShape = dyn_cast<ShapedType>(src.getType()).getShape();
auto dstShape = dyn_cast<ShapedType>(dst.getType()).getShape();
assert(srcShape.size() == dstShape.size() &&
"rank of source shoule be equal to destnation");
parse(src, data, loc, rewriter, known);
for (const auto &[idx, src_dst] :
llvm::enumerate(llvm::zip(srcShape, dstShape))) {
const auto &[srcAxis, dstAxis] = src_dst;
if (srcAxis == dstAxis) {
continue;
}
assert(srcAxis < dstAxis &&
"srcShape of broadcastOp must be less than dstShape.");
data.getSizesRef()[idx] = rewriter.getIndexAttr(dstAxis);
}
}
void BlockDataParser::parseSplat(
triton::SplatOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
auto src = op.getSrc();
auto dst = op.getResult();
auto dstShape = dyn_cast<ShapedType>(dst.getType()).getShape();
parse(src, data, loc, rewriter, known);
if (isa<IntegerType>(src.getType()) ||
isa<triton::PointerType>(src.getType())) {
if (!data.isEmpty()) {
data.getOffsetsRef().clear();
data.getSizesRef().clear();
data.getStridesRef().clear();
}
for (auto dstAxis : dstShape) {
data.getOffsetsRef().push_back(rewriter.getIndexAttr(0));
data.getSizesRef().push_back(rewriter.getIndexAttr(dstAxis));
data.getStridesRef().push_back(rewriter.getIndexAttr(0));
}
} else {
op->emitError("Block data Analysis: unsupported splat pattern");
return;
}
if (data.isScalar()) {
data.getOffsetsRef()[0] = data.getScalarRef();
}
}
void BlockDataParser::parseConstSplat(
arith::ConstantOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
DenseElementsAttr denseAttr = dyn_cast<DenseElementsAttr>(op.getValue());
assert(denseAttr && denseAttr.isSplat() &&
isa<IntegerType>(denseAttr.getElementType()));
auto innerVal = denseAttr.getValues<IntegerAttr>()[0].getValue();
auto innerValIndexAttr = rewriter.getIndexAttr(innerVal.getSExtValue());
data.setScalar(innerValIndexAttr);
auto resType = dyn_cast<ShapedType>(op.getResult().getType());
size_t loopLimit = resType.getShape().size();
for (auto i = 0; i < loopLimit; i++) {
if (i == 0) {
data.getOffsetsRef().push_back(innerValIndexAttr);
} else {
data.getOffsetsRef().push_back(rewriter.getIndexAttr(0));
}
data.getSizesRef().push_back(rewriter.getIndexAttr(resType.getShape()[i]));
data.getStridesRef().push_back(rewriter.getIndexAttr(0));
}
}
template <typename T>
std::enable_if_t<std::is_same_v<T, triton::MakeTensorPtrOp> ||
std::is_same_v<T, triton::AdvanceOp>>
BlockDataParser::parseTensorPtr(
T op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
Value remappedValue = rewriter.getRemappedValue(op);
if (auto castOp = remappedValue.getDefiningOp<memref::ReinterpretCastOp>()) {
parseReinterpretCast(castOp, data, loc, rewriter, known);
} else {
llvm_unreachable("the value should be mapped to memref.reinterpret_cast");
}
}
void BlockDataParser::parseAddPtr(
triton::AddPtrOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
BlockData ptrBlock, offsetBlock;
parse(op.getPtr(), ptrBlock, op.getLoc(), rewriter, known);
parse(op.getOffset(), offsetBlock, op.getLoc(), rewriter, known);
assert(ptrBlock.hasSource() &&
"Ptr field should provide source/base pointer");
if (offsetBlock.hasSource()) {
ptrBlock.setMemAccTy(offsetBlock.getMemAccType());
offsetBlock.removeSource();
}
if (ptrBlock.getRank() == 1 && offsetBlock.getRank() == 0) {
offsetBlock.getSizesRef().push_back(rewriter.getIndexAttr(1));
offsetBlock.getOffsetsRef().push_back(offsetBlock.getScalarRef());
offsetBlock.getStridesRef().push_back(rewriter.getIndexAttr(0));
}
assert(ptrBlock.getRank() == offsetBlock.getRank() &&
"ptr and offset should have same rank");
LLVM_DEBUG({
auto &os = llvm::dbgs();
os << "[parseAddPtr][BEG] =========================\n";
os << "[parseAddPtr] op is " << op << "\n";
for (int i = 0; i < ptrBlock.getRank(); i++) {
os << "ptrBlock.getOffsetsRef()[" << i
<< "] = " << ptrBlock.getOffsetsRef()[i] << "\n";
os << "ptrBlock.getSizesRef()[" << i
<< "] = " << ptrBlock.getSizesRef()[i] << "\n";
os << "ptrBlock.getStridesRef()[" << i
<< "] = " << ptrBlock.getStridesRef()[i] << "\n";
os << "offsetBlock.getOffsetsRef()[" << i
<< "] = " << offsetBlock.getOffsetsRef()[i] << "\n";
os << "offsetBlock.getSizesRef()[" << i
<< "] = " << offsetBlock.getSizesRef()[i] << "\n";
os << "offsetBlock.getStridesRef()[" << i
<< "] = " << offsetBlock.getStridesRef()[i] << "\n";
}
os << "[parseAddPtr][END] -------------------------\n";
});
data.addBlock(ptrBlock, offsetBlock, op.getLoc(), rewriter);
}
void BlockDataParser::parseReinterpretCast(
memref::ReinterpretCastOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
assert(data.isEmpty());
data.setOffsets(op.getMixedOffsets());
data.setSizes(op.getMixedSizes());
data.setStrides(op.getMixedStrides());
data.setSource(op.getSource());
assert(data.getOffsetsRef().size() == 1);
size_t loopLimit = data.getSizesRef().size();
for (size_t i = 1; i < loopLimit; i++) {
data.getOffsetsRef().push_back(rewriter.getIndexAttr(0));
}
}
void BlockDataParser::parseReduce(
triton::ReduceOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
const std::string scenarioMessages =
"PtsAnalysis supports indirectly block load in the following scenario\n"
"B = tl.load(Aptr + Aoffset) # B is 1D tensor\n"
"s = tl.min(B) # s is a scalar\n"
"D = tl.load(Cptr + s + Coffset) # s is used as the scalar offset\n";
auto reduce_src = op->getOperand(0);
BlockData srcBlock;
parse(reduce_src, srcBlock, loc, rewriter, known);
if (!srcBlock.hasSource()) {
llvm_unreachable(scenarioMessages.c_str());
}
if (!isa_and_nonnull<triton::LoadOp>(srcBlock.getSource().getDefiningOp())) {
llvm_unreachable(scenarioMessages.c_str());
}
auto reduce_result = op->getResult(0);
auto shaped_ty = dyn_cast<RankedTensorType>(reduce_result.getType());
auto shape = shaped_ty.getShape();
auto ops = llvm::map_to_vector(op.getBody()->without_terminator(),
[](Operation &op) { return &op; });
if (shape.size() != 1 || op.getAxis() != 0 || ops.size() != 1 ||
!isa<arith::MinSIOp>(ops.front())) {
llvm_unreachable(scenarioMessages.c_str());
}
auto castOp = rewriter.create<arith::IndexCastOp>(
loc, RankedTensorType::get(shape, rewriter.getIndexType()),
reduce_result);
auto offset = castOp.getResult();
if (data.isEmpty()) {
data.getOffsetsRef().push_back(offset);
data.getSizesRef().push_back(rewriter.getIndexAttr(shape[0]));
data.getStridesRef().push_back(rewriter.getIndexAttr(1));
} else {
llvm_unreachable("parseReduce with offset already setup not yet supported");
}
}
template <typename OpTy>
void parseIndirectLoad(OpTy op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known,
unsigned resultIdx)
{
assert(resultIdx < op->getNumResults() &&
"resultIdx out of range for parseIndirectLoad");
auto opRes = op->getResult(resultIdx);
auto opResTy = opRes.getType();
std::vector<int64_t> resShape;
if (auto shapedResTy = dyn_cast<ShapedType>(opResTy)) {
resShape = shapedResTy.getShape().vec();
auto numOperands = 3;
if (resShape.size() == 1 && resShape[0] == 1 && op->getNumOperands() == numOperands) {
Value zeroIdx = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value extracted = rewriter.create<tensor::ExtractOp>(loc, opRes, ValueRange{zeroIdx});
Value scalarIdx = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(), extracted);
data.setMemAccVal(MemAccVal::StrucMemAcc);
data.setScalar(scalarIdx);
data.getSizesRef().push_back(rewriter.getIndexAttr(1));
data.getStridesRef().push_back(rewriter.getIndexAttr(0));
data.getOffsetsRef().push_back(scalarIdx);
return;
}
data.setMemAccVal(MemAccVal::UnstrucMemAcc);
} else {
data.setMemAccVal(MemAccVal::StrucMemAcc);
resShape.push_back(1);
}
for (auto &s : resShape) {
data.getOffsetsRef().push_back(rewriter.getIndexAttr(0));
data.getSizesRef().push_back(rewriter.getIndexAttr(s));
data.getStridesRef().push_back(rewriter.getIndexAttr(1));
}
data.setSource(opRes);
}
void BlockDataParser::parseCustomOp(hivm::CustomOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known, unsigned resultIdx)
{
auto srcValArrayAttr = op->getAttrOfType<DenseI32ArrayAttr>(ConverterUtils::customSrcPtrIndexAttrName);
assert(srcValArrayAttr && "structure hivm.custom op should present src tensor<tt.ptr>");
auto srcValArray = srcValArrayAttr.asArrayRef();
assert(srcValArray[resultIdx] != -1 && "tensor<tt.ptr> result should map to src tensor<tt.ptr>");
parse(op->getOperand(srcValArray[resultIdx]), data, loc, rewriter, known);
data.setSource(rewriter.getRemappedValue(op->getResult(resultIdx)));
}
void BlockDataParser::parseFill(linalg::FillOp op, BlockData &data,
const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known) {
auto src = op.getInputs()[0];
auto dst = op.getResult(0);
auto dstShape = dyn_cast<ShapedType>(dst.getType()).getShape();
parse(src, data, loc, rewriter, known);
if (isa<IntegerType>(src.getType())) {
if (!data.isEmpty()) {
data.getOffsetsRef().clear();
data.getSizesRef().clear();
data.getStridesRef().clear();
}
for (auto dstAxis : dstShape) {
data.getOffsetsRef().push_back(rewriter.getIndexAttr(0));
data.getSizesRef().push_back(rewriter.getIndexAttr(dstAxis));
data.getStridesRef().push_back(rewriter.getIndexAttr(0));
}
} else {
op->emitError("Block data Analysis: unsupported fillOp pattern");
return;
}
if (data.isScalar()) {
data.getOffsetsRef()[0] = data.getScalarRef();
}
}
void BlockDataParser::parseSelect(
arith::SelectOp op, BlockData &data, const Location &loc,
ConversionPatternRewriter &rewriter,
const llvm::SmallDenseMap<Value, BlockData> &known)
{
assert(data.isEmpty());
auto res = op.getResult();
auto resType = dyn_cast<ShapedType>(res.getType());
assert(resType && "arith.select result should be a ShapedType");
assert(isa<IntegerType>(resType.getElementType()) ||
isa<IndexType>(resType.getElementType()));
OpFoldResult indexOfr;
size_t loopLimit = resType.getShape().size();
Value cond = op.getCondition();
bool condIsScalarI1 =
isa<IntegerType>(cond.getType()) &&
cast<IntegerType>(cond.getType()).getWidth() == 1 &&
!isa<ShapedType>(cond.getType());
auto trueConst = dyn_cast<arith::ConstantOp>(op.getTrueValue().getDefiningOp());
auto falseConst = dyn_cast<arith::ConstantOp>(op.getFalseValue().getDefiningOp());
auto trueDense =
trueConst ? dyn_cast<DenseElementsAttr>(trueConst.getValue()) : DenseElementsAttr();
auto falseDense =
falseConst ? dyn_cast<DenseElementsAttr>(falseConst.getValue()) : DenseElementsAttr();
bool denseConstCase = condIsScalarI1 && trueDense && falseDense;
if (denseConstCase) {
Attribute trueFirst = *trueDense.value_begin<Attribute>();
Attribute falseFirst = *falseDense.value_begin<Attribute>();
Value trueScalar = nullptr;
Value falseScalar = nullptr;
if (auto tInt = dyn_cast<IntegerAttr>(trueFirst)) {
trueScalar = rewriter.create<arith::ConstantOp>(loc, tInt).getResult();
} else {
llvm_unreachable("unsupported true dense element attr in parseSelect");
}
if (auto fInt = dyn_cast<IntegerAttr>(falseFirst)) {
falseScalar = rewriter.create<arith::ConstantOp>(loc, fInt).getResult();
} else {
llvm_unreachable("unsupported false dense element attr in parseSelect");
}
assert(trueScalar.getType() == falseScalar.getType() &&
"scalarized true/false type mismatch");
auto scalarSelect = rewriter.create<arith::SelectOp>(
loc, trueScalar.getType(), cond, trueScalar, falseScalar);
indexOfr = getOpFoldResultOfLayoutInfo(scalarSelect.getResult(), rewriter);
} else {
assert(llvm::all_of(resType.getShape(), [](int64_t dim) { return dim == 1; }) &&
"parseSelect currently supports all-ones shape unless cond=i1 with dense constants");
SmallVector<Value> indices;
indices.reserve(loopLimit);
for (size_t i = 0; i < loopLimit; ++i) {
indices.push_back(rewriter.create<arith::ConstantIndexOp>(loc, 0));
}
auto extractOp = rewriter.create<tensor::ExtractOp>(loc, res, indices);
indexOfr = extractOp.getResult();
if (isa<IntegerType>(extractOp.getType())) {
indexOfr = getOpFoldResultOfLayoutInfo(extractOp.getResult(), rewriter);
}
}
data.setScalar(indexOfr);
for (size_t i = 0; i < loopLimit; ++i) {
if (i == 0) {
data.getOffsetsRef().push_back(indexOfr);
} else {
data.getOffsetsRef().push_back(rewriter.getIndexAttr(0));
}
data.getSizesRef().push_back(rewriter.getIndexAttr(resType.getShape()[i]));
data.getStridesRef().push_back(rewriter.getIndexAttr(0));
}
}
void BlockDataParser::rewriteAddPtr(
triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor,
ConversionPatternRewriter &rewriter,
llvm::SmallDenseMap<Value, BlockData> &known) {
auto insertPoint = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(op);
BlockData data;
parseAddPtr(op, data, op.getLoc(), rewriter, known);
if (auto src = data.getSource();
data.getMemAccTypeRef().isUnstructured() &&
!(src && isa_and_nonnull<triton::IntToPtrOp>(src.getDefiningOp()))) {
rewriteAddPtrToUnstrucMemAcc(op, adaptor, rewriter, data);
LLVM_DEBUG({ llvm::dbgs() << *getModuleOpFromOperation(op) << "\n"; });
return;
}
if (data.getSizesRef().size() == 0) {
data.getSizesRef().push_back(rewriter.getIndexAttr(1));
data.getStridesRef().push_back(rewriter.getIndexAttr(0));
data.getOffsetsRef().push_back(data.getScalarRef());
}
ArrayRef<int64_t> resultShape;
SmallVector<int64_t> stubScalarTypeShape(1, 1);
if (auto shapedType = dyn_cast<ShapedType>(op.getResult().getType())) {
resultShape = shapedType.getShape();
} else {
assert(data.getRank() == 1);
resultShape = stubScalarTypeShape;
}
known[op.getResult()] = data;
auto inferedSize = 1;
auto hoistDim = op->getAttrOfType<IntegerAttr>("hoist_dim");
for (int i = data.getSizesRef().size() - 1; i >= 0; i--) {
auto strideConst = getConstantIntValue(data.getStridesRef()[i]);
auto sizeConst = getConstantIntValue(data.getSizesRef()[i]);
assert(sizeConst.has_value());
bool shouldReplaceStride = (sizeConst.value() == 1) || (hoistDim && hoistDim.getValue() == i);
if (shouldReplaceStride && strideConst && strideConst.value() == 0) {
data.getStridesRef()[i] = rewriter.getIndexAttr(inferedSize);
}
inferedSize *= sizeConst.value();
}
if (auto intToPtrOp =
dyn_cast_or_null<triton::IntToPtrOp>(data.getSourceRef().getDefiningOp())) {
auto rtype = cast<triton::PointerType>(intToPtrOp.getResult().getType());
auto memrefType =
MemRefType::get({ShapedType::kDynamic}, rtype.getPointeeType());
auto hivmPointCastOp = rewriter.create<hivm::PointerCastOp>(
intToPtrOp.getLoc(), memrefType, ValueRange{intToPtrOp.getSrc()});
data.setSource(hivmPointCastOp.getResult());
}
if (data.hasResElemTy()) {
auto memrefType = dyn_cast<BaseMemRefType>(data.getSourceRef().getType())
.cloneWith(std::nullopt, data.getResElemTyRef());
UnrealizedConversionCastOp castOp =
rewriter.create<mlir::UnrealizedConversionCastOp>(
op.getLoc(), memrefType, data.getSourceRef());
data.setSource(castOp.getOutputs()[0]);
}
memref::ReinterpretCastOp castOp =
data.createCastOp(resultShape, op.getLoc(), rewriter);
Value src = castOp.getResult();
LLVM_DEBUG({
llvm::dbgs() << "cast MemRefType:\n";
castOp.getOperation()->print(llvm::dbgs(),
OpPrintingFlags().printGenericOpForm());
llvm::dbgs() << "\n";
});
rewriter.replaceOp(op, src);
rewriter.restoreInsertionPoint(insertPoint);
}
OpFoldResult accumulatePotentialOffsetOnBase(
triton::MakeTensorPtrOp op, Value base, OpFoldResult offset,
ConversionPatternRewriter &rewriter) {
if (auto baseRecast = base.getDefiningOp<memref::ReinterpretCastOp>()) {
assert(isa<triton::AddPtrOp>(op.getBase().getDefiningOp()) &&
"base of MakeTensorPtrOp only comes from native ptr or AddPtrOp");
return addOpFoldResult(offset, baseRecast.getConstifiedMixedOffset(),
op.getLoc(), rewriter);
}
return offset;
}
void BlockDataParser::rewriteCustomOp(hivm::CustomOp op, hivm::CustomOp::Adaptor &adaptor,
ConversionPatternRewriter &rewriter, const llvm::SmallDenseMap<Value, BlockData> &known)
{
auto ip = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(op);
auto loc = op.getLoc();
llvm::SmallVector<Value> newInputs;
llvm::SmallVector<Value> newOutputs;
auto convertIntToPtr = [&rewriter](BlockData& data) {
if (auto intToPtrOp =
dyn_cast_or_null<triton::IntToPtrOp>(data.getSourceRef().getDefiningOp())) {
auto rtype = cast<triton::PointerType>(intToPtrOp.getResult().getType());
auto memrefType =
MemRefType::get({ShapedType::kDynamic}, rtype.getPointeeType());
auto hivmPointCastOp = rewriter.create<hivm::PointerCastOp>(
intToPtrOp.getLoc(), memrefType, ValueRange{intToPtrOp.getSrc()});
if (data.getSizesRef().size() == 0) {
data.getSizesRef().push_back(rewriter.getIndexAttr(1));
if (data.getScalarRef().isNull()) {
data.getOffsetsRef().push_back(rewriter.getIndexAttr(0));
} else {
data.getOffsetsRef().push_back(data.getScalarRef());
}
data.getStridesRef().push_back(rewriter.getIndexAttr(1));
}
data.setSource(hivmPointCastOp.getResult());
}
};
for (auto in: op.getInputs()) {
in = rewriter.getRemappedValue(in);
BlockData blockData;
auto curInput = in;
if (llvm::isa<triton::PointerType>(in.getType())) {
parse(in, blockData, loc, rewriter, known);
convertIntToPtr(blockData);
curInput = blockData.createCastOp({ShapedType::kDynamic}, loc, rewriter);
} else if (auto tensor = llvm::dyn_cast<RankedTensorType>(in.getType())) {
if (llvm::isa<triton::PointerType>(tensor.getElementType())) {
parse(in, blockData, loc, rewriter, known);
convertIntToPtr(blockData);
curInput = blockData.createCastOp(tensor.getShape(), loc, rewriter);
}
}
newInputs.emplace_back(curInput);
}
for (auto out : op.getOutputs()) {
auto tensorTy = llvm::cast<RankedTensorType>(out.getType());
if (llvm::isa<triton::PointerType>(tensorTy.getElementType())) {
continue;
}
newOutputs.emplace_back(rewriter.getRemappedValue(out));
}
llvm::SmallVector<Type> resultTypes;
for (auto ty : op->getResultTypes()) {
if (auto ptrTy = llvm::dyn_cast<triton::PointerType>(ty)) {
resultTypes.emplace_back(MemRefType::get({ShapedType::kDynamic}, ptrTy.getPointeeType()));
continue;
}
if (auto tensorTy = llvm::dyn_cast<RankedTensorType>(ty)) {
if (auto ptrTy = llvm::dyn_cast<triton::PointerType>(tensorTy.getElementType())) {
resultTypes.emplace_back(MemRefType::get(tensorTy.getShape(), ptrTy.getPointeeType()));
continue;
}
}
resultTypes.emplace_back(ty);
}
auto newCustomOp = rewriter.create<hivm::CustomOp>(loc, resultTypes, op.getName(), newInputs, newOutputs,
adaptor.getTempBuffers());
auto operandSegmentSizesAttr = newCustomOp->getAttr("operandSegmentSizes");
newCustomOp->setAttrs(op->getAttrs());
newCustomOp->setAttr("operandSegmentSizes", operandSegmentSizesAttr);
rewriter.replaceOp(op, newCustomOp.getResults());
rewriter.restoreInsertionPoint(ip);
}
memref::ReinterpretCastOp
createRedundantOp(triton::MakeTensorPtrOp op,
ConversionPatternRewriter &rewriter,
BlockData &data) {
auto loc = op.getLoc();
auto resultSizes = data.getSizes();
auto resultOffsets = data.getOffsets();
data.getSizesRef().clear();
data.getOffsetsRef().clear();
data.getSizesRef() =
std::move(llvm::map_to_vector(op.getShape(), [&](Value v) {
return getOpFoldResultOfLayoutInfo(v, rewriter);
}));
SmallVector<OpFoldResult> curOffsets(op.getOffsets().size(),
rewriter.getIndexAttr(0));
curOffsets.front() = accumulatePotentialOffsetOnBase(
op, rewriter.getRemappedValue(op.getBase()), curOffsets.front(),
rewriter);
for (auto offset : curOffsets) {
data.getOffsetsRef().push_back(offset);
}
SmallVector<int64_t> staticShapes;
SmallVector<Value> dynamicShapes;
dispatchIndexOpFoldResults(data.getSizesRef(), dynamicShapes, staticShapes);
auto castOp = data.createCastOp(staticShapes, loc, rewriter);
data.getSizesRef().clear();
for (auto &s : resultSizes) {
data.getSizesRef().push_back(s);
}
data.getOffsetsRef().clear();
for (auto &offset : resultOffsets) {
data.getOffsetsRef().push_back(offset);
}
return castOp;
}
void BlockDataParser::rewriteMakeTensorPtrOp(
triton::MakeTensorPtrOp op, Value base,
ConversionPatternRewriter &rewriter,
llvm::SmallDenseMap<Value, BlockData> &known) {
Location loc = op.getLoc();
BlockData data;
auto orderSize = op.getOrder().size();
BlockDataParser::parse(op.getBase(), data, loc, rewriter, known);
if (data.hasResElemTy()) {
auto memrefType = dyn_cast<BaseMemRefType>(data.getSourceRef().getType())
.cloneWith(std::nullopt, data.getResElemTyRef());
UnrealizedConversionCastOp castOp =
rewriter.create<mlir::UnrealizedConversionCastOp>(loc, memrefType,
data.getSourceRef());
data.setSource(castOp.getOutputs()[0]);
} else {
data.setSource(rewriter.getRemappedValue(op.getBase()));
}
data.getOffsetsRef() =
std::move(llvm::map_to_vector(op.getOffsets(), [&](Value v) {
auto zeroVal = rewriter.create<arith::ConstantOp>(loc, rewriter.getI32IntegerAttr(0));
v = rewriter.create<arith::MaxSIOp>(loc, v, zeroVal);
return getOpFoldResultOfLayoutInfo(v, rewriter);
}));
data.getStridesRef() =
std::move(llvm::map_to_vector(op.getStrides(), [&](Value v) {
return getOpFoldResultOfLayoutInfo(v, rewriter);
}));
SmallVector<OpFoldResult> newOffsets;
for (auto [offset, stride] :
llvm::zip(data.getOffsetsRef(), data.getStridesRef()))
newOffsets.push_back(mulOpFoldResult(offset, stride, loc, rewriter));
newOffsets.front() = accumulatePotentialOffsetOnBase(
op, base, newOffsets.front(), rewriter);
data.getOffsetsRef().clear();
for (auto offset : newOffsets) {
data.getOffsetsRef().push_back(offset);
}
ArrayRef<int64_t> resultShape;
auto pointerType = cast<mlir::triton::PointerType>(op.getResult().getType());
if (auto shapedType = dyn_cast<ShapedType>(pointerType.getPointeeType())) {
resultShape = shapedType.getShape();
data.getSizesRef().clear();
for (auto dim_size : resultShape) {
data.getSizesRef().push_back(
IntegerAttr::get(IntegerType::get(op.getContext(), 64), dim_size));
}
} else {
SmallVector<int64_t> scalarShape(1, 1);
resultShape = scalarShape;
assert(data.getRank() == 1);
}
auto redundantOp = createRedundantOp(op, rewriter, data);
redundantOp->setAttr("tensor_ptr_full_shape", rewriter.getUnitAttr());
data.setSource(redundantOp.getResult());
known[op.getResult()] = data;
auto castOp = data.createCastOp(resultShape, loc, rewriter);
rewriter.replaceOp(op, castOp.getResult());
if (nd2nzFlag) {
auto basePtr = castOp.getResult();
int original_rank = op.getShape().size() + 1;
std::string shapeStr;
auto baseMemrefType = mlir::dyn_cast<MemRefType>(basePtr.getType());
assert(baseMemrefType && "basePtr is not a memref type");
auto shape = baseMemrefType.getShape();
if (auto memrefType = mlir::dyn_cast<MemRefType>(basePtr.getType())) {
for (auto dim : memrefType.getShape()) {
shapeStr += llvm::formatv("_{0}", dim);
}
}
std::string elemTypeName;
Type elemType = baseMemrefType.getElementType();
if (auto intType = mlir::dyn_cast<mlir::IntegerType>(elemType)) {
elemTypeName = llvm::formatv("i{0}", intType.getWidth());
} else if (auto floatType = mlir::dyn_cast<mlir::FloatType>(elemType)) {
std::string floatTypeName;
llvm::raw_string_ostream os(floatTypeName);
floatType.print(os);
os.flush();
elemTypeName = floatTypeName;
} else {
std::string typeName;
llvm::raw_string_ostream os(typeName);
elemType.print(os);
os.flush();
elemTypeName = typeName;
}
std::string memrefTypeStr;
llvm::raw_string_ostream os(memrefTypeStr);
baseMemrefType.print(os);
os.flush();
std::string laydbgsuffix;
for (char c : memrefTypeStr) {
if ((c >= '0' && c <= '9') || (c >= 'a' && c <= 'z') ||
(c >= 'A' && c <= 'Z') || c == '_' || c == ',' || c == '[' ||
c == ']') {
laydbgsuffix += c;
}
}
auto funcName = rewriter.getStringAttr(
llvm::formatv("__hmf_original_shape{0}d{1}_{2}_{3}", original_rank,
shapeStr, elemTypeName, laydbgsuffix));
MemRefType targetMemrefType = MemRefType::get(
baseMemrefType.getShape(), baseMemrefType.getElementType(),
baseMemrefType.getLayout());
const int vectorSize = 4;
SmallVector<Type, vectorSize> srcElemTys;
for (auto sz : op.getShape()) {
srcElemTys.push_back(sz.getType());
}
srcElemTys.push_back(targetMemrefType);
Type dstElemTy = rewriter.getNoneType();
FunctionType hintFuncType =
FunctionType::get(rewriter.getContext(), srcElemTys, {dstElemTy});
auto mod = SymbolTable::getNearestSymbolTable(op);
auto extFunc = dyn_cast_or_null<SymbolOpInterface>(
SymbolTable::lookupSymbolIn(mod, funcName));
SmallVector<Value, vectorSize> args;
for (auto sz : op.getShape()) {
args.push_back(sz);
}
args.push_back(basePtr);
if (!extFunc) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(&mod->getRegion(0).front());
extFunc = rewriter.create<func::FuncOp>(rewriter.getUnknownLoc(),
funcName, hintFuncType);
extFunc.setPrivate();
extFunc->setAttr(LLVM::LLVMDialect::getReadnoneAttrName(),
UnitAttr::get(rewriter.getContext()));
rewriter.setInsertionPoint(op);
}
rewriter.create<func::CallOp>(loc, funcName, dstElemTy, args);
}
}
void BlockDataParser::rewriteAdvanceOp(
triton::AdvanceOp op, ConversionPatternRewriter &rewriter,
llvm::SmallDenseMap<Value, BlockData> &known) {
OpBuilder::InsertionGuard insertionGuard(rewriter);
rewriter.setInsertionPoint(op);
auto loc = op.getLoc();
BlockData blockData;
parse(op.getOperand(0), blockData, loc, rewriter, known);
known[op.getResult()] = blockData;
auto inferedSize = 1;
for (int i = blockData.getSizesRef().size() - 1; i >= 0; i--) {
auto strideConst = getConstantIntValue(blockData.getStridesRef()[i]);
auto sizeConst = getConstantIntValue(blockData.getSizesRef()[i]);
assert(sizeConst.has_value());
if (sizeConst.value() == 1 && strideConst && strideConst.value() == 0) {
blockData.getStridesRef()[i] = rewriter.getIndexAttr(inferedSize);
}
inferedSize *= sizeConst.value();
}
SmallVector<OpFoldResult> incrementOffsets =
llvm::map_to_vector(op.getOffsets(), [&](Value offset) {
return getOpFoldResultOfLayoutInfo(offset, rewriter);
});
SmallVector<OpFoldResult> newOffsets;
for (const auto [increment, originalOffset, stride] :
llvm::zip(incrementOffsets, blockData.getOffsetsRef(),
blockData.getStridesRef())) {
auto curDimOffset =
addOpFoldResult(mulOpFoldResult(increment, stride, loc, rewriter),
originalOffset, loc, rewriter);
newOffsets.push_back(curDimOffset);
}
blockData.getOffsetsRef().clear();
for (auto offset : newOffsets)
blockData.getOffsetsRef().push_back(offset);
SmallVector<int64_t> scalarShape(1, 1);
ArrayRef<int64_t> resultShape;
auto pointerType = cast<mlir::triton::PointerType>(op.getResult().getType());
if (auto shapedType = dyn_cast<ShapedType>(pointerType.getPointeeType())) {
resultShape = shapedType.getShape();
} else {
resultShape = scalarShape;
assert(blockData.getRank() == 1);
}
auto newOp = blockData.createCastOp(resultShape, loc, rewriter);
rewriter.replaceOp(op, newOp.getResult());
known[newOp.getResult()] = blockData;
}
template <typename T>
std::enable_if_t<std::is_same_v<T, scf::YieldOp> ||
std::is_same_v<T, scf::ConditionOp>>
BlockDataParser::rewriteTerminator(
T op, ConversionPatternRewriter &rewriter,
const llvm::SmallDenseSet<size_t> &blockArgIdxSet, ArrayRef<int64_t> iterArgIdxMap,
const llvm::SmallDenseMap<Value, BlockData> &known) {
OpBuilder::InsertionGuard insertionGuard{rewriter};
rewriter.setInsertionPoint(op);
auto adaptor = typename T::Adaptor(op);
ValueRange args;
if constexpr (std::is_same_v<T, scf::YieldOp>) {
args = adaptor.getOperands();
} else {
args = adaptor.getArgs();
}
SmallVector<BlockData, 5> initArgState;
SmallVector<Value> operands;
operands.reserve(op->getNumOperands());
for (const auto &[oper, newIterArgIdx]: llvm::zip_equal(args, iterArgIdxMap)) {
if (newIterArgIdx != -1)
operands.push_back(oper);
}
for (auto [i, v] : llvm::enumerate(args)) {
if (auto mappedV = rewriter.getRemappedValue(v)) {
if (v.getDefiningOp<triton::AddPtrOp>() ||
v.getDefiningOp<triton::AdvanceOp>() ||
v.getDefiningOp<triton::MakeTensorPtrOp>()) {
if (auto castOp = mappedV.getDefiningOp<memref::ReinterpretCastOp>()) {
v = castOp;
} else {
llvm_unreachable("mapped value defined by an unexpected op");
}
} else {
if (isa<TensorType>(mappedV.getType()) &&
isa<triton::PointerType>(
dyn_cast<TensorType>(mappedV.getType()).getElementType()))
llvm_unreachable("unsupported scenario where a value is a tensor of "
"pointers but not produced by AddPtrOp");
v = mappedV;
}
}
if (blockArgIdxSet.find(i) == blockArgIdxSet.end())
continue;
auto reintCastOp = v.getDefiningOp<memref::ReinterpretCastOp>();
assert(
reintCastOp ||
(isa<TensorType>(v.getType()) &&
isa<IntegerType>(dyn_cast<TensorType>(v.getType()).getElementType())));
BlockData state;
if (reintCastOp) {
parseReinterpretCast(reintCastOp, state, op.getLoc(), rewriter, known);
} else {
parse(v, state, op.getLoc(), rewriter, known);
}
initArgState.push_back(state);
}
for (auto state : initArgState) {
for (auto offset : state.getOffsetsRef()) {
if (isa<Attribute>(offset)) {
auto constOffset = cast<Attribute>(offset);
assert(isa<IntegerAttr>(constOffset) &&
dyn_cast<IntegerAttr>(constOffset).getInt() == 0 &&
"attribute offsets should be zeroes");
auto constOp = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(0));
operands.push_back(constOp.getResult());
} else {
operands.push_back(cast<Value>(offset));
}
}
auto sizesRef = state.getSizesRef();
size_t dimIdx = 0;
for (OpFoldResult stride : state.getStridesRef()) {
if (isa<Attribute>(stride)) {
auto constStride = cast<Attribute>(stride);
assert(isa<IntegerAttr>(constStride) && "attribute strides should be IntegerAttr");
auto strideVal = dyn_cast<IntegerAttr>(constStride).getInt();
bool isSizeOne = (dimIdx < sizesRef.size() &&
isa<Attribute>(sizesRef[dimIdx]) &&
cast<IntegerAttr>(cast<Attribute>(sizesRef[dimIdx])).getInt() == 1);
assert((strideVal == 1 || (strideVal == 0 && isSizeOne)) &&
"attribute strides should be ones");
auto constOp = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(1));
operands.push_back(constOp.getResult());
} else {
operands.push_back(cast<Value>(stride));
}
dimIdx++;
}
}
rewriter.setInsertionPointAfter(op);
Operation *newOp;
if constexpr (std::is_same_v<T, scf::YieldOp>) {
newOp = rewriter.replaceOpWithNewOp<scf::YieldOp>(op, operands);
} else {
newOp = rewriter.replaceOpWithNewOp<scf::ConditionOp>(op, op.getCondition(), operands);
}
assert(op->getNumResults() == 0);
LLVM_DEBUG({
llvm::dbgs() << "new terminator: ";
newOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
llvm::dbgs() << "\n";
});
}
bool isUsedWithCondition(
Value v,
std::function<bool(OpOperand *)> cond,
int depth = 0,
llvm::SmallSetVector<Value, 8> *visited = nullptr) {
llvm::SmallSetVector<Value, 8> localVisited;
if (!visited) {
visited = &localVisited;
}
if (visited->contains(v)) {
return false;
}
visited->insert(v);
for (auto &use : v.getUses()) {
auto *user = use.getOwner();
if (user->hasAttr(ConverterUtils::discreteAttrName) ||
isa<tensor::ExtractOp>(user))
continue;
if (cond(&use))
return true;
if (auto loopOp = dyn_cast<LoopLikeOpInterface>(user);
loopOp && !loopOp->hasAttr("ExtractedLoadOrStore")) {
Value tiedArg = loopOp.getTiedLoopRegionIterArg(&use);
if (tiedArg && isUsedWithCondition(tiedArg, cond, depth + 1, visited))
return true;
} else if (auto yieldOp = dyn_cast<scf::YieldOp>(user);
yieldOp && !isa<scf::WhileOp>(user->getParentOp())) {
if (depth && isUsedWithCondition(yieldOp->getParentOp()->getResult(use.getOperandNumber()),
cond, depth - 1, visited))
return true;
} else if (auto conditionOp = dyn_cast<scf::ConditionOp>(user);
conditionOp && use.getOperandNumber() > 0) {
auto whileOp = cast<scf::WhileOp>(conditionOp->getParentOp());
if (depth && isUsedWithCondition(whileOp->getResult(use.getOperandNumber() - 1), cond, depth - 1, visited))
return true;
if (isUsedWithCondition(whileOp.getAfterArguments()[use.getOperandNumber() - 1], cond, depth, visited))
return true;
}
for (auto res : user->getResults()) {
if (isUsedWithCondition(res, cond, depth, visited))
return true;
}
}
return false;
}
Value createFromData(RankedTensorType resType, const BlockData &data, const Location &loc, OpBuilder &builder, bool isMaskIterArg) {
auto resShape = resType.getShape();
Value newRes = nullptr;
for (size_t i = 0; i < resShape.size(); i++) {
auto axisType = RankedTensorType::get({resShape[i]}, resType.getElementType());
auto axisI32Type = RankedTensorType::get({resShape[i]}, builder.getIntegerType(32));
Value axisValue = builder.create<triton::MakeRangeOp>(loc, axisI32Type, 0, resShape[i]);
if (axisType != axisI32Type) {
axisValue = builder.create<arith::ExtSIOp>(loc, axisType, axisValue);
}
Value offset = cast<Value>(data.getOffset(i));
Value offsetValue = builder.create<arith::IndexCastOp>(loc, resType.getElementType(), offset);
offsetValue = builder.create<triton::SplatOp>(loc, axisType, offsetValue);
Value stride = cast<Value>(data.getStride(i));
if (!isMaskIterArg) {
Value strideValue = builder.create<arith::IndexCastOp>(loc, resType.getElementType(), stride);
strideValue = builder.create<triton::SplatOp>(loc, axisType, strideValue);
axisValue = builder.create<arith::MulIOp>(loc, axisValue, strideValue);
}
axisValue = builder.create<arith::AddIOp>(loc, axisValue, offsetValue);
for (size_t j = 0; j < resShape.size(); j++) {
if (i != j)
axisValue = builder.create<triton::ExpandDimsOp>(loc, axisValue, j);
}
axisValue = builder.create<triton::BroadcastOp>(loc, resType, axisValue);
if (newRes) {
newRes = builder.create<arith::AddIOp>(loc, newRes, axisValue);
} else {
newRes = axisValue;
}
}
return newRes;
}
void BlockDataParser::rewriteLoopOp(
LoopLikeOpInterface op, ConversionPatternRewriter &rewriter,
llvm::SmallDenseMap<Value, BlockData> &known) {
SmallVector<Value> newInitArgs;
SmallVector<int64_t> iterArgIdxMap;
SmallVector<bool> maskIterArgs;
int64_t argCnt = 0;
SmallVector<std::pair<int, BlockData>, 5> initArgIndexIfBlockData;
SmallVector<std::pair<int, BlockData>, 5> knownPtrsTmp;
llvm::SmallDenseSet<size_t> blockArgIdxSet;
for (auto [i, arg] : llvm::enumerate(op.getInits())) {
auto mappedV = rewriter.getRemappedValue(arg);
memref::ReinterpretCastOp reintCastOp;
maskIterArgs.push_back(false);
if (mappedV) {
assert(!(isa<BlockArgument>(mappedV) &&
isa<UnrankedMemRefType>(mappedV.getType())) &&
"cannot take pointer block argument as init arg for for loop");
if (auto reinterpretCastOp = mappedV.getDefiningOp<memref::ReinterpretCastOp>()) {
reintCastOp = reinterpretCastOp;
newInitArgs.push_back(mappedV);
iterArgIdxMap.push_back(argCnt++);
} else {
newInitArgs.push_back(mappedV);
iterArgIdxMap.push_back(argCnt++);
}
} else {
newInitArgs.push_back(arg);
iterArgIdxMap.push_back(argCnt++);
}
auto indexTensor =
isa<TensorType>(arg.getType()) &&
isa<IntegerType>(cast<TensorType>(arg.getType()).getElementType()) &&
cast<IntegerType>(cast<TensorType>(arg.getType()).getElementType()).getWidth() != 1 &&
isUsedWithCondition(op.getRegionIterArgs()[i], [](OpOperand *use) {
auto *user = use->getOwner();
return isa<triton::AddPtrOp>(user) ||
(isa<triton::LoadOp>(user) && use->getOperandNumber() == 1) ||
(isa<triton::StoreOp>(user) && use->getOperandNumber() == 2);
});
if (!reintCastOp && !indexTensor)
continue;
BlockData data;
if (reintCastOp) {
parseReinterpretCast(reintCastOp, data, op.getLoc(), rewriter,
llvm::SmallDenseMap<Value, BlockData>(0));
} else {
parse(arg, data, op.getLoc(), rewriter,
llvm::SmallDenseMap<Value, BlockData>(0));
}
maskIterArgs[i] = indexTensor && isUsedWithCondition(op.getRegionIterArgs()[i], [](OpOperand *use) {
auto *user = use->getOwner();
return (isa<triton::LoadOp>(user) && use->getOperandNumber() == 1) ||
(isa<triton::StoreOp>(user) && use->getOperandNumber() == 2);
});
if (indexTensor) {
newInitArgs.back() = nullptr;
iterArgIdxMap.back() = -1;
argCnt--;
}
initArgIndexIfBlockData.push_back(std::make_pair(i, data));
}
auto origIp = rewriter.saveInsertionPoint();
rewriter.setInsertionPoint(op);
for (auto [i, data] : initArgIndexIfBlockData) {
for (auto &dataOffset : data.getOffsetsRef()) {
if (isa<Attribute>(dataOffset)) {
auto constDataOffset = cast<Attribute>(dataOffset);
assert(isa<IntegerAttr>(constDataOffset));
auto constOp = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(
dyn_cast<IntegerAttr>(constDataOffset).getInt()));
newInitArgs.push_back(constOp.getResult());
dataOffset = constOp.getResult();
} else {
assert(isa<IndexType>(cast<Value>(dataOffset).getType()));
newInitArgs.push_back(cast<Value>(dataOffset));
}
}
for (auto &dataStride : data.getStridesRef()) {
if (isa<Attribute>(dataStride)) {
auto constDataStride = cast<Attribute>(dataStride);
assert(isa<IntegerAttr>(constDataStride));
auto constOp = rewriter.create<arith::ConstantOp>(
op.getLoc(), rewriter.getIndexAttr(
dyn_cast<IntegerAttr>(constDataStride).getInt()));
newInitArgs.push_back(constOp.getResult());
dataStride = constOp.getResult();
} else {
assert(isa<IndexType>(cast<Value>(dataStride).getType()));
newInitArgs.push_back(cast<Value>(dataStride));
}
}
knownPtrsTmp.push_back(std::make_pair(i, data));
blockArgIdxSet.insert(i);
if (newInitArgs[i] && newInitArgs[i].getDefiningOp<memref::ReinterpretCastOp>()) {
SmallVector<int64_t> resultShape;
for (auto size : data.getSizesRef()) {
auto constSize = getConstantIntValue(size);
assert(constSize && "expected constant size");
resultShape.push_back(constSize.value());
}
auto castOp = data.createCastOp(resultShape, op.getLoc(), rewriter);
if (resultShape.size() > 1) {
auto originalOffset = dyn_cast<Value>(data.getOffsetsRef()[0]);
for (auto &offsets : newInitArgs) {
if (offsets == originalOffset) {
offsets = castOp.getOffsets()[0];
break;
}
}
data.getOffsetsRef()[0] = castOp.getOffsets()[0];
}
LLVM_DEBUG({
llvm::dbgs() << "new reinterpret_cast with dynamic sizes "
"and offsets:";
castOp->print(llvm::dbgs(), OpPrintingFlags().printGenericOpForm());
llvm::dbgs() << "\n";
});
newInitArgs[i] = castOp.getResult();
}
}
rewriter.restoreInsertionPoint(origIp);
IRMapping mapping;
LoopLikeOpInterface newOp;
auto newInits = to_vector(make_filter_range(newInitArgs, [](Value v) { return v != nullptr; }));
auto commonBodyBuilder = [&](OpBuilder &b, Location loc, bool useInit, ValueRange newRegionArgs, Region ®ion, Block::BlockArgListType regionArgs, ArrayRef<bool> isUsedForRegionArgs, ArrayRef<bool> maskIterArgs) {
auto newArgIter = newRegionArgs.begin();
for (const auto &[regionArg, isUsedForRegionArg]: llvm::zip(regionArgs, isUsedForRegionArgs)) {
if (isUsedForRegionArg) {
mapping.map(regionArg, *newArgIter);
++newArgIter;
}
}
if (useInit) {
for (auto [i, data] : knownPtrsTmp) {
for (auto &offset: data.getOffsetsRef()) {
offset = *newArgIter;
++newArgIter;
}
for (auto &stride: data.getStridesRef()) {
stride = *newArgIter;
++newArgIter;
}
auto regionArg = regionArgs[i];
auto key = mapping.lookupOrNull(regionArg);
if (!key) {
key = createFromData(cast<RankedTensorType>(regionArg.getType()), data, op.getLoc(), rewriter, maskIterArgs[i]);
mapping.map(regionArg, key);
}
known.insert(std::make_pair(key, data));
}
} else {
for (auto [i, isUsedForRegionArg]: llvm::enumerate(isUsedForRegionArgs)) {
if (!isUsedForRegionArg) {
BlockData data;
auto regionArg = regionArgs[i];
auto regionArgType = cast<RankedTensorType>(regionArg.getType());
data.getOffsetsRef().resize(regionArgType.getRank());
data.getStridesRef().resize(regionArgType.getRank());
for (auto &offset: data.getOffsetsRef()) {
offset = *newArgIter;
++newArgIter;
}
for (auto &dim: regionArgType.getShape()) {
data.getSizesRef().push_back(rewriter.getIndexAttr(dim));
}
for (auto &stride: data.getStridesRef()) {
stride = *newArgIter;
++newArgIter;
}
auto key = mapping.lookupOrNull(regionArg);
if (!key) {
key = createFromData(regionArgType, data, op.getLoc(), rewriter, maskIterArgs[i]);
mapping.map(regionArg, key);
}
known.insert(std::make_pair(key, data));
}
}
}
for (auto &bodyOp : region.getOps())
b.clone(bodyOp, mapping);
};
for (const auto &[initArg, newInitArg]: llvm::zip(op.getInits(), newInitArgs)) {
if (newInitArg) {
mapping.map(initArg, newInitArg);
}
}
if (auto forOp = dyn_cast<scf::ForOp>(op.getOperation())) {
SmallVector<bool> usedForRegionArgs;
for (auto newInitArg: newInitArgs) {
usedForRegionArgs.push_back(newInitArg ? true:false);
}
newOp = rewriter.create<scf::ForOp>(
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
forOp.getStep(), newInits,
[&](OpBuilder &b, Location loc, Value iv, ValueRange args) {
mapping.map(forOp.getInductionVar(), iv);
commonBodyBuilder(b, loc, true, args, forOp.getRegion(), op.getRegionIterArgs(), usedForRegionArgs, maskIterArgs);
});
auto newResultIter = newOp->result_begin();
rewriter.setInsertionPointAfter(newOp);
for (const auto &[res, regionArg, newIterArgIdx, mask]: llvm::zip_equal(op->getResults(), op.getRegionIterArgs(), iterArgIdxMap, maskIterArgs)) {
if (newIterArgIdx != -1) {
rewriter.replaceAllUsesWith(res, *newResultIter);
++newResultIter;
} else {
auto key = mapping.lookup(regionArg);
auto data = known.at(key);
for (auto &offset : data.getOffsetsRef())
offset = newOp.getTiedLoopResult(cast<BlockArgument>(cast<Value>(offset)));
for (auto &stride : data.getStridesRef())
stride = newOp.getTiedLoopResult(cast<BlockArgument>(cast<Value>(stride)));
auto newRes = createFromData(cast<RankedTensorType>(regionArg.getType()), data, op.getLoc(), rewriter, mask);
rewriter.replaceAllUsesWith(res, newRes);
}
}
} else if (auto whileOp = dyn_cast<scf::WhileOp>(op.getOperation())) {
SmallVector<Type> resultTypes;
SmallVector<bool> usedForBeforeRegionArgs;
SmallVector<bool> usedForAfterRegionArgs;
llvm::SmallDenseSet<size_t> blockArgIdxSetForAfter;
SmallVector<int64_t> iterArgIdxMapForAfter;
SmallVector<bool> maskIterArgsForAfter(whileOp->getNumResults());
int64_t indexCnt = 0;
for (auto newInitArg: newInitArgs) {
usedForBeforeRegionArgs.push_back(newInitArg ? true:false);
}
for (size_t i = 0; i < whileOp->getNumResults(); i++) {
auto resType = whileOp->getResultTypes()[i];
auto indexTensor =
isa<RankedTensorType>(resType) &&
isa<IntegerType>(cast<RankedTensorType>(resType).getElementType()) &&
isUsedWithCondition(whileOp.getAfterArguments()[i], [](OpOperand *use) {
auto *user = use->getOwner();
return isa<triton::AddPtrOp>(user) ||
(isa<triton::LoadOp>(user) && use->getOperandNumber() == 1) ||
(isa<triton::StoreOp>(user) && use->getOperandNumber() == 2);
});
if (indexTensor) {
indexCnt += 2 * cast<RankedTensorType>(resType).getRank();
usedForAfterRegionArgs.push_back(false);
iterArgIdxMapForAfter.push_back(-1);
maskIterArgsForAfter[i] = isUsedWithCondition(whileOp.getAfterArguments()[i], [](OpOperand *use) {
auto *user = use->getOwner();
return (isa<triton::LoadOp>(user) && use->getOperandNumber() == 1) ||
(isa<triton::StoreOp>(user) && use->getOperandNumber() == 2);
});
blockArgIdxSetForAfter.insert(i);
} else {
resultTypes.push_back(resType);
usedForAfterRegionArgs.push_back(true);
iterArgIdxMapForAfter.push_back(argCnt++);
}
}
resultTypes.append(indexCnt, rewriter.getIndexType());
newOp = rewriter.create<scf::WhileOp>(
whileOp.getLoc(), resultTypes, newInits,
[&](OpBuilder &b, Location loc, ValueRange args) {
commonBodyBuilder(b, loc, true, args, whileOp.getBefore(), whileOp.getBeforeArguments(), usedForBeforeRegionArgs, maskIterArgs);
},
[&](OpBuilder &b, Location loc, ValueRange args) {
commonBodyBuilder(b, loc, false, args, whileOp.getAfter(), whileOp.getAfterArguments(), usedForAfterRegionArgs, maskIterArgsForAfter);
});
auto newResultIter = newOp->result_begin();
rewriter.setInsertionPointAfter(newOp);
for (const auto &[res, regionArg, newIterArgIdx, mask]: llvm::zip_equal(op->getResults(), whileOp.getAfterArguments(), iterArgIdxMapForAfter, maskIterArgsForAfter)) {
if (newIterArgIdx != -1) {
rewriter.replaceAllUsesWith(res, *newResultIter);
++newResultIter;
} else {
auto key = mapping.lookup(regionArg);
auto data = known.at(key);
for (auto &offset : data.getOffsetsRef())
offset = newOp->getResult(cast<BlockArgument>(cast<Value>(offset)).getArgNumber());
for (auto &stride : data.getStridesRef())
stride = newOp->getResult(cast<BlockArgument>(cast<Value>(stride)).getArgNumber());
auto newRes = createFromData(cast<RankedTensorType>(regionArg.getType()), data, op.getLoc(), rewriter, mask);
rewriter.replaceAllUsesWith(res, newRes);
}
}
auto conditionOp = cast<scf::WhileOp>(newOp.getOperation()).getConditionOp();
rewriteTerminator(conditionOp, rewriter, blockArgIdxSetForAfter, iterArgIdxMapForAfter, known);
}
newOp->setAttrs(op->getAttrs());
rewriter.eraseOp(op);
for (auto *region : newOp.getLoopRegions()) {
for (auto &bodyOp : region->getOps()) {
if (auto customOp = dyn_cast<hivm::CustomOp>(bodyOp)) {
auto adaptor = hivm::CustomOp::Adaptor(customOp);
rewriteCustomOp(customOp, adaptor, rewriter, known);
} else if (auto addptrOp = dyn_cast<triton::AddPtrOp>(bodyOp)) {
auto adaptor = triton::AddPtrOp::Adaptor(addptrOp);
rewriteAddPtr(addptrOp, adaptor, rewriter, known);
} else if (auto advanceOp = dyn_cast<triton::AdvanceOp>(bodyOp)) {
rewriteAdvanceOp(advanceOp, rewriter, known);
} else if (auto makeTensorPtrOp = dyn_cast<triton::MakeTensorPtrOp>(bodyOp)) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(makeTensorPtrOp);
rewriteMakeTensorPtrOp(makeTensorPtrOp, rewriter.getRemappedValue(makeTensorPtrOp.getBase()), rewriter, known);
} else if (auto loopOp = dyn_cast<LoopLikeOpInterface>(bodyOp);
loopOp && !loopOp->hasAttr("ExtractedLoadOrStore")) {
ConversionPatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(loopOp);
loopOp->removeAttr("UnhandledLoopOp");
rewriteLoopOp(loopOp, rewriter, known);
}
}
}
if (!op.getRegionIterArgs().empty()) {
auto yieldOp = cast<scf::YieldOp>(newOp.getLoopRegions().back()->back().getTerminator());
rewriteTerminator(yieldOp, rewriter, blockArgIdxSet, iterArgIdxMap, known);
}
LLVM_DEBUG({
llvm::dbgs() << "new loop\n";
newOp.getOperation()->print(llvm::dbgs(),
OpPrintingFlags().printGenericOpForm());
llvm::dbgs() << "\n";
});
}
void BlockDataParser::rewriteAddPtrToUnstrucMemAcc(
triton::AddPtrOp op, triton::AddPtrOp::Adaptor &adaptor,
ConversionPatternRewriter &rewriter, BlockData &data) {
auto loc = op.getLoc();
auto &offsets = data.getOffsetsRef();
auto &blockSizes = data.getSizesRef();
auto &strides = data.getStridesRef();
Value ptrOffset = adaptor.getOffset();
Value zeroIdx =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(0));
Value oneIdx =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIndexAttr(1));
auto addptrRes = op.getResult();
assert(addptrRes.hasOneUse() && "Invalid: tt.addptr has multiple users");
auto loadOp = *(addptrRes.user_begin());
auto resTy = cast<ShapedType>(addptrRes.getType());
auto resEPtrTy = resTy.getElementType();
auto resETy = cast<triton::PointerType>(resEPtrTy).getPointeeType();
Value loaded = rewriter.create<tensor::EmptyOp>(loc, blockSizes, resETy);
SmallVector<Value> initArgs;
initArgs.push_back(loaded);
SmallVector<Value> forLBs;
SmallVector<Value> forUBs;
SmallVector<Value> forSteps;
for (auto &s : offsets) {
forLBs.push_back(zeroIdx);
}
for (auto &s : blockSizes) {
forUBs.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, s));
}
for (auto &s : strides) {
forSteps.push_back(oneIdx);
}
SmallVector<Value> ivs;
OpBuilder builder(op);
auto loop = createNestedLoops(
builder, loc, 0, blockSizes.size(), forLBs, forUBs, forSteps, ivs,
initArgs,
[&](OpBuilder &bB, Location bLoc, SmallVector<Value> &allIVs,
ValueRange iterArgs) {
OpBuilder::InsertionGuard g(bB);
bB.setInsertionPointToStart(bB.getBlock());
Value scalarOffsetRaw =
bB.create<tensor::ExtractOp>(bLoc, ptrOffset, allIVs);
Value scalarOffset = bB.create<arith::IndexCastOp>(
bLoc, bB.getIndexType(), scalarOffsetRaw);
OpFoldResult baseOffset = bB.getIndexAttr(0);
for (auto ofr : data.getOffsetsRef()) {
baseOffset = addOpFoldResult(baseOffset, ofr, bLoc, bB);
}
Value baseVal =
getValueOrCreateConstantIndexOp(bB, bLoc, baseOffset);
Value combinedOffset =
bB.create<arith::AddIOp>(bLoc, baseVal, scalarOffset);
data.getOffsetsRef().clear();
data.getOffsetsRef().push_back(combinedOffset);
data.getSizesRef().clear();
data.getSizesRef().push_back(bB.getIndexAttr(1));
data.getStridesRef().clear();
data.getStridesRef().push_back(bB.getIndexAttr(1));
memref::ReinterpretCastOp castOp = data.createCastOp({1}, bLoc, bB);
rewriter.replaceOp(op, castOp);
loadOp->moveAfter(castOp);
loadOp->setAttr("IndirectLoad", UnitAttr::get(op.getContext()));
bB.create<scf::YieldOp>(bLoc, iterArgs);
});
}
}
}