#include "mlir/Dialect/GPU/TransformOps/Utils.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
#include "mlir/Dialect/GPU/TransformOps/GPUTransformOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
using namespace mlir;
using namespace mlir::gpu;
using namespace mlir::transform;
using namespace mlir::transform::gpu;
#define DEBUG_TYPE "gpu-transforms"
#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
#define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
template <typename ThreadOrBlockIdOp>
static Value buildLinearId(RewriterBase &rewriter, Location loc,
ArrayRef<OpFoldResult> originalBasisOfr) {
LLVM_DEBUG(llvm::interleaveComma(
originalBasisOfr,
DBGS() << "----buildLinearId with originalBasisOfr: ");
llvm::dbgs() << "\n");
assert(originalBasisOfr.size() == 3 && "expected 3 sizes");
IndexType indexType = rewriter.getIndexType();
AffineExpr tx, ty, tz, bdx, bdy;
bindDims(rewriter.getContext(), tx, ty, tz);
bindSymbols(rewriter.getContext(), bdx, bdy);
SmallVector<OpFoldResult> vals{
rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x)
.getResult(),
rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y)
.getResult(),
rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)
.getResult(),
originalBasisOfr[0], originalBasisOfr[1]};
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
rewriter, loc, tx + ty * bdx + tz * bdx * bdy, vals);
return getValueOrCreateConstantIndexOp(rewriter, loc, ofr);
}
template <typename ThreadOrBlockIdOp>
static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
auto res = [multiplicity](RewriterBase &rewriter, Location loc,
ArrayRef<int64_t> forallMappingSizes,
ArrayRef<int64_t> originalBasis) {
SmallVector<OpFoldResult> originalBasisOfr =
getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
OpFoldResult linearId =
buildLinearId<ThreadOrBlockIdOp>(rewriter, loc, originalBasisOfr);
SmallVector<int64_t> reverseBasisSizes(llvm::reverse(forallMappingSizes));
SmallVector<int64_t> strides = computeStrides(reverseBasisSizes);
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.floorDiv(multiplicity), {linearId});
SmallVector<AffineExpr> delinearizingExprs = delinearize(d0, strides);
SmallVector<Value> ids;
for (AffineExpr e : llvm::reverse(delinearizingExprs)) {
ids.push_back(
affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
}
LLVM_DEBUG(llvm::interleaveComma(reverseBasisSizes,
DBGS() << "--delinearization basis: ");
llvm::dbgs() << "\n";
llvm::interleaveComma(strides,
DBGS() << "--delinearization strides: ");
llvm::dbgs() << "\n";
llvm::interleaveComma(delinearizingExprs,
DBGS() << "--delinearization exprs: ");
llvm::dbgs() << "\n";
llvm::interleaveComma(ids, DBGS() << "--ids: ");
llvm::dbgs() << "\n";);
return IdBuilderResult{
ids,
SmallVector<int64_t>{computeProduct(originalBasis)},
SmallVector<int64_t>{computeProduct(forallMappingSizes) * multiplicity},
SmallVector<Value>{linearId.get<Value>()}};
};
return res;
}
template <typename ThreadOrBlockIdOp>
static GpuIdBuilderFnType common3DIdBuilderFn(int64_t multiplicity = 1) {
auto res = [multiplicity](RewriterBase &rewriter, Location loc,
ArrayRef<int64_t> forallMappingSizes,
ArrayRef<int64_t> originalBasis) {
IndexType indexType = rewriter.getIndexType();
SmallVector<Value> ids{
rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::x),
rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::y),
rewriter.create<ThreadOrBlockIdOp>(loc, indexType, Dimension::z)};
SmallVector<Value> scaledIds = ids;
AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
scaledIds[0] = affine::makeComposedFoldedAffineApply(
rewriter, loc, d0.floorDiv(multiplicity), {scaledIds[0]})
.get<Value>();
SmallVector<int64_t> forallMappingSizeInOriginalBasis(
forallMappingSizes.begin(), forallMappingSizes.end());
forallMappingSizeInOriginalBasis[0] *= multiplicity;
return IdBuilderResult{
scaledIds,
SmallVector<int64_t>{originalBasis},
SmallVector<int64_t>{forallMappingSizeInOriginalBasis},
ids};
};
return res;
}
namespace mlir {
namespace transform {
namespace gpu {
GpuIdBuilder::GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping,
const MappingIdBuilderFnType &fn)
: mappingAttributes(), idBuilder() {
if (useLinearMapping) {
for (uint64_t d = static_cast<uint64_t>(MappingId::LinearDim0),
e = getMaxEnumValForMappingId();
d <= e; ++d)
mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value()));
} else {
for (uint64_t d = static_cast<uint64_t>(MappingId::DimX),
e = static_cast<uint64_t>(MappingId::DimZ);
d <= e; ++d)
mappingAttributes.push_back(fn(ctx, symbolizeMappingId(d).value()));
}
}
GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping)
: GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) {
return GPUBlockMappingAttr::get(ctx, id);
}) {
idBuilder = useLinearMapping
? commonLinearIdBuilderFn<BlockIdOp>(1)
: common3DIdBuilderFn<BlockIdOp>(1);
}
GpuWarpgroupIdBuilder::GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
bool useLinearMapping)
: GpuIdBuilder(ctx, useLinearMapping,
[](MLIRContext *ctx, MappingId id) {
return GPUWarpgroupMappingAttr::get(ctx, id);
}),
warpSize(warpSize) {
idBuilder = useLinearMapping
? commonLinearIdBuilderFn<ThreadIdOp>(
kNumWarpsPerGroup * warpSize)
: common3DIdBuilderFn<ThreadIdOp>(
kNumWarpsPerGroup * warpSize);
}
GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
bool useLinearMapping)
: GpuIdBuilder(ctx, useLinearMapping,
[](MLIRContext *ctx, MappingId id) {
return GPUWarpMappingAttr::get(ctx, id);
}),
warpSize(warpSize) {
idBuilder =
useLinearMapping
? commonLinearIdBuilderFn<ThreadIdOp>(warpSize)
: common3DIdBuilderFn<ThreadIdOp>(warpSize);
}
GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping)
: GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) {
return GPUThreadMappingAttr::get(ctx, id);
}) {
idBuilder = useLinearMapping
? commonLinearIdBuilderFn<ThreadIdOp>(1)
: common3DIdBuilderFn<ThreadIdOp>(1);
}
DiagnosedSilenceableFailure checkGpuLimits(TransformOpInterface transformOp,
std::optional<int64_t> gridDimX,
std::optional<int64_t> gridDimY,
std::optional<int64_t> gridDimZ,
std::optional<int64_t> blockDimX,
std::optional<int64_t> blockDimY,
std::optional<int64_t> blockDimZ) {
static constexpr int maxTotalBlockdim = 1024;
static constexpr int maxBlockdimx = 1024;
static constexpr int maxBlockdimy = 1024;
static constexpr int maxBlockdimz = 64;
static constexpr int maxTotalGriddim = 2147483647;
static constexpr int maxGriddimx = 2147483647;
static constexpr int maxGriddimy = 65535;
static constexpr int maxGriddimz = 65535;
if ((blockDimX.value_or(1) * blockDimY.value_or(1) * blockDimZ.value_or(1)) >
maxTotalBlockdim ||
(gridDimX.value_or(1) * gridDimY.value_or(1) * gridDimZ.value_or(1)) >
maxTotalGriddim ||
blockDimX.value_or(1) > maxBlockdimx ||
blockDimY.value_or(1) > maxBlockdimy ||
blockDimZ.value_or(1) > maxBlockdimz ||
gridDimY.value_or(1) > maxGriddimy ||
gridDimZ.value_or(1) > maxGriddimz ||
gridDimX.value_or(1) > maxGriddimx) {
return transformOp.emitSilenceableError()
<< "Trying to launch a GPU kernel with grid_dims = ("
<< gridDimX.value_or(1) << ", " << gridDimY.value_or(1) << ", "
<< gridDimZ.value_or(1) << ") block_dims = ("
<< blockDimX.value_or(1) << ", " << blockDimY.value_or(1) << ", "
<< blockDimZ.value_or(1) << "). It is larger than the limits.";
}
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure createGpuLaunch(
RewriterBase &rewriter, Location loc, TransformOpInterface transformOp,
LaunchOp &launchOp, std::optional<int64_t> gridDimX,
std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ,
std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY,
std::optional<int64_t> blockDimZ) {
DiagnosedSilenceableFailure diag =
checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
blockDimY, blockDimZ);
if (!diag.succeeded())
return diag;
auto createConst = [&](int dim) {
return rewriter.create<arith::ConstantIndexOp>(loc, dim);
};
OpBuilder::InsertionGuard guard(rewriter);
Value one = createConst(1);
Value gridSizeX = gridDimX.has_value() ? createConst(gridDimX.value()) : one;
Value gridSizeY = gridDimY.has_value() ? createConst(gridDimY.value()) : one;
Value gridSizeZ = gridDimZ.has_value() ? createConst(gridDimZ.value()) : one;
Value blkSizeX = blockDimX.has_value() ? createConst(blockDimX.value()) : one;
Value blkSizeY = blockDimY.has_value() ? createConst(blockDimY.value()) : one;
Value blkSizeZ = blockDimZ.has_value() ? createConst(blockDimZ.value()) : one;
launchOp = rewriter.create<LaunchOp>(loc, gridSizeX, gridSizeY, gridSizeZ,
blkSizeX, blkSizeY, blkSizeZ);
rewriter.setInsertionPointToEnd(&launchOp.getBody().front());
rewriter.create<TerminatorOp>(loc);
return DiagnosedSilenceableFailure::success();
}
DiagnosedSilenceableFailure alterGpuLaunch(
RewriterBase &rewriter, LaunchOp gpuLaunch,
TransformOpInterface transformOp, std::optional<int64_t> gridDimX,
std::optional<int64_t> gridDimY, std::optional<int64_t> gridDimZ,
std::optional<int64_t> blockDimX, std::optional<int64_t> blockDimY,
std::optional<int64_t> blockDimZ) {
DiagnosedSilenceableFailure diag =
checkGpuLimits(transformOp, gridDimX, gridDimY, gridDimZ, blockDimX,
blockDimY, blockDimZ);
if (!diag.succeeded())
return diag;
KernelDim3 currentBlockdim = gpuLaunch.getBlockSizeOperandValues();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfterValue(currentBlockdim.x);
auto createConstValue = [&](int dim) {
return rewriter.create<arith::ConstantIndexOp>(currentBlockdim.x.getLoc(),
dim);
};
if (gridDimX.has_value())
gpuLaunch.getGridSizeXMutable().assign(createConstValue(gridDimX.value()));
if (gridDimY.has_value())
gpuLaunch.getGridSizeYMutable().assign(createConstValue(gridDimY.value()));
if (gridDimZ.has_value())
gpuLaunch.getGridSizeZMutable().assign(createConstValue(gridDimZ.value()));
if (blockDimX.has_value())
gpuLaunch.getBlockSizeXMutable().assign(
createConstValue(blockDimX.value()));
if (blockDimY.has_value())
gpuLaunch.getBlockSizeYMutable().assign(
createConstValue(blockDimY.value()));
if (blockDimZ.has_value())
gpuLaunch.getBlockSizeZMutable().assign(
createConstValue(blockDimZ.value()));
return DiagnosedSilenceableFailure::success();
}
}
}
}