1 //===- MergeICmps.cpp - Optimize chains of integer comparisons ------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass turns chains of integer comparisons into memcmp (the memcmp is
10 // later typically inlined as a chain of efficient hardware comparisons). This
11 // typically benefits c++ member or nonmember operator==().
12 //
13 // The basic idea is to replace a longer chain of integer comparisons loaded
14 // from contiguous memory locations into a shorter chain of larger integer
15 // comparisons. Benefits are double:
16 // - There are less jumps, and therefore less opportunities for mispredictions
17 // and I-cache misses.
18 // - Code size is smaller, both because jumps are removed and because the
19 // encoding of a 2*n byte compare is smaller than that of two n-byte
20 // compares.
21 //
22 // Example:
23 //
24 // struct S {
25 // int a;
26 // char b;
27 // char c;
28 // uint16_t d;
29 // bool operator==(const S& o) const {
30 // return a == o.a && b == o.b && c == o.c && d == o.d;
31 // }
32 // };
33 //
34 // Is optimized as :
35 //
36 // bool S::operator==(const S& o) const {
37 // return memcmp(this, &o, 8) == 0;
38 // }
39 //
40 // Which will later be expanded (ExpandMemCmp) as a single 8-bytes icmp.
41 //
42 //===----------------------------------------------------------------------===//
43
44 #include "llvm/Transforms/Scalar/MergeICmps.h"
45 #include "llvm/Analysis/DomTreeUpdater.h"
46 #include "llvm/Analysis/GlobalsModRef.h"
47 #include "llvm/Analysis/Loads.h"
48 #include "llvm/Analysis/TargetLibraryInfo.h"
49 #include "llvm/Analysis/TargetTransformInfo.h"
50 #include "llvm/IR/Dominators.h"
51 #include "llvm/IR/Function.h"
52 #include "llvm/IR/IRBuilder.h"
53 #include "llvm/InitializePasses.h"
54 #include "llvm/Pass.h"
55 #include "llvm/Transforms/Scalar.h"
56 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
57 #include "llvm/Transforms/Utils/BuildLibCalls.h"
58 #include <algorithm>
59 #include <numeric>
60 #include <utility>
61 #include <vector>
62
63 using namespace llvm;
64
65 namespace {
66
67 #define DEBUG_TYPE "mergeicmps"
68
69 // A BCE atom "Binary Compare Expression Atom" represents an integer load
70 // that is a constant offset from a base value, e.g. `a` or `o.c` in the example
71 // at the top.
72 struct BCEAtom {
73 BCEAtom() = default;
BCEAtom__anon2691366f0111::BCEAtom74 BCEAtom(GetElementPtrInst *GEP, LoadInst *LoadI, int BaseId, APInt Offset)
75 : GEP(GEP), LoadI(LoadI), BaseId(BaseId), Offset(Offset) {}
76
77 BCEAtom(const BCEAtom &) = delete;
78 BCEAtom &operator=(const BCEAtom &) = delete;
79
80 BCEAtom(BCEAtom &&that) = default;
operator =__anon2691366f0111::BCEAtom81 BCEAtom &operator=(BCEAtom &&that) {
82 if (this == &that)
83 return *this;
84 GEP = that.GEP;
85 LoadI = that.LoadI;
86 BaseId = that.BaseId;
87 Offset = std::move(that.Offset);
88 return *this;
89 }
90
91 // We want to order BCEAtoms by (Base, Offset). However we cannot use
92 // the pointer values for Base because these are non-deterministic.
93 // To make sure that the sort order is stable, we first assign to each atom
94 // base value an index based on its order of appearance in the chain of
95 // comparisons. We call this index `BaseOrdering`. For example, for:
96 // b[3] == c[2] && a[1] == d[1] && b[4] == c[3]
97 // | block 1 | | block 2 | | block 3 |
98 // b gets assigned index 0 and a index 1, because b appears as LHS in block 1,
99 // which is before block 2.
100 // We then sort by (BaseOrdering[LHS.Base()], LHS.Offset), which is stable.
operator <__anon2691366f0111::BCEAtom101 bool operator<(const BCEAtom &O) const {
102 return BaseId != O.BaseId ? BaseId < O.BaseId : Offset.slt(O.Offset);
103 }
104
105 GetElementPtrInst *GEP = nullptr;
106 LoadInst *LoadI = nullptr;
107 unsigned BaseId = 0;
108 APInt Offset;
109 };
110
111 // A class that assigns increasing ids to values in the order in which they are
112 // seen. See comment in `BCEAtom::operator<()``.
113 class BaseIdentifier {
114 public:
115 // Returns the id for value `Base`, after assigning one if `Base` has not been
116 // seen before.
getBaseId(const Value * Base)117 int getBaseId(const Value *Base) {
118 assert(Base && "invalid base");
119 const auto Insertion = BaseToIndex.try_emplace(Base, Order);
120 if (Insertion.second)
121 ++Order;
122 return Insertion.first->second;
123 }
124
125 private:
126 unsigned Order = 1;
127 DenseMap<const Value*, int> BaseToIndex;
128 };
129
130 // If this value is a load from a constant offset w.r.t. a base address, and
131 // there are no other users of the load or address, returns the base address and
132 // the offset.
visitICmpLoadOperand(Value * const Val,BaseIdentifier & BaseId)133 BCEAtom visitICmpLoadOperand(Value *const Val, BaseIdentifier &BaseId) {
134 auto *const LoadI = dyn_cast<LoadInst>(Val);
135 if (!LoadI)
136 return {};
137 LLVM_DEBUG(dbgs() << "load\n");
138 if (LoadI->isUsedOutsideOfBlock(LoadI->getParent())) {
139 LLVM_DEBUG(dbgs() << "used outside of block\n");
140 return {};
141 }
142 // Do not optimize atomic loads to non-atomic memcmp
143 if (!LoadI->isSimple()) {
144 LLVM_DEBUG(dbgs() << "volatile or atomic\n");
145 return {};
146 }
147 Value *Addr = LoadI->getOperand(0);
148 if (Addr->getType()->getPointerAddressSpace() != 0) {
149 LLVM_DEBUG(dbgs() << "from non-zero AddressSpace\n");
150 return {};
151 }
152 const auto &DL = LoadI->getModule()->getDataLayout();
153 if (!isDereferenceablePointer(Addr, LoadI->getType(), DL)) {
154 LLVM_DEBUG(dbgs() << "not dereferenceable\n");
155 // We need to make sure that we can do comparison in any order, so we
156 // require memory to be unconditionnally dereferencable.
157 return {};
158 }
159
160 APInt Offset = APInt(DL.getPointerTypeSizeInBits(Addr->getType()), 0);
161 Value *Base = Addr;
162 auto *GEP = dyn_cast<GetElementPtrInst>(Addr);
163 if (GEP) {
164 LLVM_DEBUG(dbgs() << "GEP\n");
165 if (GEP->isUsedOutsideOfBlock(LoadI->getParent())) {
166 LLVM_DEBUG(dbgs() << "used outside of block\n");
167 return {};
168 }
169 if (!GEP->accumulateConstantOffset(DL, Offset))
170 return {};
171 Base = GEP->getPointerOperand();
172 }
173 return BCEAtom(GEP, LoadI, BaseId.getBaseId(Base), Offset);
174 }
175
176 // A comparison between two BCE atoms, e.g. `a == o.a` in the example at the
177 // top.
178 // Note: the terminology is misleading: the comparison is symmetric, so there
179 // is no real {l/r}hs. What we want though is to have the same base on the
180 // left (resp. right), so that we can detect consecutive loads. To ensure this
181 // we put the smallest atom on the left.
182 struct BCECmp {
183 BCEAtom Lhs;
184 BCEAtom Rhs;
185 int SizeBits;
186 const ICmpInst *CmpI;
187
BCECmp__anon2691366f0111::BCECmp188 BCECmp(BCEAtom L, BCEAtom R, int SizeBits, const ICmpInst *CmpI)
189 : Lhs(std::move(L)), Rhs(std::move(R)), SizeBits(SizeBits), CmpI(CmpI) {
190 if (Rhs < Lhs) std::swap(Rhs, Lhs);
191 }
192 };
193
194 // A basic block with a comparison between two BCE atoms.
195 // The block might do extra work besides the atom comparison, in which case
196 // doesOtherWork() returns true. Under some conditions, the block can be
197 // split into the atom comparison part and the "other work" part
198 // (see canSplit()).
199 class BCECmpBlock {
200 public:
201 typedef SmallDenseSet<const Instruction *, 8> InstructionSet;
202
BCECmpBlock(BCECmp Cmp,BasicBlock * BB,InstructionSet BlockInsts)203 BCECmpBlock(BCECmp Cmp, BasicBlock *BB, InstructionSet BlockInsts)
204 : BB(BB), BlockInsts(std::move(BlockInsts)), Cmp(std::move(Cmp)) {}
205
Lhs() const206 const BCEAtom &Lhs() const { return Cmp.Lhs; }
Rhs() const207 const BCEAtom &Rhs() const { return Cmp.Rhs; }
SizeBits() const208 int SizeBits() const { return Cmp.SizeBits; }
209
210 // Returns true if the block does other works besides comparison.
211 bool doesOtherWork() const;
212
213 // Returns true if the non-BCE-cmp instructions can be separated from BCE-cmp
214 // instructions in the block.
215 bool canSplit(AliasAnalysis &AA) const;
216
217 // Return true if this all the relevant instructions in the BCE-cmp-block can
218 // be sunk below this instruction. By doing this, we know we can separate the
219 // BCE-cmp-block instructions from the non-BCE-cmp-block instructions in the
220 // block.
221 bool canSinkBCECmpInst(const Instruction *, AliasAnalysis &AA) const;
222
223 // We can separate the BCE-cmp-block instructions and the non-BCE-cmp-block
224 // instructions. Split the old block and move all non-BCE-cmp-insts into the
225 // new parent block.
226 void split(BasicBlock *NewParent, AliasAnalysis &AA) const;
227
228 // The basic block where this comparison happens.
229 BasicBlock *BB;
230 // Instructions relating to the BCECmp and branch.
231 InstructionSet BlockInsts;
232 // The block requires splitting.
233 bool RequireSplit = false;
234 // Original order of this block in the chain.
235 unsigned OrigOrder = 0;
236
237 private:
238 BCECmp Cmp;
239 };
240
canSinkBCECmpInst(const Instruction * Inst,AliasAnalysis & AA) const241 bool BCECmpBlock::canSinkBCECmpInst(const Instruction *Inst,
242 AliasAnalysis &AA) const {
243 // If this instruction may clobber the loads and is in middle of the BCE cmp
244 // block instructions, then bail for now.
245 if (Inst->mayWriteToMemory()) {
246 auto MayClobber = [&](LoadInst *LI) {
247 // If a potentially clobbering instruction comes before the load,
248 // we can still safely sink the load.
249 return (Inst->getParent() != LI->getParent() || !Inst->comesBefore(LI)) &&
250 isModSet(AA.getModRefInfo(Inst, MemoryLocation::get(LI)));
251 };
252 if (MayClobber(Cmp.Lhs.LoadI) || MayClobber(Cmp.Rhs.LoadI))
253 return false;
254 }
255 // Make sure this instruction does not use any of the BCE cmp block
256 // instructions as operand.
257 return llvm::none_of(Inst->operands(), [&](const Value *Op) {
258 const Instruction *OpI = dyn_cast<Instruction>(Op);
259 return OpI && BlockInsts.contains(OpI);
260 });
261 }
262
split(BasicBlock * NewParent,AliasAnalysis & AA) const263 void BCECmpBlock::split(BasicBlock *NewParent, AliasAnalysis &AA) const {
264 llvm::SmallVector<Instruction *, 4> OtherInsts;
265 for (Instruction &Inst : *BB) {
266 if (BlockInsts.count(&Inst))
267 continue;
268 assert(canSinkBCECmpInst(&Inst, AA) && "Split unsplittable block");
269 // This is a non-BCE-cmp-block instruction. And it can be separated
270 // from the BCE-cmp-block instruction.
271 OtherInsts.push_back(&Inst);
272 }
273
274 // Do the actual spliting.
275 for (Instruction *Inst : reverse(OtherInsts))
276 Inst->moveBefore(*NewParent, NewParent->begin());
277 }
278
canSplit(AliasAnalysis & AA) const279 bool BCECmpBlock::canSplit(AliasAnalysis &AA) const {
280 for (Instruction &Inst : *BB) {
281 if (!BlockInsts.count(&Inst)) {
282 if (!canSinkBCECmpInst(&Inst, AA))
283 return false;
284 }
285 }
286 return true;
287 }
288
doesOtherWork() const289 bool BCECmpBlock::doesOtherWork() const {
290 // TODO(courbet): Can we allow some other things ? This is very conservative.
291 // We might be able to get away with anything does not have any side
292 // effects outside of the basic block.
293 // Note: The GEPs and/or loads are not necessarily in the same block.
294 for (const Instruction &Inst : *BB) {
295 if (!BlockInsts.count(&Inst))
296 return true;
297 }
298 return false;
299 }
300
301 // Visit the given comparison. If this is a comparison between two valid
302 // BCE atoms, returns the comparison.
visitICmp(const ICmpInst * const CmpI,const ICmpInst::Predicate ExpectedPredicate,BaseIdentifier & BaseId)303 Optional<BCECmp> visitICmp(const ICmpInst *const CmpI,
304 const ICmpInst::Predicate ExpectedPredicate,
305 BaseIdentifier &BaseId) {
306 // The comparison can only be used once:
307 // - For intermediate blocks, as a branch condition.
308 // - For the final block, as an incoming value for the Phi.
309 // If there are any other uses of the comparison, we cannot merge it with
310 // other comparisons as we would create an orphan use of the value.
311 if (!CmpI->hasOneUse()) {
312 LLVM_DEBUG(dbgs() << "cmp has several uses\n");
313 return None;
314 }
315 if (CmpI->getPredicate() != ExpectedPredicate)
316 return None;
317 LLVM_DEBUG(dbgs() << "cmp "
318 << (ExpectedPredicate == ICmpInst::ICMP_EQ ? "eq" : "ne")
319 << "\n");
320 auto Lhs = visitICmpLoadOperand(CmpI->getOperand(0), BaseId);
321 if (!Lhs.BaseId)
322 return None;
323 auto Rhs = visitICmpLoadOperand(CmpI->getOperand(1), BaseId);
324 if (!Rhs.BaseId)
325 return None;
326 const auto &DL = CmpI->getModule()->getDataLayout();
327 return BCECmp(std::move(Lhs), std::move(Rhs),
328 DL.getTypeSizeInBits(CmpI->getOperand(0)->getType()), CmpI);
329 }
330
331 // Visit the given comparison block. If this is a comparison between two valid
332 // BCE atoms, returns the comparison.
visitCmpBlock(Value * const Val,BasicBlock * const Block,const BasicBlock * const PhiBlock,BaseIdentifier & BaseId)333 Optional<BCECmpBlock> visitCmpBlock(Value *const Val, BasicBlock *const Block,
334 const BasicBlock *const PhiBlock,
335 BaseIdentifier &BaseId) {
336 if (Block->empty()) return None;
337 auto *const BranchI = dyn_cast<BranchInst>(Block->getTerminator());
338 if (!BranchI) return None;
339 LLVM_DEBUG(dbgs() << "branch\n");
340 Value *Cond;
341 ICmpInst::Predicate ExpectedPredicate;
342 if (BranchI->isUnconditional()) {
343 // In this case, we expect an incoming value which is the result of the
344 // comparison. This is the last link in the chain of comparisons (note
345 // that this does not mean that this is the last incoming value, blocks
346 // can be reordered).
347 Cond = Val;
348 ExpectedPredicate = ICmpInst::ICMP_EQ;
349 } else {
350 // In this case, we expect a constant incoming value (the comparison is
351 // chained).
352 const auto *const Const = cast<ConstantInt>(Val);
353 LLVM_DEBUG(dbgs() << "const\n");
354 if (!Const->isZero()) return None;
355 LLVM_DEBUG(dbgs() << "false\n");
356 assert(BranchI->getNumSuccessors() == 2 && "expecting a cond branch");
357 BasicBlock *const FalseBlock = BranchI->getSuccessor(1);
358 Cond = BranchI->getCondition();
359 ExpectedPredicate =
360 FalseBlock == PhiBlock ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE;
361 }
362
363 auto *CmpI = dyn_cast<ICmpInst>(Cond);
364 if (!CmpI) return None;
365 LLVM_DEBUG(dbgs() << "icmp\n");
366
367 Optional<BCECmp> Result = visitICmp(CmpI, ExpectedPredicate, BaseId);
368 if (!Result)
369 return None;
370
371 BCECmpBlock::InstructionSet BlockInsts(
372 {Result->Lhs.LoadI, Result->Rhs.LoadI, Result->CmpI, BranchI});
373 if (Result->Lhs.GEP)
374 BlockInsts.insert(Result->Lhs.GEP);
375 if (Result->Rhs.GEP)
376 BlockInsts.insert(Result->Rhs.GEP);
377 return BCECmpBlock(std::move(*Result), Block, BlockInsts);
378 }
379
enqueueBlock(std::vector<BCECmpBlock> & Comparisons,BCECmpBlock && Comparison)380 static inline void enqueueBlock(std::vector<BCECmpBlock> &Comparisons,
381 BCECmpBlock &&Comparison) {
382 LLVM_DEBUG(dbgs() << "Block '" << Comparison.BB->getName()
383 << "': Found cmp of " << Comparison.SizeBits()
384 << " bits between " << Comparison.Lhs().BaseId << " + "
385 << Comparison.Lhs().Offset << " and "
386 << Comparison.Rhs().BaseId << " + "
387 << Comparison.Rhs().Offset << "\n");
388 LLVM_DEBUG(dbgs() << "\n");
389 Comparison.OrigOrder = Comparisons.size();
390 Comparisons.push_back(std::move(Comparison));
391 }
392
393 // A chain of comparisons.
394 class BCECmpChain {
395 public:
396 using ContiguousBlocks = std::vector<BCECmpBlock>;
397
398 BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
399 AliasAnalysis &AA);
400
401 bool simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
402 DomTreeUpdater &DTU);
403
atLeastOneMerged() const404 bool atLeastOneMerged() const {
405 return any_of(MergedBlocks_,
406 [](const auto &Blocks) { return Blocks.size() > 1; });
407 }
408
409 private:
410 PHINode &Phi_;
411 // The list of all blocks in the chain, grouped by contiguity.
412 std::vector<ContiguousBlocks> MergedBlocks_;
413 // The original entry block (before sorting);
414 BasicBlock *EntryBlock_;
415 };
416
areContiguous(const BCECmpBlock & First,const BCECmpBlock & Second)417 static bool areContiguous(const BCECmpBlock &First, const BCECmpBlock &Second) {
418 return First.Lhs().BaseId == Second.Lhs().BaseId &&
419 First.Rhs().BaseId == Second.Rhs().BaseId &&
420 First.Lhs().Offset + First.SizeBits() / 8 == Second.Lhs().Offset &&
421 First.Rhs().Offset + First.SizeBits() / 8 == Second.Rhs().Offset;
422 }
423
getMinOrigOrder(const BCECmpChain::ContiguousBlocks & Blocks)424 static unsigned getMinOrigOrder(const BCECmpChain::ContiguousBlocks &Blocks) {
425 unsigned MinOrigOrder = std::numeric_limits<unsigned>::max();
426 for (const BCECmpBlock &Block : Blocks)
427 MinOrigOrder = std::min(MinOrigOrder, Block.OrigOrder);
428 return MinOrigOrder;
429 }
430
431 /// Given a chain of comparison blocks, groups the blocks into contiguous
432 /// ranges that can be merged together into a single comparison.
433 static std::vector<BCECmpChain::ContiguousBlocks>
mergeBlocks(std::vector<BCECmpBlock> && Blocks)434 mergeBlocks(std::vector<BCECmpBlock> &&Blocks) {
435 std::vector<BCECmpChain::ContiguousBlocks> MergedBlocks;
436
437 // Sort to detect continuous offsets.
438 llvm::sort(Blocks,
439 [](const BCECmpBlock &LhsBlock, const BCECmpBlock &RhsBlock) {
440 return std::tie(LhsBlock.Lhs(), LhsBlock.Rhs()) <
441 std::tie(RhsBlock.Lhs(), RhsBlock.Rhs());
442 });
443
444 BCECmpChain::ContiguousBlocks *LastMergedBlock = nullptr;
445 for (BCECmpBlock &Block : Blocks) {
446 if (!LastMergedBlock || !areContiguous(LastMergedBlock->back(), Block)) {
447 MergedBlocks.emplace_back();
448 LastMergedBlock = &MergedBlocks.back();
449 } else {
450 LLVM_DEBUG(dbgs() << "Merging block " << Block.BB->getName() << " into "
451 << LastMergedBlock->back().BB->getName() << "\n");
452 }
453 LastMergedBlock->push_back(std::move(Block));
454 }
455
456 // While we allow reordering for merging, do not reorder unmerged comparisons.
457 // Doing so may introduce branch on poison.
458 llvm::sort(MergedBlocks, [](const BCECmpChain::ContiguousBlocks &LhsBlocks,
459 const BCECmpChain::ContiguousBlocks &RhsBlocks) {
460 return getMinOrigOrder(LhsBlocks) < getMinOrigOrder(RhsBlocks);
461 });
462
463 return MergedBlocks;
464 }
465
BCECmpChain(const std::vector<BasicBlock * > & Blocks,PHINode & Phi,AliasAnalysis & AA)466 BCECmpChain::BCECmpChain(const std::vector<BasicBlock *> &Blocks, PHINode &Phi,
467 AliasAnalysis &AA)
468 : Phi_(Phi) {
469 assert(!Blocks.empty() && "a chain should have at least one block");
470 // Now look inside blocks to check for BCE comparisons.
471 std::vector<BCECmpBlock> Comparisons;
472 BaseIdentifier BaseId;
473 for (BasicBlock *const Block : Blocks) {
474 assert(Block && "invalid block");
475 Optional<BCECmpBlock> Comparison = visitCmpBlock(
476 Phi.getIncomingValueForBlock(Block), Block, Phi.getParent(), BaseId);
477 if (!Comparison) {
478 LLVM_DEBUG(dbgs() << "chain with invalid BCECmpBlock, no merge.\n");
479 return;
480 }
481 if (Comparison->doesOtherWork()) {
482 LLVM_DEBUG(dbgs() << "block '" << Comparison->BB->getName()
483 << "' does extra work besides compare\n");
484 if (Comparisons.empty()) {
485 // This is the initial block in the chain, in case this block does other
486 // work, we can try to split the block and move the irrelevant
487 // instructions to the predecessor.
488 //
489 // If this is not the initial block in the chain, splitting it wont
490 // work.
491 //
492 // As once split, there will still be instructions before the BCE cmp
493 // instructions that do other work in program order, i.e. within the
494 // chain before sorting. Unless we can abort the chain at this point
495 // and start anew.
496 //
497 // NOTE: we only handle blocks a with single predecessor for now.
498 if (Comparison->canSplit(AA)) {
499 LLVM_DEBUG(dbgs()
500 << "Split initial block '" << Comparison->BB->getName()
501 << "' that does extra work besides compare\n");
502 Comparison->RequireSplit = true;
503 enqueueBlock(Comparisons, std::move(*Comparison));
504 } else {
505 LLVM_DEBUG(dbgs()
506 << "ignoring initial block '" << Comparison->BB->getName()
507 << "' that does extra work besides compare\n");
508 }
509 continue;
510 }
511 // TODO(courbet): Right now we abort the whole chain. We could be
512 // merging only the blocks that don't do other work and resume the
513 // chain from there. For example:
514 // if (a[0] == b[0]) { // bb1
515 // if (a[1] == b[1]) { // bb2
516 // some_value = 3; //bb3
517 // if (a[2] == b[2]) { //bb3
518 // do a ton of stuff //bb4
519 // }
520 // }
521 // }
522 //
523 // This is:
524 //
525 // bb1 --eq--> bb2 --eq--> bb3* -eq--> bb4 --+
526 // \ \ \ \
527 // ne ne ne \
528 // \ \ \ v
529 // +------------+-----------+----------> bb_phi
530 //
531 // We can only merge the first two comparisons, because bb3* does
532 // "other work" (setting some_value to 3).
533 // We could still merge bb1 and bb2 though.
534 return;
535 }
536 enqueueBlock(Comparisons, std::move(*Comparison));
537 }
538
539 // It is possible we have no suitable comparison to merge.
540 if (Comparisons.empty()) {
541 LLVM_DEBUG(dbgs() << "chain with no BCE basic blocks, no merge\n");
542 return;
543 }
544 EntryBlock_ = Comparisons[0].BB;
545 MergedBlocks_ = mergeBlocks(std::move(Comparisons));
546 }
547
548 namespace {
549
550 // A class to compute the name of a set of merged basic blocks.
551 // This is optimized for the common case of no block names.
552 class MergedBlockName {
553 // Storage for the uncommon case of several named blocks.
554 SmallString<16> Scratch;
555
556 public:
MergedBlockName(ArrayRef<BCECmpBlock> Comparisons)557 explicit MergedBlockName(ArrayRef<BCECmpBlock> Comparisons)
558 : Name(makeName(Comparisons)) {}
559 const StringRef Name;
560
561 private:
makeName(ArrayRef<BCECmpBlock> Comparisons)562 StringRef makeName(ArrayRef<BCECmpBlock> Comparisons) {
563 assert(!Comparisons.empty() && "no basic block");
564 // Fast path: only one block, or no names at all.
565 if (Comparisons.size() == 1)
566 return Comparisons[0].BB->getName();
567 const int size = std::accumulate(Comparisons.begin(), Comparisons.end(), 0,
568 [](int i, const BCECmpBlock &Cmp) {
569 return i + Cmp.BB->getName().size();
570 });
571 if (size == 0)
572 return StringRef("", 0);
573
574 // Slow path: at least two blocks, at least one block with a name.
575 Scratch.clear();
576 // We'll have `size` bytes for name and `Comparisons.size() - 1` bytes for
577 // separators.
578 Scratch.reserve(size + Comparisons.size() - 1);
579 const auto append = [this](StringRef str) {
580 Scratch.append(str.begin(), str.end());
581 };
582 append(Comparisons[0].BB->getName());
583 for (int I = 1, E = Comparisons.size(); I < E; ++I) {
584 const BasicBlock *const BB = Comparisons[I].BB;
585 if (!BB->getName().empty()) {
586 append("+");
587 append(BB->getName());
588 }
589 }
590 return Scratch.str();
591 }
592 };
593 } // namespace
594
595 // Merges the given contiguous comparison blocks into one memcmp block.
mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,BasicBlock * const InsertBefore,BasicBlock * const NextCmpBlock,PHINode & Phi,const TargetLibraryInfo & TLI,AliasAnalysis & AA,DomTreeUpdater & DTU)596 static BasicBlock *mergeComparisons(ArrayRef<BCECmpBlock> Comparisons,
597 BasicBlock *const InsertBefore,
598 BasicBlock *const NextCmpBlock,
599 PHINode &Phi, const TargetLibraryInfo &TLI,
600 AliasAnalysis &AA, DomTreeUpdater &DTU) {
601 assert(!Comparisons.empty() && "merging zero comparisons");
602 LLVMContext &Context = NextCmpBlock->getContext();
603 const BCECmpBlock &FirstCmp = Comparisons[0];
604
605 // Create a new cmp block before next cmp block.
606 BasicBlock *const BB =
607 BasicBlock::Create(Context, MergedBlockName(Comparisons).Name,
608 NextCmpBlock->getParent(), InsertBefore);
609 IRBuilder<> Builder(BB);
610 // Add the GEPs from the first BCECmpBlock.
611 Value *Lhs, *Rhs;
612 if (FirstCmp.Lhs().GEP)
613 Lhs = Builder.Insert(FirstCmp.Lhs().GEP->clone());
614 else
615 Lhs = FirstCmp.Lhs().LoadI->getPointerOperand();
616 if (FirstCmp.Rhs().GEP)
617 Rhs = Builder.Insert(FirstCmp.Rhs().GEP->clone());
618 else
619 Rhs = FirstCmp.Rhs().LoadI->getPointerOperand();
620
621 Value *IsEqual = nullptr;
622 LLVM_DEBUG(dbgs() << "Merging " << Comparisons.size() << " comparisons -> "
623 << BB->getName() << "\n");
624
625 // If there is one block that requires splitting, we do it now, i.e.
626 // just before we know we will collapse the chain. The instructions
627 // can be executed before any of the instructions in the chain.
628 const auto ToSplit = llvm::find_if(
629 Comparisons, [](const BCECmpBlock &B) { return B.RequireSplit; });
630 if (ToSplit != Comparisons.end()) {
631 LLVM_DEBUG(dbgs() << "Splitting non_BCE work to header\n");
632 ToSplit->split(BB, AA);
633 }
634
635 if (Comparisons.size() == 1) {
636 LLVM_DEBUG(dbgs() << "Only one comparison, updating branches\n");
637 Value *const LhsLoad =
638 Builder.CreateLoad(FirstCmp.Lhs().LoadI->getType(), Lhs);
639 Value *const RhsLoad =
640 Builder.CreateLoad(FirstCmp.Rhs().LoadI->getType(), Rhs);
641 // There are no blocks to merge, just do the comparison.
642 IsEqual = Builder.CreateICmpEQ(LhsLoad, RhsLoad);
643 } else {
644 const unsigned TotalSizeBits = std::accumulate(
645 Comparisons.begin(), Comparisons.end(), 0u,
646 [](int Size, const BCECmpBlock &C) { return Size + C.SizeBits(); });
647
648 // Create memcmp() == 0.
649 const auto &DL = Phi.getModule()->getDataLayout();
650 Value *const MemCmpCall = emitMemCmp(
651 Lhs, Rhs,
652 ConstantInt::get(DL.getIntPtrType(Context), TotalSizeBits / 8), Builder,
653 DL, &TLI);
654 IsEqual = Builder.CreateICmpEQ(
655 MemCmpCall, ConstantInt::get(Type::getInt32Ty(Context), 0));
656 }
657
658 BasicBlock *const PhiBB = Phi.getParent();
659 // Add a branch to the next basic block in the chain.
660 if (NextCmpBlock == PhiBB) {
661 // Continue to phi, passing it the comparison result.
662 Builder.CreateBr(PhiBB);
663 Phi.addIncoming(IsEqual, BB);
664 DTU.applyUpdates({{DominatorTree::Insert, BB, PhiBB}});
665 } else {
666 // Continue to next block if equal, exit to phi else.
667 Builder.CreateCondBr(IsEqual, NextCmpBlock, PhiBB);
668 Phi.addIncoming(ConstantInt::getFalse(Context), BB);
669 DTU.applyUpdates({{DominatorTree::Insert, BB, NextCmpBlock},
670 {DominatorTree::Insert, BB, PhiBB}});
671 }
672 return BB;
673 }
674
simplify(const TargetLibraryInfo & TLI,AliasAnalysis & AA,DomTreeUpdater & DTU)675 bool BCECmpChain::simplify(const TargetLibraryInfo &TLI, AliasAnalysis &AA,
676 DomTreeUpdater &DTU) {
677 assert(atLeastOneMerged() && "simplifying trivial BCECmpChain");
678 LLVM_DEBUG(dbgs() << "Simplifying comparison chain starting at block "
679 << EntryBlock_->getName() << "\n");
680
681 // Effectively merge blocks. We go in the reverse direction from the phi block
682 // so that the next block is always available to branch to.
683 BasicBlock *InsertBefore = EntryBlock_;
684 BasicBlock *NextCmpBlock = Phi_.getParent();
685 for (const auto &Blocks : reverse(MergedBlocks_)) {
686 InsertBefore = NextCmpBlock = mergeComparisons(
687 Blocks, InsertBefore, NextCmpBlock, Phi_, TLI, AA, DTU);
688 }
689
690 // Replace the original cmp chain with the new cmp chain by pointing all
691 // predecessors of EntryBlock_ to NextCmpBlock instead. This makes all cmp
692 // blocks in the old chain unreachable.
693 while (!pred_empty(EntryBlock_)) {
694 BasicBlock* const Pred = *pred_begin(EntryBlock_);
695 LLVM_DEBUG(dbgs() << "Updating jump into old chain from " << Pred->getName()
696 << "\n");
697 Pred->getTerminator()->replaceUsesOfWith(EntryBlock_, NextCmpBlock);
698 DTU.applyUpdates({{DominatorTree::Delete, Pred, EntryBlock_},
699 {DominatorTree::Insert, Pred, NextCmpBlock}});
700 }
701
702 // If the old cmp chain was the function entry, we need to update the function
703 // entry.
704 const bool ChainEntryIsFnEntry = EntryBlock_->isEntryBlock();
705 if (ChainEntryIsFnEntry && DTU.hasDomTree()) {
706 LLVM_DEBUG(dbgs() << "Changing function entry from "
707 << EntryBlock_->getName() << " to "
708 << NextCmpBlock->getName() << "\n");
709 DTU.getDomTree().setNewRoot(NextCmpBlock);
710 DTU.applyUpdates({{DominatorTree::Delete, NextCmpBlock, EntryBlock_}});
711 }
712 EntryBlock_ = nullptr;
713
714 // Delete merged blocks. This also removes incoming values in phi.
715 SmallVector<BasicBlock *, 16> DeadBlocks;
716 for (const auto &Blocks : MergedBlocks_) {
717 for (const BCECmpBlock &Block : Blocks) {
718 LLVM_DEBUG(dbgs() << "Deleting merged block " << Block.BB->getName()
719 << "\n");
720 DeadBlocks.push_back(Block.BB);
721 }
722 }
723 DeleteDeadBlocks(DeadBlocks, &DTU);
724
725 MergedBlocks_.clear();
726 return true;
727 }
728
getOrderedBlocks(PHINode & Phi,BasicBlock * const LastBlock,int NumBlocks)729 std::vector<BasicBlock *> getOrderedBlocks(PHINode &Phi,
730 BasicBlock *const LastBlock,
731 int NumBlocks) {
732 // Walk up from the last block to find other blocks.
733 std::vector<BasicBlock *> Blocks(NumBlocks);
734 assert(LastBlock && "invalid last block");
735 BasicBlock *CurBlock = LastBlock;
736 for (int BlockIndex = NumBlocks - 1; BlockIndex > 0; --BlockIndex) {
737 if (CurBlock->hasAddressTaken()) {
738 // Somebody is jumping to the block through an address, all bets are
739 // off.
740 LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
741 << " has its address taken\n");
742 return {};
743 }
744 Blocks[BlockIndex] = CurBlock;
745 auto *SinglePredecessor = CurBlock->getSinglePredecessor();
746 if (!SinglePredecessor) {
747 // The block has two or more predecessors.
748 LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
749 << " has two or more predecessors\n");
750 return {};
751 }
752 if (Phi.getBasicBlockIndex(SinglePredecessor) < 0) {
753 // The block does not link back to the phi.
754 LLVM_DEBUG(dbgs() << "skip: block " << BlockIndex
755 << " does not link back to the phi\n");
756 return {};
757 }
758 CurBlock = SinglePredecessor;
759 }
760 Blocks[0] = CurBlock;
761 return Blocks;
762 }
763
processPhi(PHINode & Phi,const TargetLibraryInfo & TLI,AliasAnalysis & AA,DomTreeUpdater & DTU)764 bool processPhi(PHINode &Phi, const TargetLibraryInfo &TLI, AliasAnalysis &AA,
765 DomTreeUpdater &DTU) {
766 LLVM_DEBUG(dbgs() << "processPhi()\n");
767 if (Phi.getNumIncomingValues() <= 1) {
768 LLVM_DEBUG(dbgs() << "skip: only one incoming value in phi\n");
769 return false;
770 }
771 // We are looking for something that has the following structure:
772 // bb1 --eq--> bb2 --eq--> bb3 --eq--> bb4 --+
773 // \ \ \ \
774 // ne ne ne \
775 // \ \ \ v
776 // +------------+-----------+----------> bb_phi
777 //
778 // - The last basic block (bb4 here) must branch unconditionally to bb_phi.
779 // It's the only block that contributes a non-constant value to the Phi.
780 // - All other blocks (b1, b2, b3) must have exactly two successors, one of
781 // them being the phi block.
782 // - All intermediate blocks (bb2, bb3) must have only one predecessor.
783 // - Blocks cannot do other work besides the comparison, see doesOtherWork()
784
785 // The blocks are not necessarily ordered in the phi, so we start from the
786 // last block and reconstruct the order.
787 BasicBlock *LastBlock = nullptr;
788 for (unsigned I = 0; I < Phi.getNumIncomingValues(); ++I) {
789 if (isa<ConstantInt>(Phi.getIncomingValue(I))) continue;
790 if (LastBlock) {
791 // There are several non-constant values.
792 LLVM_DEBUG(dbgs() << "skip: several non-constant values\n");
793 return false;
794 }
795 if (!isa<ICmpInst>(Phi.getIncomingValue(I)) ||
796 cast<ICmpInst>(Phi.getIncomingValue(I))->getParent() !=
797 Phi.getIncomingBlock(I)) {
798 // Non-constant incoming value is not from a cmp instruction or not
799 // produced by the last block. We could end up processing the value
800 // producing block more than once.
801 //
802 // This is an uncommon case, so we bail.
803 LLVM_DEBUG(
804 dbgs()
805 << "skip: non-constant value not from cmp or not from last block.\n");
806 return false;
807 }
808 LastBlock = Phi.getIncomingBlock(I);
809 }
810 if (!LastBlock) {
811 // There is no non-constant block.
812 LLVM_DEBUG(dbgs() << "skip: no non-constant block\n");
813 return false;
814 }
815 if (LastBlock->getSingleSuccessor() != Phi.getParent()) {
816 LLVM_DEBUG(dbgs() << "skip: last block non-phi successor\n");
817 return false;
818 }
819
820 const auto Blocks =
821 getOrderedBlocks(Phi, LastBlock, Phi.getNumIncomingValues());
822 if (Blocks.empty()) return false;
823 BCECmpChain CmpChain(Blocks, Phi, AA);
824
825 if (!CmpChain.atLeastOneMerged()) {
826 LLVM_DEBUG(dbgs() << "skip: nothing merged\n");
827 return false;
828 }
829
830 return CmpChain.simplify(TLI, AA, DTU);
831 }
832
runImpl(Function & F,const TargetLibraryInfo & TLI,const TargetTransformInfo & TTI,AliasAnalysis & AA,DominatorTree * DT)833 static bool runImpl(Function &F, const TargetLibraryInfo &TLI,
834 const TargetTransformInfo &TTI, AliasAnalysis &AA,
835 DominatorTree *DT) {
836 LLVM_DEBUG(dbgs() << "MergeICmpsLegacyPass: " << F.getName() << "\n");
837
838 // We only try merging comparisons if the target wants to expand memcmp later.
839 // The rationale is to avoid turning small chains into memcmp calls.
840 if (!TTI.enableMemCmpExpansion(F.hasOptSize(), true))
841 return false;
842
843 // If we don't have memcmp avaiable we can't emit calls to it.
844 if (!TLI.has(LibFunc_memcmp))
845 return false;
846
847 DomTreeUpdater DTU(DT, /*PostDominatorTree*/ nullptr,
848 DomTreeUpdater::UpdateStrategy::Eager);
849
850 bool MadeChange = false;
851
852 for (BasicBlock &BB : llvm::drop_begin(F)) {
853 // A Phi operation is always first in a basic block.
854 if (auto *const Phi = dyn_cast<PHINode>(&*BB.begin()))
855 MadeChange |= processPhi(*Phi, TLI, AA, DTU);
856 }
857
858 return MadeChange;
859 }
860
861 class MergeICmpsLegacyPass : public FunctionPass {
862 public:
863 static char ID;
864
MergeICmpsLegacyPass()865 MergeICmpsLegacyPass() : FunctionPass(ID) {
866 initializeMergeICmpsLegacyPassPass(*PassRegistry::getPassRegistry());
867 }
868
runOnFunction(Function & F)869 bool runOnFunction(Function &F) override {
870 if (skipFunction(F)) return false;
871 const auto &TLI = getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
872 const auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
873 // MergeICmps does not need the DominatorTree, but we update it if it's
874 // already available.
875 auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>();
876 auto &AA = getAnalysis<AAResultsWrapperPass>().getAAResults();
877 return runImpl(F, TLI, TTI, AA, DTWP ? &DTWP->getDomTree() : nullptr);
878 }
879
880 private:
getAnalysisUsage(AnalysisUsage & AU) const881 void getAnalysisUsage(AnalysisUsage &AU) const override {
882 AU.addRequired<TargetLibraryInfoWrapperPass>();
883 AU.addRequired<TargetTransformInfoWrapperPass>();
884 AU.addRequired<AAResultsWrapperPass>();
885 AU.addPreserved<GlobalsAAWrapperPass>();
886 AU.addPreserved<DominatorTreeWrapperPass>();
887 }
888 };
889
890 } // namespace
891
892 char MergeICmpsLegacyPass::ID = 0;
893 INITIALIZE_PASS_BEGIN(MergeICmpsLegacyPass, "mergeicmps",
894 "Merge contiguous icmps into a memcmp", false, false)
INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)895 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
896 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
897 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
898 INITIALIZE_PASS_END(MergeICmpsLegacyPass, "mergeicmps",
899 "Merge contiguous icmps into a memcmp", false, false)
900
901 Pass *llvm::createMergeICmpsLegacyPass() { return new MergeICmpsLegacyPass(); }
902
run(Function & F,FunctionAnalysisManager & AM)903 PreservedAnalyses MergeICmpsPass::run(Function &F,
904 FunctionAnalysisManager &AM) {
905 auto &TLI = AM.getResult<TargetLibraryAnalysis>(F);
906 auto &TTI = AM.getResult<TargetIRAnalysis>(F);
907 auto &AA = AM.getResult<AAManager>(F);
908 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F);
909 const bool MadeChanges = runImpl(F, TLI, TTI, AA, DT);
910 if (!MadeChanges)
911 return PreservedAnalyses::all();
912 PreservedAnalyses PA;
913 PA.preserve<DominatorTreeAnalysis>();
914 return PA;
915 }
916