* Copyright 2025 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "akg/Analysis/BufferAnalysis.h"
#include <algorithm>
#include <limits>
#include <optional>
#include <set>
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinTypes.h"
#include "llvm/ADT/PriorityQueue.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/raw_ostream.h"
namespace mlir {
namespace akg {
static std::optional<int64_t> getStaticTotalSize(ArrayRef<int64_t> shapes) {
int64_t totalSize = 1;
for (int64_t dim : shapes) {
if (ShapedType::isDynamic(dim)) {
return std::nullopt;
}
totalSize *= dim;
}
return totalSize;
}
static Value tracebackMemRef(Value memrefVal) {
Value current = memrefVal;
while (true) {
if (auto subview = current.getDefiningOp<memref::SubViewOp>()) {
current = subview.getSource();
} else if (auto reshape = current.getDefiningOp<memref::ReshapeOp>()) {
current = reshape.getSource();
} else if (auto expand = current.getDefiningOp<memref::ExpandShapeOp>()) {
current = expand.getSrc();
} else if (auto collapse = current.getDefiningOp<memref::CollapseShapeOp>()) {
current = collapse.getSrc();
} else if (auto cast = current.getDefiningOp<memref::ReinterpretCastOp>()) {
current = cast.getSource();
} else {
break;
}
}
return current;
}
static bool isDefiningOpAllocLike(Value value) {
auto defOp = value.getDefiningOp();
if (!defOp) return false;
return isa<memref::AllocOp, memref::AllocaOp>(defOp);
}
static bool canInplaceReuse(Value genBuffer, Value killBuffer, const llvm::DenseMap<Value, BufferInfo> &bufferInfos) {
auto genIt = bufferInfos.find(genBuffer);
auto killIt = bufferInfos.find(killBuffer);
if (genIt == bufferInfos.end() || killIt == bufferInfos.end()) {
return false;
}
if (killIt->second.constBits < genIt->second.constBits) {
return false;
}
auto genBitWidth = genIt->second.elementType.getIntOrFloatBitWidth();
auto killBitWidth = killIt->second.elementType.getIntOrFloatBitWidth();
return genBitWidth == killBitWidth;
}
namespace {
SmallVector<std::pair<Value, Value>> getOperationAliasInfo(Operation *op) {
SmallVector<std::pair<Value, Value>> aliasPairs;
if (isMemRefAliasingOp(op)) {
Value src = getAliasSource(op);
for (Value result : op->getResults()) {
if (isa<MemRefType>(result.getType())) {
aliasPairs.push_back({result, src});
}
}
}
return aliasPairs;
}
}
OpInfo *BufferAnalysis::UpdateLinearOperation(Operation *op) {
auto opInfo = std::make_unique<OpInfo>(op, seqIndex++);
auto curOpInfo = opInfo.get();
linearOperation.push_back(std::move(opInfo));
return curOpInfo;
}
void BufferAnalysis::UpdateOpBufferInfo(Operation *op, const ValueRange &results) {
for (const Value &operand : results) {
auto it = buffer2status.find(operand);
if (it != buffer2status.end()) {
continue;
}
Type opType = operand.getType();
int64_t constBits = 0;
Type elementType;
if (auto memRefType = dyn_cast<MemRefType>(opType)) {
Value traceValue = tracebackMemRef(operand);
auto tracedMemRefType = cast<MemRefType>(traceValue.getType());
elementType = tracedMemRefType.getElementType();
std::optional<int64_t> totalStaticSize = getStaticTotalSize(tracedMemRefType.getShape());
if (!totalStaticSize.has_value()) {
continue;
}
constBits = totalStaticSize.value() * static_cast<int64_t>(elementType.getIntOrFloatBitWidth());
} else if (opType.isIntOrFloat()) {
elementType = opType;
constBits = static_cast<int64_t>(opType.getIntOrFloatBitWidth());
} else {
continue;
}
bufferInfos[operand] = {op, constBits, elementType};
buffer2status[operand] = BufferStatus::DEFINED;
}
}
void BufferAnalysis::UpdateBufferAlias(Value buffer, Value aliasBuffer) {
UpdateBufferAlias(buffer, aliasBuffer, false);
}
void BufferAnalysis::UpdateBufferAlias(Value buffer, Value aliasBuffer, bool hasCond) {
SetVector<Value> buffers = GetAliasBuffers(buffer);
SetVector<Value> aliasBuffers = GetAliasBuffers(aliasBuffer);
buffers.insert(buffer);
aliasBuffers.insert(aliasBuffer);
UpdateBuffer2AliasVec(buffers, aliasBuffers, hasCond);
UpdateBuffer2AliasVec(aliasBuffers, buffers, hasCond);
if (!isDefiningOpAllocLike(buffer)) {
buffer2status[buffer] = BufferStatus::UNDEFINED;
}
}
void BufferAnalysis::UpdateBuffer2AliasVec(const SetVector<Value> &buffers, const SetVector<Value> &aliasBuffers,
bool hasCond) {
for (auto buffer : buffers) {
for (auto aliasValue : aliasBuffers) {
auto it = std::find_if(buffer2AliasVec[buffer].begin(), buffer2AliasVec[buffer].end(),
[aliasValue](const std::pair<Value, bool> &p) { return p.first == aliasValue; });
if (it != buffer2AliasVec[buffer].end()) {
it->second = it->second || hasCond;
} else {
buffer2AliasVec[buffer].push_back(std::make_pair(aliasValue, hasCond));
}
}
}
}
SetVector<Value> BufferAnalysis::GetAliasBuffers(Value aliasBuffer) {
SetVector<Value> aliasBuffers;
auto it = buffer2AliasVec.find(aliasBuffer);
if (it != buffer2AliasVec.end()) {
for (auto &pair : it->second) {
aliasBuffers.insert(pair.first);
}
}
return aliasBuffers;
}
void BufferAnalysis::UpdateOpGenInfo(OpInfo *opInfo, const ValueRange &results) {
if (results.empty()) {
return;
}
for (Value operand : results) {
auto aliasBuffers = GetAliasBuffers(operand);
aliasBuffers.insert(operand);
for (auto buffer : aliasBuffers) {
UpdateOperandGenInfo(opInfo, buffer);
}
}
}
void BufferAnalysis::UpdateOperandGenInfo(OpInfo *opInfo, Value operand) {
auto iter_buffer = buffer2status.find(operand);
if (iter_buffer == buffer2status.end()) {
return;
}
if (iter_buffer->second == BufferStatus::DEFINED) {
genKillMap[opInfo].gen.push_back(operand);
buffer2status[iter_buffer->first] = BufferStatus::GENED;
} else if (iter_buffer->second == BufferStatus::KILLED) {
llvm_unreachable("The buffer memory has been released and cannot be used again!");
}
}
void BufferAnalysis::OpKillHandle(OpInfo *opInfo, Liveness live, Block *block) {
const auto *liveBlockInfo = live.getLiveness(block);
assert(liveBlockInfo != nullptr && opInfo != nullptr);
auto currentLiveValues = liveBlockInfo->currentlyLiveValues(opInfo->operation);
if (currentLiveValues.empty()) {
return;
}
SetVector<Value> liveValues(currentLiveValues.begin(), currentLiveValues.end());
for (const Value &operand : liveValues) {
UpdateOpKillInfo(opInfo, operand, live);
}
}
void BufferAnalysis::UpdateOpKillInfo(OpInfo *opInfo, Value operand, Liveness live) {
auto aliasBuffers = GetAliasBuffers(operand);
aliasBuffers.insert(operand);
for (Value aliasBuffer : aliasBuffers) {
auto iterBuffer = buffer2status.find(aliasBuffer);
if (iterBuffer == buffer2status.end()) {
continue;
}
if (iterBuffer->second == BufferStatus::GENED &&
isParentOpDominate(iterBuffer->first.getDefiningOp(), opInfo->operation) &&
IsBufferDeadAfter(opInfo->operation, aliasBuffer, live)) {
genKillMap[opInfo].kill.push_back(aliasBuffer);
buffer2status[iterBuffer->first] = BufferStatus::KILLED;
}
}
}
bool BufferAnalysis::isParentOpDominate(Operation *op1, Operation *op2) const {
if (op1 == nullptr || op2 == nullptr) return false;
if (op2->getParentOp() == nullptr || op1->getParentOp() == nullptr) return false;
return op2->getParentOp()->isAncestor(op1->getParentOp());
}
bool BufferAnalysis::IsBlockAfter(Block *afterBlock, Block *beforeBlock) const {
if (afterBlock == beforeBlock) {
return false;
}
assert(afterBlock != nullptr && beforeBlock != nullptr);
mlir::Region *region = beforeBlock->getParent();
assert(region != nullptr);
for (auto it = region->begin(); it != region->end(); ++it) {
if (&*it == beforeBlock) {
for (++it; it != region->end(); ++it) {
if (&*it == afterBlock) {
return true;
}
}
break;
}
}
return false;
}
bool BufferAnalysis::IsDeadAfterBlock(Value value, Block *block) const {
for (auto &useOperand : value.getUses()) {
Operation *useOp = useOperand.getOwner();
assert(useOp != nullptr);
Block *useBlock = useOp->getBlock();
if (useBlock != block && IsBlockAfter(useBlock, block)) {
return false;
}
}
return true;
}
bool BufferAnalysis::IsBufferDeadAfter(Operation *op, Value buffer, Liveness live) const {
if (!live.isDeadAfter(buffer, op)) {
return false;
}
if (!IsDeadAfterBlock(buffer, op->getBlock())) {
return false;
}
return true;
}
template <typename ForOpType>
SmallVector<Value> BufferAnalysis::GetLiveBuffersInLoopImpl(ForOpType loopOp, Liveness live) {
SmallVector<Value> allocBeforeLoopBuffers;
const auto *liveBlockInfo = live.getLiveness(loopOp->getBlock());
assert(liveBlockInfo != nullptr);
auto currentLiveValues = liveBlockInfo->currentlyLiveValues(loopOp.getOperation());
if (currentLiveValues.empty()) {
return allocBeforeLoopBuffers;
}
SetVector<Value> currentLiveValuesOrder;
for (auto buffer : currentLiveValues) {
currentLiveValuesOrder.insert(buffer);
}
for (const Value &operand : currentLiveValuesOrder) {
auto aliasBuffers = GetAliasBuffers(operand);
aliasBuffers.insert(operand);
for (auto Buffer : aliasBuffers) {
auto iter = buffer2status.find(Buffer);
if (iter != buffer2status.end()) allocBeforeLoopBuffers.push_back(Buffer);
}
}
return allocBeforeLoopBuffers;
}
template <typename ForOpType>
void BufferAnalysis::UpdateForOpInitArgsAliasImpl(ForOpType forOp) {
SmallVector<Value> inits;
if constexpr (std::is_same_v<ForOpType, mlir::affine::AffineForOp>) {
inits = forOp.getInits();
} else {
inits = forOp.getInitArgs();
}
if (inits.empty()) {
return;
}
assert(inits.size() == forOp.getRegionIterArgs().size());
for (auto [i, arg] : llvm::enumerate(inits)) {
UpdateBufferAlias(forOp.getRegionIterArgs()[i], arg);
}
}
template <typename ForOpType>
void BufferAnalysis::UpdateForOpBufferAliasImpl(ForOpType forOp) {
if (forOp.getResults().empty()) {
return;
}
SmallVector<Value> yieldedValues;
if constexpr (std::is_same_v<ForOpType, mlir::affine::AffineForOp>) {
yieldedValues = forOp.getYieldedValues();
} else {
auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
yieldedValues = SmallVector<Value>(yieldOp.getOperands().begin(), yieldOp.getOperands().end());
}
if (!forOp.getRegionIterArgs().empty()) {
assert(yieldedValues.size() == forOp.getRegionIterArgs().size());
SmallVector<Value> inits;
if constexpr (std::is_same_v<ForOpType, mlir::affine::AffineForOp>) {
inits = forOp.getInits();
} else {
inits = forOp.getInitArgs();
}
assert(inits.size() == forOp.getRegionIterArgs().size());
for (auto [i, arg] : llvm::enumerate(forOp.getRegionIterArgs())) {
UpdateBufferAlias(yieldedValues[i], arg);
}
}
assert(forOp->getResults().size() == yieldedValues.size());
for (auto [i, arg] : llvm::enumerate(yieldedValues)) {
UpdateBufferAlias(forOp->getResult(i), arg);
}
}
template <typename ForOpType>
void BufferAnalysis::RecursiveForOpImpl(ForOpType forOp, Liveness live) {
auto forBeginSeq = UpdateLinearOperation(forOp.getOperation());
UpdateOpGenInfo(forBeginSeq, GetLiveBuffersInLoopImpl(forOp, live));
UpdateForOpInitArgsAliasImpl(forOp);
RecursionIR(&forOp.getRegion(), live);
UpdateForOpBufferAliasImpl(forOp);
auto forEndSeq = UpdateLinearOperation(forOp.getOperation());
OpKillHandle(forEndSeq, live, forOp->getBlock());
}
template <typename IfOpType, typename YieldOpType>
void BufferAnalysis::UpdateIfOpBufferAliasImpl(IfOpType ifOp, YieldOpType yieldOp) {
if (ifOp.getResults().empty()) {
return;
}
assert(ifOp->getResults().size() == yieldOp->getOperands().size());
for (auto [i, arg] : llvm::enumerate(yieldOp->getOperands())) {
UpdateBufferAlias(ifOp->getResult(i), arg, true);
}
}
template <typename IfOpType, typename YieldOpType>
void BufferAnalysis::RecursiveIfOpImpl(IfOpType ifOp, Liveness live) {
(void)UpdateLinearOperation(ifOp.getOperation());
RecursionIR(&ifOp.getThenRegion(), live);
auto curIfElse = UpdateLinearOperation(ifOp.getOperation());
if (!ifOp.getThenRegion().empty()) {
auto &thenBlock = ifOp.getThenRegion().front();
if (auto thenYield = dyn_cast<YieldOpType>(thenBlock.getTerminator())) {
UpdateIfOpBufferAliasImpl(ifOp, thenYield);
}
}
auto curIfEnd = curIfElse;
bool hasElse = false;
if constexpr (std::is_same_v<IfOpType, mlir::affine::AffineIfOp>) {
hasElse = ifOp.hasElse();
} else {
hasElse = !ifOp.getElseRegion().empty();
}
if (hasElse) {
RecursionIR(&ifOp.getElseRegion(), live);
curIfEnd = UpdateLinearOperation(ifOp.getOperation());
if (!ifOp.getElseRegion().empty()) {
auto &elseBlock = ifOp.getElseRegion().front();
if (auto elseYield = dyn_cast<YieldOpType>(elseBlock.getTerminator())) {
UpdateIfOpBufferAliasImpl(ifOp, elseYield);
}
}
}
OpKillHandle(curIfEnd, live, ifOp->getBlock());
}
void BufferAnalysis::RecursionIR(Region *region, Liveness live) {
auto result = region->walk<WalkOrder::PreOrder>([&](Operation *op) {
if (auto affineForOp = dyn_cast<mlir::affine::AffineForOp>(op)) {
RecursiveForOpImpl(affineForOp, live);
return WalkResult::skip();
} else if (auto affineIfOp = dyn_cast<mlir::affine::AffineIfOp>(op)) {
RecursiveIfOpImpl<mlir::affine::AffineIfOp, mlir::affine::AffineYieldOp>(affineIfOp, live);
return WalkResult::skip();
} else if (auto scfForOp = dyn_cast<mlir::scf::ForOp>(op)) {
RecursiveForOpImpl(scfForOp, live);
return WalkResult::skip();
} else if (auto scfIfOp = dyn_cast<mlir::scf::IfOp>(op)) {
RecursiveIfOpImpl<mlir::scf::IfOp, mlir::scf::YieldOp>(scfIfOp, live);
return WalkResult::skip();
}
auto curOpInfo = UpdateLinearOperation(op);
if (isa<bufferization::ToMemrefOp, bufferization::ToTensorOp, arith::ConstantOp>(op)) {
return WalkResult::advance();
}
auto aliasPairs = getOperationAliasInfo(op);
if (!aliasPairs.empty() && !isa<arith::SelectOp>(op)) {
for (auto aliasPair : aliasPairs) {
UpdateBufferAlias(aliasPair.first, aliasPair.second);
}
} else if (isa<memref::AllocOp, memref::AllocaOp>(op)) {
UpdateOpBufferInfo(op, op->getResults());
} else if (auto affineLoadOp = dyn_cast<mlir::affine::AffineLoadOp>(op)) {
UpdateOpBufferInfo(op, op->getResults());
UpdateOpGenInfo(curOpInfo, op->getResults());
OpKillHandle(curOpInfo, live, op->getBlock());
} else if (auto memrefLoadOp = dyn_cast<memref::LoadOp>(op)) {
UpdateOpBufferInfo(op, op->getResults());
UpdateOpGenInfo(curOpInfo, op->getResults());
OpKillHandle(curOpInfo, live, op->getBlock());
} else if (auto affineStoreOp = dyn_cast<mlir::affine::AffineStoreOp>(op)) {
UpdateStoreOpInfo(curOpInfo, affineStoreOp.getMemRef(), live);
} else if (auto memrefStoreOp = dyn_cast<memref::StoreOp>(op)) {
UpdateStoreOpInfo(curOpInfo, memrefStoreOp.getMemRef(), live);
} else if (auto selectOp = dyn_cast<arith::SelectOp>(op)) {
UpdateBufferAlias(selectOp.getResult(), selectOp.getTrueValue(), true);
UpdateBufferAlias(selectOp.getResult(), selectOp.getFalseValue(), true);
OpKillHandle(curOpInfo, live, op->getBlock());
} else if (op->getNumResults() > 0) {
bool hasScalarOrMemRefResult = false;
for (Value result : op->getResults()) {
Type resultType = result.getType();
if (resultType.isIntOrFloat() || isa<MemRefType>(resultType)) {
hasScalarOrMemRefResult = true;
break;
}
}
if (hasScalarOrMemRefResult) {
UpdateOpBufferInfo(op, op->getResults());
UpdateOpGenInfo(curOpInfo, op->getResults());
OpKillHandle(curOpInfo, live, op->getBlock());
}
}
return WalkResult::advance();
});
if (result == WalkResult::interrupt()) {
llvm_unreachable("BufferAnalysis Traverse IR Failed!");
}
}
void BufferAnalysis::UpdateStoreOpInfo(OpInfo *opInfo, const Value storeValue, Liveness live) {
SmallVector<Value, 1> storeValues;
storeValues.push_back(storeValue);
UpdateOpGenInfo(opInfo, storeValues);
OpKillHandle(opInfo, live, opInfo->operation->getBlock());
}
int64_t BufferAnalysis::getExtraBufferSizeByFactor(Operation *op) const {
if (!op->hasAttr("reduction_axes")) {
return 0;
}
if (op->getNumOperands() == 0) {
return 0;
}
Value inputOperand = op->getOperand(0);
constexpr int64_t kReduceFactor = 1;
int64_t numResults = std::max(static_cast<int64_t>(1), static_cast<int64_t>(op->getNumResults()));
uint32_t dataTypeWeight = getValDataTypeWeight(inputOperand, 1, dataTypeWeightMap);
return kReduceFactor * numResults * static_cast<int64_t>(dataTypeWeight);
}
void BufferAnalysis::GenerateBufferLife() {
int scopeTime = 0;
for (size_t i = 0; i < linearOperation.size(); ++i) {
auto it = genKillMap.find(linearOperation[i].get());
if (it == genKillMap.end()) {
scopeTime++;
continue;
}
for (const Value &genBuffer : it->second.gen) {
std::unique_ptr<BufferLife> bufferLife = std::make_unique<BufferLife>(genBuffer);
bufferLife->allocTime = scopeTime;
buffer2Life[genBuffer] = std::move(bufferLife);
}
for (const Value &killBuffer : it->second.kill) {
auto iter = buffer2Life.find(killBuffer);
if (iter != buffer2Life.end()) {
iter->second->freeTime = scopeTime;
}
}
scopeTime++;
}
for (auto &[buffer, life] : buffer2Life) {
if (life->freeTime == -1 && life->allocTime != -1) {
life->freeTime = scopeTime > 0 ? scopeTime - 1 : 0;
}
}
}
uint32_t BufferAnalysis::getValMultiBuffer(const Value &value, uint32_t def) const {
auto it = options.multiBufferCount.find(value);
if (it != options.multiBufferCount.end()) {
return static_cast<uint32_t>(it->second);
}
return def;
}
uint32_t BufferAnalysis::getValDataTypeWeight(const Value &value, uint32_t def,
const DataTypeWeightMap &weightMap) const {
auto it = weightMap.find(value);
if (it != weightMap.end()) {
return it->second;
}
return def;
}
void BufferAnalysis::printLiveRanges() const {
llvm::outs() << "\n==================== Live Ranges ====================\n";
llvm::outs() << "Considering " << liveRanges.size() << " live ranges:\n\n";
for (size_t i = 0; i < liveRanges.size(); ++i) {
const auto &liveRange = liveRanges[i];
llvm::outs() << "Live Range #" << i << ":\n";
llvm::outs() << " Start: " << liveRange.start << ", End: " << liveRange.end << ", Weight: " << liveRange.weight
<< "\n";
if (liveRange.op) {
llvm::outs() << " Operation: ";
liveRange.op->print(llvm::outs(), OpPrintingFlags().skipRegions());
llvm::outs() << "\n";
}
}
llvm::outs() << "=====================================================\n\n";
}
void BufferAnalysis::printBufferAnalysisInfo() const {
llvm::outs() << "\n================== Buffer Analysis ==================\n\n";
llvm::outs() << "--- Linear Operations ---\n";
for (size_t i = 0; i < linearOperation.size(); ++i) {
const auto &opInfo = linearOperation[i];
llvm::outs() << "[" << i << "] ";
if (opInfo->operation) {
opInfo->operation->print(llvm::outs(), OpPrintingFlags().skipRegions());
}
llvm::outs() << "\n";
auto it = genKillMap.find(opInfo.get());
if (it != genKillMap.end()) {
const auto &genKill = it->second;
if (!genKill.gen.empty()) {
llvm::outs() << " GEN: ";
for (size_t j = 0; j < genKill.gen.size(); ++j) {
if (j > 0) llvm::outs() << ", ";
genKill.gen[j].print(llvm::outs());
}
llvm::outs() << "\n";
}
if (!genKill.kill.empty()) {
llvm::outs() << " KILL: ";
for (size_t j = 0; j < genKill.kill.size(); ++j) {
if (j > 0) llvm::outs() << ", ";
genKill.kill[j].print(llvm::outs());
}
llvm::outs() << "\n";
}
}
}
llvm::outs() << "\n";
printLiveRanges();
int64_t maxBuffer = lineSweepRanges();
llvm::outs() << "--- Max Buffer ---\n";
llvm::outs() << " MaxBuffer: " << maxBuffer << "\n";
llvm::outs() << "\n";
llvm::outs() << "=====================================================\n\n";
}
llvm::DenseSet<Value> BufferAnalysis::gatherInplaceReuseBuffers() const {
llvm::DenseSet<Value> inplaceReuseBuffers;
for (auto &[opInfo, genKill] : genKillMap) {
if (genKill.gen.empty() || genKill.kill.empty()) {
continue;
}
for (const Value &genBuffer : genKill.gen) {
if (std::any_of(genKill.kill.begin(), genKill.kill.end(),
[&](const Value &killBuffer) { return canInplaceReuse(genBuffer, killBuffer, bufferInfos); })) {
inplaceReuseBuffers.insert(genBuffer);
}
}
}
return inplaceReuseBuffers;
}
void BufferAnalysis::gatherDataTypeWeights() {
dataTypeWeightMap.clear();
smallestTypeBits = std::numeric_limits<uint32_t>::max();
for (auto &[buffer, info] : bufferInfos) {
if (!info.elementType.isIntOrFloat()) continue;
uint32_t typeBits = static_cast<uint32_t>(info.elementType.getIntOrFloatBitWidth());
dataTypeWeightMap[buffer] = typeBits;
smallestTypeBits = std::min(smallestTypeBits, typeBits);
}
if (smallestTypeBits == std::numeric_limits<uint32_t>::max()) {
smallestTypeBits = 1;
}
for (auto &[buffer, bits] : dataTypeWeightMap) {
bits = (bits + smallestTypeBits - 1) / smallestTypeBits;
}
}
void BufferAnalysis::createLiveRangesFromBufferLife(const llvm::DenseSet<Value> &inplaceReuseBuffers,
const DataTypeWeightMap &dataTypeWeightMap) {
for (auto &[buffer, life] : buffer2Life) {
if (life->allocTime == -1 || life->freeTime == -1) {
continue;
}
auto it = bufferInfos.find(buffer);
if (it == bufferInfos.end()) {
continue;
}
if (inplaceReuseBuffers.contains(buffer)) {
continue;
}
uint32_t multiBuffer = getValMultiBuffer(buffer, 1);
uint32_t dataTypeWeight = getValDataTypeWeight(buffer, 1, dataTypeWeightMap);
int64_t weight = static_cast<int64_t>(multiBuffer) * static_cast<int64_t>(dataTypeWeight);
Operation *op = it->second.operation;
liveRanges.emplace_back(static_cast<uint32_t>(life->allocTime), static_cast<uint32_t>(life->freeTime), weight, op);
}
}
void BufferAnalysis::addExtraBufferLiveRanges(const DataTypeWeightMap &dataTypeWeightMap) {
llvm::DenseMap<size_t, uint32_t> opIndexToScopeTime;
uint32_t scopeTime = 0;
for (size_t i = 0; i < linearOperation.size(); ++i) {
opIndexToScopeTime[i] = scopeTime;
scopeTime++;
}
for (size_t i = 0; i < linearOperation.size(); ++i) {
Operation *op = linearOperation[i]->operation;
if (auto extraWeight = getExtraBufferSizeByFactor(op)) {
uint32_t resultMultiBuffer = 0;
uint32_t initMultiBuffer = 0;
if (op->getNumResults() > 0) {
Value firstResult = op->getResult(0);
resultMultiBuffer = getValMultiBuffer(firstResult, 0);
}
if (op->getNumOperands() > 0) {
Value initOperand = op->getOperand(op->getNumOperands() - 1);
initMultiBuffer = getValMultiBuffer(initOperand, 0);
}
int64_t multiBufferFactor = std::max(static_cast<uint32_t>(1), resultMultiBuffer + initMultiBuffer);
extraWeight *= multiBufferFactor;
uint32_t opScopeTime = opIndexToScopeTime[i];
liveRanges.emplace_back(opScopeTime, opScopeTime, extraWeight, op);
}
}
}
void BufferAnalysis::gatherLiveRanges() {
liveRanges.clear();
auto inplaceReuseBuffers = gatherInplaceReuseBuffers();
createLiveRangesFromBufferLife(inplaceReuseBuffers, dataTypeWeightMap);
addExtraBufferLiveRanges(dataTypeWeightMap);
llvm::sort(liveRanges);
}
int64_t BufferAnalysis::lineSweepRanges() const {
llvm::PriorityQueue<WeightedEndPair, llvm::SmallVector<WeightedEndPair>, std::greater<WeightedEndPair>> earlyDone;
int64_t maxBuffer = 0;
int64_t currentBuffer = 0;
for (const auto &liveRange : liveRanges) {
if (liveRange.start == liveRange.end) {
}
while (!earlyDone.empty() && earlyDone.top().first < liveRange.start) {
currentBuffer -= earlyDone.top().second;
earlyDone.pop();
}
earlyDone.push({liveRange.end, liveRange.weight});
currentBuffer += liveRange.weight;
maxBuffer = std::max(maxBuffer, currentBuffer);
}
return maxBuffer;
}
std::pair<int64_t, uint32_t> BufferAnalysis::calculateMaxBuffer() {
Region &funcRegion = func.getBody();
Liveness live(func);
RecursionIR(&funcRegion, live);
GenerateBufferLife();
gatherDataTypeWeights();
gatherLiveRanges();
if (options.printBufferInfo) {
printBufferAnalysisInfo();
}
int64_t maxBuffer = lineSweepRanges();
return {maxBuffer, smallestTypeBits};
}
std::pair<int64_t, uint32_t> countMaxBuffer(mlir::func::FuncOp func, const BufferAnalysisOptions &options) {
if (func.getBody().getBlocks().size() != 1) {
return {-1, 0};
}
BufferAnalysis analysis(func, options);
return analysis.calculateMaxBuffer();
}
}
}