#include "bolt/Passes/ThreeWayBranch.h"
using namespace llvm;
namespace llvm {
namespace bolt {
bool ThreeWayBranch::shouldRunOnFunction(BinaryFunction &Function) {
BinaryContext &BC = Function.getBinaryContext();
for (const BinaryBasicBlock &BB : Function)
for (const MCInst &Inst : BB)
if (BC.MIB->isPacked(Inst))
return false;
return true;
}
void ThreeWayBranch::runOnFunction(BinaryFunction &Function) {
BinaryContext &BC = Function.getBinaryContext();
MCContext *Ctx = BC.Ctx.get();
BinaryFunction::BasicBlockOrderType BlockLayout(
Function.getLayout().block_begin(), Function.getLayout().block_end());
for (BinaryBasicBlock *BB : BlockLayout) {
if (BB->getExecutionCount() == 0 ||
BB->getExecutionCount() == BinaryBasicBlock::COUNT_NO_PROFILE)
continue;
if (BB->succ_size() != 2)
continue;
if (BB->hasJumpTable())
continue;
BinaryBasicBlock *FalseSucc = BB->getConditionalSuccessor(false);
BinaryBasicBlock *TrueSucc = BB->getConditionalSuccessor(true);
if ((FalseSucc->succ_size() != 2 || FalseSucc->size() != 1) &&
(TrueSucc->succ_size() != 2 || TrueSucc->size() != 1))
continue;
BinaryBasicBlock *SecondBranch = FalseSucc;
BinaryBasicBlock *FirstEndpoint = TrueSucc;
if (FalseSucc->succ_size() != 2) {
SecondBranch = TrueSucc;
FirstEndpoint = FalseSucc;
}
BinaryBasicBlock *SecondEndpoint =
SecondBranch->getConditionalSuccessor(false);
BinaryBasicBlock *ThirdEndpoint =
SecondBranch->getConditionalSuccessor(true);
if (SecondBranch->pred_size() != 1)
continue;
MCInst *FirstJump = BB->getLastNonPseudoInstr();
MCInst *SecondJump = SecondBranch->getLastNonPseudoInstr();
unsigned FirstCC = BC.MIB->getCondCode(*FirstJump);
if (SecondBranch != FalseSucc)
FirstCC = BC.MIB->getInvertedCondCode(FirstCC);
unsigned ThirdCC =
BC.MIB->getInvertedCondCode(BC.MIB->getCondCodesLogicalOr(
BC.MIB->getInvertedCondCode(BC.MIB->getCondCode(*SecondJump)),
FirstCC));
unsigned SecondCC =
BC.MIB->getInvertedCondCode(BC.MIB->getCondCodesLogicalOr(
BC.MIB->getCondCode(*SecondJump), FirstCC));
if (!BC.MIB->isValidCondCode(FirstCC) ||
!BC.MIB->isValidCondCode(ThirdCC) || !BC.MIB->isValidCondCode(SecondCC))
continue;
std::vector<std::pair<BinaryBasicBlock *, unsigned>> Blocks;
Blocks.push_back(std::make_pair(FirstEndpoint, FirstCC));
Blocks.push_back(std::make_pair(SecondEndpoint, SecondCC));
Blocks.push_back(std::make_pair(ThirdEndpoint, ThirdCC));
llvm::sort(Blocks, [&](const std::pair<BinaryBasicBlock *, unsigned> A,
const std::pair<BinaryBasicBlock *, unsigned> B) {
return A.first->getExecutionCount() < B.first->getExecutionCount();
});
uint64_t NewSecondBranchCount = Blocks[1].first->getExecutionCount() +
Blocks[0].first->getExecutionCount();
bool SecondBranchBigger =
NewSecondBranchCount > Blocks[2].first->getExecutionCount();
BB->removeAllSuccessors();
if (SecondBranchBigger) {
BB->addSuccessor(Blocks[2].first, Blocks[2].first->getExecutionCount());
BB->addSuccessor(SecondBranch, NewSecondBranchCount);
} else {
BB->addSuccessor(SecondBranch, NewSecondBranchCount);
BB->addSuccessor(Blocks[2].first, Blocks[2].first->getExecutionCount());
}
SecondBranch->removeAllSuccessors();
SecondBranch->addSuccessor(Blocks[0].first,
Blocks[0].first->getExecutionCount());
SecondBranch->addSuccessor(Blocks[1].first,
Blocks[1].first->getExecutionCount());
SecondBranch->setExecutionCount(NewSecondBranchCount);
if (SecondBranchBigger)
BC.MIB->replaceBranchCondition(*FirstJump, Blocks[2].first->getLabel(),
Ctx, Blocks[2].second);
else
BC.MIB->replaceBranchCondition(
*FirstJump, SecondBranch->getLabel(), Ctx,
BC.MIB->getInvertedCondCode(Blocks[2].second));
BC.MIB->replaceBranchCondition(*SecondJump, Blocks[0].first->getLabel(),
Ctx, Blocks[0].second);
++BranchesAltered;
}
}
Error ThreeWayBranch::runOnFunctions(BinaryContext &BC) {
for (auto &It : BC.getBinaryFunctions()) {
BinaryFunction &Function = It.second;
if (!shouldRunOnFunction(Function))
continue;
runOnFunction(Function);
}
BC.outs() << "BOLT-INFO: number of three way branches order changed: "
<< BranchesAltered << "\n";
return Error::success();
}
}
}