#ifndef TRITON_CONVERSION_FMA_DOT_UTILITY_H
#define TRITON_CONVERSION_FMA_DOT_UTILITY_H
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Transforms/DialectConversion.h"
#include "triton/Dialect/Triton/IR/Dialect.h"
namespace mlir::triton::gpu {
/// Abstract interface for scalar multiplication of Value vectors.
///
/// Enable generation of hardware specific code in different backends.
class FMAVectorMultiplier {
public:
/// \returns scalar product of two arrays, plus c: a·b + c
virtual Value multiplyVectors(ArrayRef<Value> a, ArrayRef<Value> b,
Value c) = 0;
virtual ~FMAVectorMultiplier() = default;
};
/// Implements a framework for FMA dot conversion to llvm.
///
/// This function implements architecture independent part of FMA dot
/// conversion and calls "multiplier" object, which is defined by caller
/// and implements architecture dependant part of conversion.
LogicalResult parametricConvertFMADot(DotOp op, DotOp::Adaptor adaptor,
const LLVMTypeConverter *typeConverter,
ConversionPatternRewriter &rewriter,
FMAVectorMultiplier &multiplier);
} // namespace mlir::triton::gpu
#endif // TRITON_CONVERSION_FMA_DOT_UTILITY_H