1 //===- ConstantHoisting.cpp - Prepare code for expensive constants --------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass identifies expensive constants to hoist and coalesces them to
11 // better prepare it for SelectionDAG-based code generation. This works around
12 // the limitations of the basic-block-at-a-time approach.
13 //
14 // First it scans all instructions for integer constants and calculates its
15 // cost. If the constant can be folded into the instruction (the cost is
16 // TCC_Free) or the cost is just a simple operation (TCC_BASIC), then we don't
17 // consider it expensive and leave it alone. This is the default behavior and
18 // the default implementation of getIntImmCost will always return TCC_Free.
19 //
20 // If the cost is more than TCC_BASIC, then the integer constant can't be folded
21 // into the instruction and it might be beneficial to hoist the constant.
22 // Similar constants are coalesced to reduce register pressure and
23 // materialization code.
24 //
25 // When a constant is hoisted, it is also hidden behind a bitcast to force it to
26 // be live-out of the basic block. Otherwise the constant would be just
27 // duplicated and each basic block would have its own copy in the SelectionDAG.
28 // The SelectionDAG recognizes such constants as opaque and doesn't perform
29 // certain transformations on them, which would create a new expensive constant.
30 //
31 // This optimization is only applied to integer constants in instructions and
32 // simple (this means not nested) constant cast expressions. For example:
33 // %0 = load i64* inttoptr (i64 big_constant to i64*)
34 //===----------------------------------------------------------------------===//
35 
36 #include "llvm/Transforms/Scalar/ConstantHoisting.h"
37 #include "llvm/ADT/SmallSet.h"
38 #include "llvm/ADT/SmallVector.h"
39 #include "llvm/ADT/Statistic.h"
40 #include "llvm/IR/Constants.h"
41 #include "llvm/IR/GetElementPtrTypeIterator.h"
42 #include "llvm/IR/IntrinsicInst.h"
43 #include "llvm/Pass.h"
44 #include "llvm/Support/Debug.h"
45 #include "llvm/Support/raw_ostream.h"
46 #include "llvm/Transforms/Scalar.h"
47 #include <tuple>
48 
49 using namespace llvm;
50 using namespace consthoist;
51 
52 #define DEBUG_TYPE "consthoist"
53 
54 STATISTIC(NumConstantsHoisted, "Number of constants hoisted");
55 STATISTIC(NumConstantsRebased, "Number of constants rebased");
56 
57 static cl::opt<bool> ConstHoistWithBlockFrequency(
58     "consthoist-with-block-frequency", cl::init(false), cl::Hidden,
59     cl::desc("Enable the use of the block frequency analysis to reduce the "
60              "chance to execute const materialization more frequently than "
61              "without hoisting."));
62 
63 namespace {
64 /// \brief The constant hoisting pass.
65 class ConstantHoistingLegacyPass : public FunctionPass {
66 public:
67   static char ID; // Pass identification, replacement for typeid
68   ConstantHoistingLegacyPass() : FunctionPass(ID) {
69     initializeConstantHoistingLegacyPassPass(*PassRegistry::getPassRegistry());
70   }
71 
72   bool runOnFunction(Function &Fn) override;
73 
74   StringRef getPassName() const override { return "Constant Hoisting"; }
75 
76   void getAnalysisUsage(AnalysisUsage &AU) const override {
77     AU.setPreservesCFG();
78     if (ConstHoistWithBlockFrequency)
79       AU.addRequired<BlockFrequencyInfoWrapperPass>();
80     AU.addRequired<DominatorTreeWrapperPass>();
81     AU.addRequired<TargetTransformInfoWrapperPass>();
82   }
83 
84   void releaseMemory() override { Impl.releaseMemory(); }
85 
86 private:
87   ConstantHoistingPass Impl;
88 };
89 }
90 
91 char ConstantHoistingLegacyPass::ID = 0;
92 INITIALIZE_PASS_BEGIN(ConstantHoistingLegacyPass, "consthoist",
93                       "Constant Hoisting", false, false)
94 INITIALIZE_PASS_DEPENDENCY(BlockFrequencyInfoWrapperPass)
95 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
96 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
97 INITIALIZE_PASS_END(ConstantHoistingLegacyPass, "consthoist",
98                     "Constant Hoisting", false, false)
99 
100 FunctionPass *llvm::createConstantHoistingPass() {
101   return new ConstantHoistingLegacyPass();
102 }
103 
104 /// \brief Perform the constant hoisting optimization for the given function.
105 bool ConstantHoistingLegacyPass::runOnFunction(Function &Fn) {
106   if (skipFunction(Fn))
107     return false;
108 
109   DEBUG(dbgs() << "********** Begin Constant Hoisting **********\n");
110   DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
111 
112   bool MadeChange =
113       Impl.runImpl(Fn, getAnalysis<TargetTransformInfoWrapperPass>().getTTI(Fn),
114                    getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
115                    ConstHoistWithBlockFrequency
116                        ? &getAnalysis<BlockFrequencyInfoWrapperPass>().getBFI()
117                        : nullptr,
118                    Fn.getEntryBlock());
119 
120   if (MadeChange) {
121     DEBUG(dbgs() << "********** Function after Constant Hoisting: "
122                  << Fn.getName() << '\n');
123     DEBUG(dbgs() << Fn);
124   }
125   DEBUG(dbgs() << "********** End Constant Hoisting **********\n");
126 
127   return MadeChange;
128 }
129 
130 
131 /// \brief Find the constant materialization insertion point.
132 Instruction *ConstantHoistingPass::findMatInsertPt(Instruction *Inst,
133                                                    unsigned Idx) const {
134   // If the operand is a cast instruction, then we have to materialize the
135   // constant before the cast instruction.
136   if (Idx != ~0U) {
137     Value *Opnd = Inst->getOperand(Idx);
138     if (auto CastInst = dyn_cast<Instruction>(Opnd))
139       if (CastInst->isCast())
140         return CastInst;
141   }
142 
143   // The simple and common case. This also includes constant expressions.
144   if (!isa<PHINode>(Inst) && !Inst->isEHPad())
145     return Inst;
146 
147   // We can't insert directly before a phi node or an eh pad. Insert before
148   // the terminator of the incoming or dominating block.
149   assert(Entry != Inst->getParent() && "PHI or landing pad in entry block!");
150   if (Idx != ~0U && isa<PHINode>(Inst))
151     return cast<PHINode>(Inst)->getIncomingBlock(Idx)->getTerminator();
152 
153   // This must be an EH pad. Iterate over immediate dominators until we find a
154   // non-EH pad. We need to skip over catchswitch blocks, which are both EH pads
155   // and terminators.
156   auto IDom = DT->getNode(Inst->getParent())->getIDom();
157   while (IDom->getBlock()->isEHPad()) {
158     assert(Entry != IDom->getBlock() && "eh pad in entry block");
159     IDom = IDom->getIDom();
160   }
161 
162   return IDom->getBlock()->getTerminator();
163 }
164 
165 /// \brief Given \p BBs as input, find another set of BBs which collectively
166 /// dominates \p BBs and have the minimal sum of frequencies. Return the BB
167 /// set found in \p BBs.
168 static void findBestInsertionSet(DominatorTree &DT, BlockFrequencyInfo &BFI,
169                                  BasicBlock *Entry,
170                                  SmallPtrSet<BasicBlock *, 8> &BBs) {
171   assert(!BBs.count(Entry) && "Assume Entry is not in BBs");
172   // Nodes on the current path to the root.
173   SmallPtrSet<BasicBlock *, 8> Path;
174   // Candidates includes any block 'BB' in set 'BBs' that is not strictly
175   // dominated by any other blocks in set 'BBs', and all nodes in the path
176   // in the dominator tree from Entry to 'BB'.
177   SmallPtrSet<BasicBlock *, 16> Candidates;
178   for (auto BB : BBs) {
179     Path.clear();
180     // Walk up the dominator tree until Entry or another BB in BBs
181     // is reached. Insert the nodes on the way to the Path.
182     BasicBlock *Node = BB;
183     // The "Path" is a candidate path to be added into Candidates set.
184     bool isCandidate = false;
185     do {
186       Path.insert(Node);
187       if (Node == Entry || Candidates.count(Node)) {
188         isCandidate = true;
189         break;
190       }
191       assert(DT.getNode(Node)->getIDom() &&
192              "Entry doens't dominate current Node");
193       Node = DT.getNode(Node)->getIDom()->getBlock();
194     } while (!BBs.count(Node));
195 
196     // If isCandidate is false, Node is another Block in BBs dominating
197     // current 'BB'. Drop the nodes on the Path.
198     if (!isCandidate)
199       continue;
200 
201     // Add nodes on the Path into Candidates.
202     Candidates.insert(Path.begin(), Path.end());
203   }
204 
205   // Sort the nodes in Candidates in top-down order and save the nodes
206   // in Orders.
207   unsigned Idx = 0;
208   SmallVector<BasicBlock *, 16> Orders;
209   Orders.push_back(Entry);
210   while (Idx != Orders.size()) {
211     BasicBlock *Node = Orders[Idx++];
212     for (auto ChildDomNode : DT.getNode(Node)->getChildren()) {
213       if (Candidates.count(ChildDomNode->getBlock()))
214         Orders.push_back(ChildDomNode->getBlock());
215     }
216   }
217 
218   // Visit Orders in bottom-up order.
219   typedef std::pair<SmallPtrSet<BasicBlock *, 16>, BlockFrequency>
220       InsertPtsCostPair;
221   // InsertPtsMap is a map from a BB to the best insertion points for the
222   // subtree of BB (subtree not including the BB itself).
223   DenseMap<BasicBlock *, InsertPtsCostPair> InsertPtsMap;
224   InsertPtsMap.reserve(Orders.size() + 1);
225   for (auto RIt = Orders.rbegin(); RIt != Orders.rend(); RIt++) {
226     BasicBlock *Node = *RIt;
227     bool NodeInBBs = BBs.count(Node);
228     SmallPtrSet<BasicBlock *, 16> &InsertPts = InsertPtsMap[Node].first;
229     BlockFrequency &InsertPtsFreq = InsertPtsMap[Node].second;
230 
231     // Return the optimal insert points in BBs.
232     if (Node == Entry) {
233       BBs.clear();
234       if (InsertPtsFreq > BFI.getBlockFreq(Node))
235         BBs.insert(Entry);
236       else
237         BBs.insert(InsertPts.begin(), InsertPts.end());
238       break;
239     }
240 
241     BasicBlock *Parent = DT.getNode(Node)->getIDom()->getBlock();
242     // Initially, ParentInsertPts is empty and ParentPtsFreq is 0. Every child
243     // will update its parent's ParentInsertPts and ParentPtsFreq.
244     SmallPtrSet<BasicBlock *, 16> &ParentInsertPts = InsertPtsMap[Parent].first;
245     BlockFrequency &ParentPtsFreq = InsertPtsMap[Parent].second;
246     // Choose to insert in Node or in subtree of Node.
247     if (InsertPtsFreq > BFI.getBlockFreq(Node) || NodeInBBs) {
248       ParentInsertPts.insert(Node);
249       ParentPtsFreq += BFI.getBlockFreq(Node);
250     } else {
251       ParentInsertPts.insert(InsertPts.begin(), InsertPts.end());
252       ParentPtsFreq += InsertPtsFreq;
253     }
254   }
255 }
256 
257 /// \brief Find an insertion point that dominates all uses.
258 SmallPtrSet<Instruction *, 8> ConstantHoistingPass::findConstantInsertionPoint(
259     const ConstantInfo &ConstInfo) const {
260   assert(!ConstInfo.RebasedConstants.empty() && "Invalid constant info entry.");
261   // Collect all basic blocks.
262   SmallPtrSet<BasicBlock *, 8> BBs;
263   SmallPtrSet<Instruction *, 8> InsertPts;
264   for (auto const &RCI : ConstInfo.RebasedConstants)
265     for (auto const &U : RCI.Uses)
266       BBs.insert(findMatInsertPt(U.Inst, U.OpndIdx)->getParent());
267 
268   if (BBs.count(Entry)) {
269     InsertPts.insert(&Entry->front());
270     return InsertPts;
271   }
272 
273   if (BFI) {
274     findBestInsertionSet(*DT, *BFI, Entry, BBs);
275     for (auto BB : BBs) {
276       BasicBlock::iterator InsertPt = BB->begin();
277       for (; isa<PHINode>(InsertPt) || InsertPt->isEHPad(); ++InsertPt)
278         ;
279       InsertPts.insert(&*InsertPt);
280     }
281     return InsertPts;
282   }
283 
284   while (BBs.size() >= 2) {
285     BasicBlock *BB, *BB1, *BB2;
286     BB1 = *BBs.begin();
287     BB2 = *std::next(BBs.begin());
288     BB = DT->findNearestCommonDominator(BB1, BB2);
289     if (BB == Entry) {
290       InsertPts.insert(&Entry->front());
291       return InsertPts;
292     }
293     BBs.erase(BB1);
294     BBs.erase(BB2);
295     BBs.insert(BB);
296   }
297   assert((BBs.size() == 1) && "Expected only one element.");
298   Instruction &FirstInst = (*BBs.begin())->front();
299   InsertPts.insert(findMatInsertPt(&FirstInst));
300   return InsertPts;
301 }
302 
303 
304 /// \brief Record constant integer ConstInt for instruction Inst at operand
305 /// index Idx.
306 ///
307 /// The operand at index Idx is not necessarily the constant integer itself. It
308 /// could also be a cast instruction or a constant expression that uses the
309 // constant integer.
310 void ConstantHoistingPass::collectConstantCandidates(
311     ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx,
312     ConstantInt *ConstInt) {
313   unsigned Cost;
314   // Ask the target about the cost of materializing the constant for the given
315   // instruction and operand index.
316   if (auto IntrInst = dyn_cast<IntrinsicInst>(Inst))
317     Cost = TTI->getIntImmCost(IntrInst->getIntrinsicID(), Idx,
318                               ConstInt->getValue(), ConstInt->getType());
319   else
320     Cost = TTI->getIntImmCost(Inst->getOpcode(), Idx, ConstInt->getValue(),
321                               ConstInt->getType());
322 
323   // Ignore cheap integer constants.
324   if (Cost > TargetTransformInfo::TCC_Basic) {
325     ConstCandMapType::iterator Itr;
326     bool Inserted;
327     std::tie(Itr, Inserted) = ConstCandMap.insert(std::make_pair(ConstInt, 0));
328     if (Inserted) {
329       ConstCandVec.push_back(ConstantCandidate(ConstInt));
330       Itr->second = ConstCandVec.size() - 1;
331     }
332     ConstCandVec[Itr->second].addUser(Inst, Idx, Cost);
333     DEBUG(if (isa<ConstantInt>(Inst->getOperand(Idx)))
334             dbgs() << "Collect constant " << *ConstInt << " from " << *Inst
335                    << " with cost " << Cost << '\n';
336           else
337           dbgs() << "Collect constant " << *ConstInt << " indirectly from "
338                  << *Inst << " via " << *Inst->getOperand(Idx) << " with cost "
339                  << Cost << '\n';
340     );
341   }
342 }
343 
344 
345 /// \brief Check the operand for instruction Inst at index Idx.
346 void ConstantHoistingPass::collectConstantCandidates(
347     ConstCandMapType &ConstCandMap, Instruction *Inst, unsigned Idx) {
348   Value *Opnd = Inst->getOperand(Idx);
349 
350   // Visit constant integers.
351   if (auto ConstInt = dyn_cast<ConstantInt>(Opnd)) {
352     collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt);
353     return;
354   }
355 
356   // Visit cast instructions that have constant integers.
357   if (auto CastInst = dyn_cast<Instruction>(Opnd)) {
358     // Only visit cast instructions, which have been skipped. All other
359     // instructions should have already been visited.
360     if (!CastInst->isCast())
361       return;
362 
363     if (auto *ConstInt = dyn_cast<ConstantInt>(CastInst->getOperand(0))) {
364       // Pretend the constant is directly used by the instruction and ignore
365       // the cast instruction.
366       collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt);
367       return;
368     }
369   }
370 
371   // Visit constant expressions that have constant integers.
372   if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) {
373     // Only visit constant cast expressions.
374     if (!ConstExpr->isCast())
375       return;
376 
377     if (auto ConstInt = dyn_cast<ConstantInt>(ConstExpr->getOperand(0))) {
378       // Pretend the constant is directly used by the instruction and ignore
379       // the constant expression.
380       collectConstantCandidates(ConstCandMap, Inst, Idx, ConstInt);
381       return;
382     }
383   }
384 }
385 
386 
387 /// \brief Scan the instruction for expensive integer constants and record them
388 /// in the constant candidate vector.
389 void ConstantHoistingPass::collectConstantCandidates(
390     ConstCandMapType &ConstCandMap, Instruction *Inst) {
391   // Skip all cast instructions. They are visited indirectly later on.
392   if (Inst->isCast())
393     return;
394 
395   // Can't handle inline asm. Skip it.
396   if (auto Call = dyn_cast<CallInst>(Inst))
397     if (isa<InlineAsm>(Call->getCalledValue()))
398       return;
399 
400   // Switch cases must remain constant, and if the value being tested is
401   // constant the entire thing should disappear.
402   if (isa<SwitchInst>(Inst))
403     return;
404 
405   // Static allocas (constant size in the entry block) are handled by
406   // prologue/epilogue insertion so they're free anyway. We definitely don't
407   // want to make them non-constant.
408   auto AI = dyn_cast<AllocaInst>(Inst);
409   if (AI && AI->isStaticAlloca())
410     return;
411 
412   // Constants in GEPs that index into a struct type should not be hoisted.
413   if (isa<GetElementPtrInst>(Inst)) {
414     gep_type_iterator GTI = gep_type_begin(Inst);
415 
416     // Collect constant for first operand.
417     collectConstantCandidates(ConstCandMap, Inst, 0);
418     // Scan rest operands.
419     for (unsigned Idx = 1, E = Inst->getNumOperands(); Idx != E; ++Idx, ++GTI) {
420       // Only collect constants that index into a non struct type.
421       if (!GTI.isStruct()) {
422         collectConstantCandidates(ConstCandMap, Inst, Idx);
423       }
424     }
425     return;
426   }
427 
428   // Scan all operands.
429   for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
430     collectConstantCandidates(ConstCandMap, Inst, Idx);
431   } // end of for all operands
432 }
433 
434 /// \brief Collect all integer constants in the function that cannot be folded
435 /// into an instruction itself.
436 void ConstantHoistingPass::collectConstantCandidates(Function &Fn) {
437   ConstCandMapType ConstCandMap;
438   for (BasicBlock &BB : Fn)
439     for (Instruction &Inst : BB)
440       collectConstantCandidates(ConstCandMap, &Inst);
441 }
442 
443 // This helper function is necessary to deal with values that have different
444 // bit widths (APInt Operator- does not like that). If the value cannot be
445 // represented in uint64 we return an "empty" APInt. This is then interpreted
446 // as the value is not in range.
447 static llvm::Optional<APInt> calculateOffsetDiff(const APInt &V1,
448                                                  const APInt &V2) {
449   llvm::Optional<APInt> Res = None;
450   unsigned BW = V1.getBitWidth() > V2.getBitWidth() ?
451                 V1.getBitWidth() : V2.getBitWidth();
452   uint64_t LimVal1 = V1.getLimitedValue();
453   uint64_t LimVal2 = V2.getLimitedValue();
454 
455   if (LimVal1 == ~0ULL || LimVal2 == ~0ULL)
456     return Res;
457 
458   uint64_t Diff = LimVal1 - LimVal2;
459   return APInt(BW, Diff, true);
460 }
461 
462 // From a list of constants, one needs to picked as the base and the other
463 // constants will be transformed into an offset from that base constant. The
464 // question is which we can pick best? For example, consider these constants
465 // and their number of uses:
466 //
467 //  Constants| 2 | 4 | 12 | 42 |
468 //  NumUses  | 3 | 2 |  8 |  7 |
469 //
470 // Selecting constant 12 because it has the most uses will generate negative
471 // offsets for constants 2 and 4 (i.e. -10 and -8 respectively). If negative
472 // offsets lead to less optimal code generation, then there might be better
473 // solutions. Suppose immediates in the range of 0..35 are most optimally
474 // supported by the architecture, then selecting constant 2 is most optimal
475 // because this will generate offsets: 0, 2, 10, 40. Offsets 0, 2 and 10 are in
476 // range 0..35, and thus 3 + 2 + 8 = 13 uses are in range. Selecting 12 would
477 // have only 8 uses in range, so choosing 2 as a base is more optimal. Thus, in
478 // selecting the base constant the range of the offsets is a very important
479 // factor too that we take into account here. This algorithm calculates a total
480 // costs for selecting a constant as the base and substract the costs if
481 // immediates are out of range. It has quadratic complexity, so we call this
482 // function only when we're optimising for size and there are less than 100
483 // constants, we fall back to the straightforward algorithm otherwise
484 // which does not do all the offset calculations.
485 unsigned
486 ConstantHoistingPass::maximizeConstantsInRange(ConstCandVecType::iterator S,
487                                            ConstCandVecType::iterator E,
488                                            ConstCandVecType::iterator &MaxCostItr) {
489   unsigned NumUses = 0;
490 
491   if(!Entry->getParent()->optForSize() || std::distance(S,E) > 100) {
492     for (auto ConstCand = S; ConstCand != E; ++ConstCand) {
493       NumUses += ConstCand->Uses.size();
494       if (ConstCand->CumulativeCost > MaxCostItr->CumulativeCost)
495         MaxCostItr = ConstCand;
496     }
497     return NumUses;
498   }
499 
500   DEBUG(dbgs() << "== Maximize constants in range ==\n");
501   int MaxCost = -1;
502   for (auto ConstCand = S; ConstCand != E; ++ConstCand) {
503     auto Value = ConstCand->ConstInt->getValue();
504     Type *Ty = ConstCand->ConstInt->getType();
505     int Cost = 0;
506     NumUses += ConstCand->Uses.size();
507     DEBUG(dbgs() << "= Constant: " << ConstCand->ConstInt->getValue() << "\n");
508 
509     for (auto User : ConstCand->Uses) {
510       unsigned Opcode = User.Inst->getOpcode();
511       unsigned OpndIdx = User.OpndIdx;
512       Cost += TTI->getIntImmCost(Opcode, OpndIdx, Value, Ty);
513       DEBUG(dbgs() << "Cost: " << Cost << "\n");
514 
515       for (auto C2 = S; C2 != E; ++C2) {
516         llvm::Optional<APInt> Diff = calculateOffsetDiff(
517                                       C2->ConstInt->getValue(),
518                                       ConstCand->ConstInt->getValue());
519         if (Diff) {
520           const int ImmCosts =
521             TTI->getIntImmCodeSizeCost(Opcode, OpndIdx, Diff.getValue(), Ty);
522           Cost -= ImmCosts;
523           DEBUG(dbgs() << "Offset " << Diff.getValue() << " "
524                        << "has penalty: " << ImmCosts << "\n"
525                        << "Adjusted cost: " << Cost << "\n");
526         }
527       }
528     }
529     DEBUG(dbgs() << "Cumulative cost: " << Cost << "\n");
530     if (Cost > MaxCost) {
531       MaxCost = Cost;
532       MaxCostItr = ConstCand;
533       DEBUG(dbgs() << "New candidate: " << MaxCostItr->ConstInt->getValue()
534                    << "\n");
535     }
536   }
537   return NumUses;
538 }
539 
540 /// \brief Find the base constant within the given range and rebase all other
541 /// constants with respect to the base constant.
542 void ConstantHoistingPass::findAndMakeBaseConstant(
543     ConstCandVecType::iterator S, ConstCandVecType::iterator E) {
544   auto MaxCostItr = S;
545   unsigned NumUses = maximizeConstantsInRange(S, E, MaxCostItr);
546 
547   // Don't hoist constants that have only one use.
548   if (NumUses <= 1)
549     return;
550 
551   ConstantInfo ConstInfo;
552   ConstInfo.BaseConstant = MaxCostItr->ConstInt;
553   Type *Ty = ConstInfo.BaseConstant->getType();
554 
555   // Rebase the constants with respect to the base constant.
556   for (auto ConstCand = S; ConstCand != E; ++ConstCand) {
557     APInt Diff = ConstCand->ConstInt->getValue() -
558                  ConstInfo.BaseConstant->getValue();
559     Constant *Offset = Diff == 0 ? nullptr : ConstantInt::get(Ty, Diff);
560     ConstInfo.RebasedConstants.push_back(
561       RebasedConstantInfo(std::move(ConstCand->Uses), Offset));
562   }
563   ConstantVec.push_back(std::move(ConstInfo));
564 }
565 
566 /// \brief Finds and combines constant candidates that can be easily
567 /// rematerialized with an add from a common base constant.
568 void ConstantHoistingPass::findBaseConstants() {
569   // Sort the constants by value and type. This invalidates the mapping!
570   std::sort(ConstCandVec.begin(), ConstCandVec.end(),
571             [](const ConstantCandidate &LHS, const ConstantCandidate &RHS) {
572     if (LHS.ConstInt->getType() != RHS.ConstInt->getType())
573       return LHS.ConstInt->getType()->getBitWidth() <
574              RHS.ConstInt->getType()->getBitWidth();
575     return LHS.ConstInt->getValue().ult(RHS.ConstInt->getValue());
576   });
577 
578   // Simple linear scan through the sorted constant candidate vector for viable
579   // merge candidates.
580   auto MinValItr = ConstCandVec.begin();
581   for (auto CC = std::next(ConstCandVec.begin()), E = ConstCandVec.end();
582        CC != E; ++CC) {
583     if (MinValItr->ConstInt->getType() == CC->ConstInt->getType()) {
584       // Check if the constant is in range of an add with immediate.
585       APInt Diff = CC->ConstInt->getValue() - MinValItr->ConstInt->getValue();
586       if ((Diff.getBitWidth() <= 64) &&
587           TTI->isLegalAddImmediate(Diff.getSExtValue()))
588         continue;
589     }
590     // We either have now a different constant type or the constant is not in
591     // range of an add with immediate anymore.
592     findAndMakeBaseConstant(MinValItr, CC);
593     // Start a new base constant search.
594     MinValItr = CC;
595   }
596   // Finalize the last base constant search.
597   findAndMakeBaseConstant(MinValItr, ConstCandVec.end());
598 }
599 
600 /// \brief Updates the operand at Idx in instruction Inst with the result of
601 ///        instruction Mat. If the instruction is a PHI node then special
602 ///        handling for duplicate values form the same incoming basic block is
603 ///        required.
604 /// \return The update will always succeed, but the return value indicated if
605 ///         Mat was used for the update or not.
606 static bool updateOperand(Instruction *Inst, unsigned Idx, Instruction *Mat) {
607   if (auto PHI = dyn_cast<PHINode>(Inst)) {
608     // Check if any previous operand of the PHI node has the same incoming basic
609     // block. This is a very odd case that happens when the incoming basic block
610     // has a switch statement. In this case use the same value as the previous
611     // operand(s), otherwise we will fail verification due to different values.
612     // The values are actually the same, but the variable names are different
613     // and the verifier doesn't like that.
614     BasicBlock *IncomingBB = PHI->getIncomingBlock(Idx);
615     for (unsigned i = 0; i < Idx; ++i) {
616       if (PHI->getIncomingBlock(i) == IncomingBB) {
617         Value *IncomingVal = PHI->getIncomingValue(i);
618         Inst->setOperand(Idx, IncomingVal);
619         return false;
620       }
621     }
622   }
623 
624   Inst->setOperand(Idx, Mat);
625   return true;
626 }
627 
628 /// \brief Emit materialization code for all rebased constants and update their
629 /// users.
630 void ConstantHoistingPass::emitBaseConstants(Instruction *Base,
631                                              Constant *Offset,
632                                              const ConstantUser &ConstUser) {
633   Instruction *Mat = Base;
634   if (Offset) {
635     Instruction *InsertionPt = findMatInsertPt(ConstUser.Inst,
636                                                ConstUser.OpndIdx);
637     Mat = BinaryOperator::Create(Instruction::Add, Base, Offset,
638                                  "const_mat", InsertionPt);
639 
640     DEBUG(dbgs() << "Materialize constant (" << *Base->getOperand(0)
641                  << " + " << *Offset << ") in BB "
642                  << Mat->getParent()->getName() << '\n' << *Mat << '\n');
643     Mat->setDebugLoc(ConstUser.Inst->getDebugLoc());
644   }
645   Value *Opnd = ConstUser.Inst->getOperand(ConstUser.OpndIdx);
646 
647   // Visit constant integer.
648   if (isa<ConstantInt>(Opnd)) {
649     DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n');
650     if (!updateOperand(ConstUser.Inst, ConstUser.OpndIdx, Mat) && Offset)
651       Mat->eraseFromParent();
652     DEBUG(dbgs() << "To    : " << *ConstUser.Inst << '\n');
653     return;
654   }
655 
656   // Visit cast instruction.
657   if (auto CastInst = dyn_cast<Instruction>(Opnd)) {
658     assert(CastInst->isCast() && "Expected an cast instruction!");
659     // Check if we already have visited this cast instruction before to avoid
660     // unnecessary cloning.
661     Instruction *&ClonedCastInst = ClonedCastMap[CastInst];
662     if (!ClonedCastInst) {
663       ClonedCastInst = CastInst->clone();
664       ClonedCastInst->setOperand(0, Mat);
665       ClonedCastInst->insertAfter(CastInst);
666       // Use the same debug location as the original cast instruction.
667       ClonedCastInst->setDebugLoc(CastInst->getDebugLoc());
668       DEBUG(dbgs() << "Clone instruction: " << *CastInst << '\n'
669                    << "To               : " << *ClonedCastInst << '\n');
670     }
671 
672     DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n');
673     updateOperand(ConstUser.Inst, ConstUser.OpndIdx, ClonedCastInst);
674     DEBUG(dbgs() << "To    : " << *ConstUser.Inst << '\n');
675     return;
676   }
677 
678   // Visit constant expression.
679   if (auto ConstExpr = dyn_cast<ConstantExpr>(Opnd)) {
680     Instruction *ConstExprInst = ConstExpr->getAsInstruction();
681     ConstExprInst->setOperand(0, Mat);
682     ConstExprInst->insertBefore(findMatInsertPt(ConstUser.Inst,
683                                                 ConstUser.OpndIdx));
684 
685     // Use the same debug location as the instruction we are about to update.
686     ConstExprInst->setDebugLoc(ConstUser.Inst->getDebugLoc());
687 
688     DEBUG(dbgs() << "Create instruction: " << *ConstExprInst << '\n'
689                  << "From              : " << *ConstExpr << '\n');
690     DEBUG(dbgs() << "Update: " << *ConstUser.Inst << '\n');
691     if (!updateOperand(ConstUser.Inst, ConstUser.OpndIdx, ConstExprInst)) {
692       ConstExprInst->eraseFromParent();
693       if (Offset)
694         Mat->eraseFromParent();
695     }
696     DEBUG(dbgs() << "To    : " << *ConstUser.Inst << '\n');
697     return;
698   }
699 }
700 
701 /// \brief Hoist and hide the base constant behind a bitcast and emit
702 /// materialization code for derived constants.
703 bool ConstantHoistingPass::emitBaseConstants() {
704   bool MadeChange = false;
705   for (auto const &ConstInfo : ConstantVec) {
706     // Hoist and hide the base constant behind a bitcast.
707     SmallPtrSet<Instruction *, 8> IPSet = findConstantInsertionPoint(ConstInfo);
708     assert(!IPSet.empty() && "IPSet is empty");
709 
710     unsigned UsesNum = 0;
711     unsigned ReBasesNum = 0;
712     for (Instruction *IP : IPSet) {
713       IntegerType *Ty = ConstInfo.BaseConstant->getType();
714       Instruction *Base =
715           new BitCastInst(ConstInfo.BaseConstant, Ty, "const", IP);
716       DEBUG(dbgs() << "Hoist constant (" << *ConstInfo.BaseConstant
717                    << ") to BB " << IP->getParent()->getName() << '\n'
718                    << *Base << '\n');
719 
720       // Emit materialization code for all rebased constants.
721       unsigned Uses = 0;
722       for (auto const &RCI : ConstInfo.RebasedConstants) {
723         for (auto const &U : RCI.Uses) {
724           Uses++;
725           BasicBlock *OrigMatInsertBB =
726               findMatInsertPt(U.Inst, U.OpndIdx)->getParent();
727           // If Base constant is to be inserted in multiple places,
728           // generate rebase for U using the Base dominating U.
729           if (IPSet.size() == 1 ||
730               DT->dominates(Base->getParent(), OrigMatInsertBB)) {
731             emitBaseConstants(Base, RCI.Offset, U);
732             ReBasesNum++;
733           }
734         }
735       }
736       UsesNum = Uses;
737 
738       // Use the same debug location as the last user of the constant.
739       assert(!Base->use_empty() && "The use list is empty!?");
740       assert(isa<Instruction>(Base->user_back()) &&
741              "All uses should be instructions.");
742       Base->setDebugLoc(cast<Instruction>(Base->user_back())->getDebugLoc());
743     }
744     (void)UsesNum;
745     (void)ReBasesNum;
746     // Expect all uses are rebased after rebase is done.
747     assert(UsesNum == ReBasesNum && "Not all uses are rebased");
748 
749     NumConstantsHoisted++;
750 
751     // Base constant is also included in ConstInfo.RebasedConstants, so
752     // deduct 1 from ConstInfo.RebasedConstants.size().
753     NumConstantsRebased = ConstInfo.RebasedConstants.size() - 1;
754 
755     MadeChange = true;
756   }
757   return MadeChange;
758 }
759 
760 /// \brief Check all cast instructions we made a copy of and remove them if they
761 /// have no more users.
762 void ConstantHoistingPass::deleteDeadCastInst() const {
763   for (auto const &I : ClonedCastMap)
764     if (I.first->use_empty())
765       I.first->eraseFromParent();
766 }
767 
768 /// \brief Optimize expensive integer constants in the given function.
769 bool ConstantHoistingPass::runImpl(Function &Fn, TargetTransformInfo &TTI,
770                                    DominatorTree &DT, BlockFrequencyInfo *BFI,
771                                    BasicBlock &Entry) {
772   this->TTI = &TTI;
773   this->DT = &DT;
774   this->BFI = BFI;
775   this->Entry = &Entry;
776   // Collect all constant candidates.
777   collectConstantCandidates(Fn);
778 
779   // There are no constant candidates to worry about.
780   if (ConstCandVec.empty())
781     return false;
782 
783   // Combine constants that can be easily materialized with an add from a common
784   // base constant.
785   findBaseConstants();
786 
787   // There are no constants to emit.
788   if (ConstantVec.empty())
789     return false;
790 
791   // Finally hoist the base constant and emit materialization code for dependent
792   // constants.
793   bool MadeChange = emitBaseConstants();
794 
795   // Cleanup dead instructions.
796   deleteDeadCastInst();
797 
798   return MadeChange;
799 }
800 
801 PreservedAnalyses ConstantHoistingPass::run(Function &F,
802                                             FunctionAnalysisManager &AM) {
803   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
804   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
805   auto BFI = ConstHoistWithBlockFrequency
806                  ? &AM.getResult<BlockFrequencyAnalysis>(F)
807                  : nullptr;
808   if (!runImpl(F, TTI, DT, BFI, F.getEntryBlock()))
809     return PreservedAnalyses::all();
810 
811   PreservedAnalyses PA;
812   PA.preserveSet<CFGAnalyses>();
813   return PA;
814 }
815