* Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
* Copyright 2018-2020 Philippe Tillet
* Copyright 2020-2022 OpenAI
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "ir.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Types.h"
using namespace mlir;
namespace py = pybind11;
constexpr unsigned kIntegerAttrBitWidth = 64;
struct BufferOpBuilder : public TritonOpBuilder {};
void init_buffer_ir(py::module &&m)
{
m.def("load_dialects", [](MLIRContext &context) {
DialectRegistry registry;
registry.insert<memref::MemRefDialect>();
registry.insert<bufferization::BufferizationDialect>();
context.appendDialectRegistry(registry);
context.loadAllAvailableDialects();
});
py::class_<BufferOpBuilder, TritonOpBuilder>(
m, "buffer_builder", py::module_local(), py::dynamic_attr())
.def(py::init<MLIRContext *>())
.def("get_null_attr", [](BufferOpBuilder &self) { return Attribute(); })
.def("get_str_array_attr",
[](BufferOpBuilder &self, const std::vector<std::string> &array) -> ArrayAttr {
auto strRefVec = to_vector(llvm::map_range(array, [](const auto &s) { return llvm::StringRef(s); }));
return self.getBuilder().getStrArrayAttr(llvm::ArrayRef<StringRef> {strRefVec});
})
.def("alloc",
[](BufferOpBuilder &self, Type memrefType) -> Value {
return self.create<memref::AllocOp>(
mlir::cast<MemRefType>(memrefType));
})
.def("to_buffer",
[](BufferOpBuilder &self, Value &src, const Attribute &addressSpace) -> Value {
auto tensorType = dyn_cast<RankedTensorType>(src.getType());
if (!tensorType) {
llvm::report_fatal_error("to_buffer: src must be tensor type");
}
auto memrefType =
MemRefType::get(tensorType.getShape(), tensorType.getElementType(), MemRefLayoutAttrInterface {});
Operation *memref = self.create<bufferization::ToMemrefOp>(memrefType, src);
if (addressSpace) {
memref = self.create<memref::MemorySpaceCastOp>(
MemRefType::get(memrefType.getShape(), memrefType.getElementType(), memrefType.getLayout(),
addressSpace),
memref->getResult(0));
}
return memref->getResult(0);
})
.def("to_tensor",
[](BufferOpBuilder &self, Value &src, bool writable) -> Value {
const auto &memrefType = mlir::cast<MemRefType>(src.getType());
auto hasAddressSpace = memrefType.getMemorySpace();
if (hasAddressSpace) {
return self.create<bufferization::ToTensorOp>(
self.create<memref::MemorySpaceCastOp>(
MemRefType::get(memrefType.getShape(),
memrefType.getElementType(),
memrefType.getLayout()),
src),
true, writable);
}
return self.create<bufferization::ToTensorOp>(src, true, writable);
})
.def("subview",
[](BufferOpBuilder &self, Value source, std::vector<Value> &offsets,
const std::vector<int64_t> &sizes,
const std::vector<int64_t> &strides) -> Value {
SmallVector<mlir::OpFoldResult> mixedOffsets;
auto *context = self.getBuilder().getContext();
auto &builder = self.getBuilder();
auto sourceType = mlir::cast<MemRefType>(source.getType());
int64_t rank = sourceType.getRank();
if (offsets.size() != rank || sizes.size() != rank ||
strides.size() != rank) {
throw std::runtime_error("Number of offsets, sizes, and strides "
"must match memref rank");
}
for (const auto &offset : offsets) {
auto indexType = builder.getIndexType();
if (offset.getType() != indexType) {
Value offset_val =
self.create<arith::IndexCastOp>(indexType, offset);
mixedOffsets.push_back(offset_val);
} else {
mixedOffsets.push_back(offset);
}
}
SmallVector<mlir::OpFoldResult> mixedSizes;
SmallVector<mlir::OpFoldResult> mixedStrides;
for (int64_t i = 0; i < rank; ++i) {
int64_t size = sizes[i];
int64_t stride = strides[i];
int64_t srcDim = sourceType.getDimSize(i);
if (size <= 0) {
throw std::runtime_error("Expected sizes to be positive");
}
if (stride <= 0) {
throw std::runtime_error("Expected strides to be positive");
}
if (!ShapedType::isDynamic(srcDim)) {
if (size > srcDim) {
throw std::runtime_error(
"Subview size cannot exceed source dimension size");
}
if (stride > srcDim) {
throw std::runtime_error(
"Stride cannot exceed source dimension size");
}
}
mixedSizes.push_back(
IntegerAttr::get(IntegerType::get(context, kIntegerAttrBitWidth), size));
mixedStrides.push_back(
IntegerAttr::get(IntegerType::get(context, kIntegerAttrBitWidth), stride));
}
return self.create<memref::SubViewOp>(source, mixedOffsets,
mixedSizes, mixedStrides);
});
}