1 //===-- LoopUtils.cpp - Loop Utility functions -------------------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This file defines common loop utility functions.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Transforms/Utils/LoopUtils.h"
15 #include "llvm/ADT/ScopeExit.h"
16 #include "llvm/Analysis/AliasAnalysis.h"
17 #include "llvm/Analysis/BasicAliasAnalysis.h"
18 #include "llvm/Analysis/GlobalsModRef.h"
19 #include "llvm/Analysis/InstructionSimplify.h"
20 #include "llvm/Analysis/LoopInfo.h"
21 #include "llvm/Analysis/LoopPass.h"
22 #include "llvm/Analysis/MustExecute.h"
23 #include "llvm/Analysis/ScalarEvolution.h"
24 #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h"
25 #include "llvm/Analysis/ScalarEvolutionExpander.h"
26 #include "llvm/Analysis/ScalarEvolutionExpressions.h"
27 #include "llvm/Analysis/TargetTransformInfo.h"
28 #include "llvm/Analysis/ValueTracking.h"
29 #include "llvm/IR/DomTreeUpdater.h"
30 #include "llvm/IR/Dominators.h"
31 #include "llvm/IR/Instructions.h"
32 #include "llvm/IR/Module.h"
33 #include "llvm/IR/PatternMatch.h"
34 #include "llvm/IR/ValueHandle.h"
35 #include "llvm/Pass.h"
36 #include "llvm/Support/Debug.h"
37 #include "llvm/Support/KnownBits.h"
38 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
39 
40 using namespace llvm;
41 using namespace llvm::PatternMatch;
42 
43 #define DEBUG_TYPE "loop-utils"
44 
45 bool llvm::formDedicatedExitBlocks(Loop *L, DominatorTree *DT, LoopInfo *LI,
46                                    bool PreserveLCSSA) {
47   bool Changed = false;
48 
49   // We re-use a vector for the in-loop predecesosrs.
50   SmallVector<BasicBlock *, 4> InLoopPredecessors;
51 
52   auto RewriteExit = [&](BasicBlock *BB) {
53     assert(InLoopPredecessors.empty() &&
54            "Must start with an empty predecessors list!");
55     auto Cleanup = make_scope_exit([&] { InLoopPredecessors.clear(); });
56 
57     // See if there are any non-loop predecessors of this exit block and
58     // keep track of the in-loop predecessors.
59     bool IsDedicatedExit = true;
60     for (auto *PredBB : predecessors(BB))
61       if (L->contains(PredBB)) {
62         if (isa<IndirectBrInst>(PredBB->getTerminator()))
63           // We cannot rewrite exiting edges from an indirectbr.
64           return false;
65 
66         InLoopPredecessors.push_back(PredBB);
67       } else {
68         IsDedicatedExit = false;
69       }
70 
71     assert(!InLoopPredecessors.empty() && "Must have *some* loop predecessor!");
72 
73     // Nothing to do if this is already a dedicated exit.
74     if (IsDedicatedExit)
75       return false;
76 
77     auto *NewExitBB = SplitBlockPredecessors(
78         BB, InLoopPredecessors, ".loopexit", DT, LI, nullptr, PreserveLCSSA);
79 
80     if (!NewExitBB)
81       LLVM_DEBUG(
82           dbgs() << "WARNING: Can't create a dedicated exit block for loop: "
83                  << *L << "\n");
84     else
85       LLVM_DEBUG(dbgs() << "LoopSimplify: Creating dedicated exit block "
86                         << NewExitBB->getName() << "\n");
87     return true;
88   };
89 
90   // Walk the exit blocks directly rather than building up a data structure for
91   // them, but only visit each one once.
92   SmallPtrSet<BasicBlock *, 4> Visited;
93   for (auto *BB : L->blocks())
94     for (auto *SuccBB : successors(BB)) {
95       // We're looking for exit blocks so skip in-loop successors.
96       if (L->contains(SuccBB))
97         continue;
98 
99       // Visit each exit block exactly once.
100       if (!Visited.insert(SuccBB).second)
101         continue;
102 
103       Changed |= RewriteExit(SuccBB);
104     }
105 
106   return Changed;
107 }
108 
109 /// Returns the instructions that use values defined in the loop.
110 SmallVector<Instruction *, 8> llvm::findDefsUsedOutsideOfLoop(Loop *L) {
111   SmallVector<Instruction *, 8> UsedOutside;
112 
113   for (auto *Block : L->getBlocks())
114     // FIXME: I believe that this could use copy_if if the Inst reference could
115     // be adapted into a pointer.
116     for (auto &Inst : *Block) {
117       auto Users = Inst.users();
118       if (any_of(Users, [&](User *U) {
119             auto *Use = cast<Instruction>(U);
120             return !L->contains(Use->getParent());
121           }))
122         UsedOutside.push_back(&Inst);
123     }
124 
125   return UsedOutside;
126 }
127 
128 void llvm::getLoopAnalysisUsage(AnalysisUsage &AU) {
129   // By definition, all loop passes need the LoopInfo analysis and the
130   // Dominator tree it depends on. Because they all participate in the loop
131   // pass manager, they must also preserve these.
132   AU.addRequired<DominatorTreeWrapperPass>();
133   AU.addPreserved<DominatorTreeWrapperPass>();
134   AU.addRequired<LoopInfoWrapperPass>();
135   AU.addPreserved<LoopInfoWrapperPass>();
136 
137   // We must also preserve LoopSimplify and LCSSA. We locally access their IDs
138   // here because users shouldn't directly get them from this header.
139   extern char &LoopSimplifyID;
140   extern char &LCSSAID;
141   AU.addRequiredID(LoopSimplifyID);
142   AU.addPreservedID(LoopSimplifyID);
143   AU.addRequiredID(LCSSAID);
144   AU.addPreservedID(LCSSAID);
145   // This is used in the LPPassManager to perform LCSSA verification on passes
146   // which preserve lcssa form
147   AU.addRequired<LCSSAVerificationPass>();
148   AU.addPreserved<LCSSAVerificationPass>();
149 
150   // Loop passes are designed to run inside of a loop pass manager which means
151   // that any function analyses they require must be required by the first loop
152   // pass in the manager (so that it is computed before the loop pass manager
153   // runs) and preserved by all loop pasess in the manager. To make this
154   // reasonably robust, the set needed for most loop passes is maintained here.
155   // If your loop pass requires an analysis not listed here, you will need to
156   // carefully audit the loop pass manager nesting structure that results.
157   AU.addRequired<AAResultsWrapperPass>();
158   AU.addPreserved<AAResultsWrapperPass>();
159   AU.addPreserved<BasicAAWrapperPass>();
160   AU.addPreserved<GlobalsAAWrapperPass>();
161   AU.addPreserved<SCEVAAWrapperPass>();
162   AU.addRequired<ScalarEvolutionWrapperPass>();
163   AU.addPreserved<ScalarEvolutionWrapperPass>();
164 }
165 
166 /// Manually defined generic "LoopPass" dependency initialization. This is used
167 /// to initialize the exact set of passes from above in \c
168 /// getLoopAnalysisUsage. It can be used within a loop pass's initialization
169 /// with:
170 ///
171 ///   INITIALIZE_PASS_DEPENDENCY(LoopPass)
172 ///
173 /// As-if "LoopPass" were a pass.
174 void llvm::initializeLoopPassPass(PassRegistry &Registry) {
175   INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
176   INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
177   INITIALIZE_PASS_DEPENDENCY(LoopSimplify)
178   INITIALIZE_PASS_DEPENDENCY(LCSSAWrapperPass)
179   INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
180   INITIALIZE_PASS_DEPENDENCY(BasicAAWrapperPass)
181   INITIALIZE_PASS_DEPENDENCY(GlobalsAAWrapperPass)
182   INITIALIZE_PASS_DEPENDENCY(SCEVAAWrapperPass)
183   INITIALIZE_PASS_DEPENDENCY(ScalarEvolutionWrapperPass)
184 }
185 
186 /// Find string metadata for loop
187 ///
188 /// If it has a value (e.g. {"llvm.distribute", 1} return the value as an
189 /// operand or null otherwise.  If the string metadata is not found return
190 /// Optional's not-a-value.
191 Optional<const MDOperand *> llvm::findStringMetadataForLoop(Loop *TheLoop,
192                                                             StringRef Name) {
193   MDNode *LoopID = TheLoop->getLoopID();
194   // Return none if LoopID is false.
195   if (!LoopID)
196     return None;
197 
198   // First operand should refer to the loop id itself.
199   assert(LoopID->getNumOperands() > 0 && "requires at least one operand");
200   assert(LoopID->getOperand(0) == LoopID && "invalid loop id");
201 
202   // Iterate over LoopID operands and look for MDString Metadata
203   for (unsigned i = 1, e = LoopID->getNumOperands(); i < e; ++i) {
204     MDNode *MD = dyn_cast<MDNode>(LoopID->getOperand(i));
205     if (!MD)
206       continue;
207     MDString *S = dyn_cast<MDString>(MD->getOperand(0));
208     if (!S)
209       continue;
210     // Return true if MDString holds expected MetaData.
211     if (Name.equals(S->getString()))
212       switch (MD->getNumOperands()) {
213       case 1:
214         return nullptr;
215       case 2:
216         return &MD->getOperand(1);
217       default:
218         llvm_unreachable("loop metadata has 0 or 1 operand");
219       }
220   }
221   return None;
222 }
223 
224 /// Does a BFS from a given node to all of its children inside a given loop.
225 /// The returned vector of nodes includes the starting point.
226 SmallVector<DomTreeNode *, 16>
227 llvm::collectChildrenInLoop(DomTreeNode *N, const Loop *CurLoop) {
228   SmallVector<DomTreeNode *, 16> Worklist;
229   auto AddRegionToWorklist = [&](DomTreeNode *DTN) {
230     // Only include subregions in the top level loop.
231     BasicBlock *BB = DTN->getBlock();
232     if (CurLoop->contains(BB))
233       Worklist.push_back(DTN);
234   };
235 
236   AddRegionToWorklist(N);
237 
238   for (size_t I = 0; I < Worklist.size(); I++)
239     for (DomTreeNode *Child : Worklist[I]->getChildren())
240       AddRegionToWorklist(Child);
241 
242   return Worklist;
243 }
244 
245 void llvm::deleteDeadLoop(Loop *L, DominatorTree *DT = nullptr,
246                           ScalarEvolution *SE = nullptr,
247                           LoopInfo *LI = nullptr) {
248   assert((!DT || L->isLCSSAForm(*DT)) && "Expected LCSSA!");
249   auto *Preheader = L->getLoopPreheader();
250   assert(Preheader && "Preheader should exist!");
251 
252   // Now that we know the removal is safe, remove the loop by changing the
253   // branch from the preheader to go to the single exit block.
254   //
255   // Because we're deleting a large chunk of code at once, the sequence in which
256   // we remove things is very important to avoid invalidation issues.
257 
258   // Tell ScalarEvolution that the loop is deleted. Do this before
259   // deleting the loop so that ScalarEvolution can look at the loop
260   // to determine what it needs to clean up.
261   if (SE)
262     SE->forgetLoop(L);
263 
264   auto *ExitBlock = L->getUniqueExitBlock();
265   assert(ExitBlock && "Should have a unique exit block!");
266   assert(L->hasDedicatedExits() && "Loop should have dedicated exits!");
267 
268   auto *OldBr = dyn_cast<BranchInst>(Preheader->getTerminator());
269   assert(OldBr && "Preheader must end with a branch");
270   assert(OldBr->isUnconditional() && "Preheader must have a single successor");
271   // Connect the preheader to the exit block. Keep the old edge to the header
272   // around to perform the dominator tree update in two separate steps
273   // -- #1 insertion of the edge preheader -> exit and #2 deletion of the edge
274   // preheader -> header.
275   //
276   //
277   // 0.  Preheader          1.  Preheader           2.  Preheader
278   //        |                    |   |                   |
279   //        V                    |   V                   |
280   //      Header <--\            | Header <--\           | Header <--\
281   //       |  |     |            |  |  |     |           |  |  |     |
282   //       |  V     |            |  |  V     |           |  |  V     |
283   //       | Body --/            |  | Body --/           |  | Body --/
284   //       V                     V  V                    V  V
285   //      Exit                   Exit                    Exit
286   //
287   // By doing this is two separate steps we can perform the dominator tree
288   // update without using the batch update API.
289   //
290   // Even when the loop is never executed, we cannot remove the edge from the
291   // source block to the exit block. Consider the case where the unexecuted loop
292   // branches back to an outer loop. If we deleted the loop and removed the edge
293   // coming to this inner loop, this will break the outer loop structure (by
294   // deleting the backedge of the outer loop). If the outer loop is indeed a
295   // non-loop, it will be deleted in a future iteration of loop deletion pass.
296   IRBuilder<> Builder(OldBr);
297   Builder.CreateCondBr(Builder.getFalse(), L->getHeader(), ExitBlock);
298   // Remove the old branch. The conditional branch becomes a new terminator.
299   OldBr->eraseFromParent();
300 
301   // Rewrite phis in the exit block to get their inputs from the Preheader
302   // instead of the exiting block.
303   for (PHINode &P : ExitBlock->phis()) {
304     // Set the zero'th element of Phi to be from the preheader and remove all
305     // other incoming values. Given the loop has dedicated exits, all other
306     // incoming values must be from the exiting blocks.
307     int PredIndex = 0;
308     P.setIncomingBlock(PredIndex, Preheader);
309     // Removes all incoming values from all other exiting blocks (including
310     // duplicate values from an exiting block).
311     // Nuke all entries except the zero'th entry which is the preheader entry.
312     // NOTE! We need to remove Incoming Values in the reverse order as done
313     // below, to keep the indices valid for deletion (removeIncomingValues
314     // updates getNumIncomingValues and shifts all values down into the operand
315     // being deleted).
316     for (unsigned i = 0, e = P.getNumIncomingValues() - 1; i != e; ++i)
317       P.removeIncomingValue(e - i, false);
318 
319     assert((P.getNumIncomingValues() == 1 &&
320             P.getIncomingBlock(PredIndex) == Preheader) &&
321            "Should have exactly one value and that's from the preheader!");
322   }
323 
324   // Disconnect the loop body by branching directly to its exit.
325   Builder.SetInsertPoint(Preheader->getTerminator());
326   Builder.CreateBr(ExitBlock);
327   // Remove the old branch.
328   Preheader->getTerminator()->eraseFromParent();
329 
330   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Eager);
331   if (DT) {
332     // Update the dominator tree by informing it about the new edge from the
333     // preheader to the exit.
334     DTU.insertEdge(Preheader, ExitBlock);
335     // Inform the dominator tree about the removed edge.
336     DTU.deleteEdge(Preheader, L->getHeader());
337   }
338 
339   // Given LCSSA form is satisfied, we should not have users of instructions
340   // within the dead loop outside of the loop. However, LCSSA doesn't take
341   // unreachable uses into account. We handle them here.
342   // We could do it after drop all references (in this case all users in the
343   // loop will be already eliminated and we have less work to do but according
344   // to API doc of User::dropAllReferences only valid operation after dropping
345   // references, is deletion. So let's substitute all usages of
346   // instruction from the loop with undef value of corresponding type first.
347   for (auto *Block : L->blocks())
348     for (Instruction &I : *Block) {
349       auto *Undef = UndefValue::get(I.getType());
350       for (Value::use_iterator UI = I.use_begin(), E = I.use_end(); UI != E;) {
351         Use &U = *UI;
352         ++UI;
353         if (auto *Usr = dyn_cast<Instruction>(U.getUser()))
354           if (L->contains(Usr->getParent()))
355             continue;
356         // If we have a DT then we can check that uses outside a loop only in
357         // unreachable block.
358         if (DT)
359           assert(!DT->isReachableFromEntry(U) &&
360                  "Unexpected user in reachable block");
361         U.set(Undef);
362       }
363     }
364 
365   // Remove the block from the reference counting scheme, so that we can
366   // delete it freely later.
367   for (auto *Block : L->blocks())
368     Block->dropAllReferences();
369 
370   if (LI) {
371     // Erase the instructions and the blocks without having to worry
372     // about ordering because we already dropped the references.
373     // NOTE: This iteration is safe because erasing the block does not remove
374     // its entry from the loop's block list.  We do that in the next section.
375     for (Loop::block_iterator LpI = L->block_begin(), LpE = L->block_end();
376          LpI != LpE; ++LpI)
377       (*LpI)->eraseFromParent();
378 
379     // Finally, the blocks from loopinfo.  This has to happen late because
380     // otherwise our loop iterators won't work.
381 
382     SmallPtrSet<BasicBlock *, 8> blocks;
383     blocks.insert(L->block_begin(), L->block_end());
384     for (BasicBlock *BB : blocks)
385       LI->removeBlock(BB);
386 
387     // The last step is to update LoopInfo now that we've eliminated this loop.
388     LI->erase(L);
389   }
390 }
391 
392 Optional<unsigned> llvm::getLoopEstimatedTripCount(Loop *L) {
393   // Only support loops with a unique exiting block, and a latch.
394   if (!L->getExitingBlock())
395     return None;
396 
397   // Get the branch weights for the loop's backedge.
398   BranchInst *LatchBR =
399       dyn_cast<BranchInst>(L->getLoopLatch()->getTerminator());
400   if (!LatchBR || LatchBR->getNumSuccessors() != 2)
401     return None;
402 
403   assert((LatchBR->getSuccessor(0) == L->getHeader() ||
404           LatchBR->getSuccessor(1) == L->getHeader()) &&
405          "At least one edge out of the latch must go to the header");
406 
407   // To estimate the number of times the loop body was executed, we want to
408   // know the number of times the backedge was taken, vs. the number of times
409   // we exited the loop.
410   uint64_t TrueVal, FalseVal;
411   if (!LatchBR->extractProfMetadata(TrueVal, FalseVal))
412     return None;
413 
414   if (!TrueVal || !FalseVal)
415     return 0;
416 
417   // Divide the count of the backedge by the count of the edge exiting the loop,
418   // rounding to nearest.
419   if (LatchBR->getSuccessor(0) == L->getHeader())
420     return (TrueVal + (FalseVal / 2)) / FalseVal;
421   else
422     return (FalseVal + (TrueVal / 2)) / TrueVal;
423 }
424 
425 bool llvm::hasIterationCountInvariantInParent(Loop *InnerLoop,
426                                               ScalarEvolution &SE) {
427   Loop *OuterL = InnerLoop->getParentLoop();
428   if (!OuterL)
429     return true;
430 
431   // Get the backedge taken count for the inner loop
432   BasicBlock *InnerLoopLatch = InnerLoop->getLoopLatch();
433   const SCEV *InnerLoopBECountSC = SE.getExitCount(InnerLoop, InnerLoopLatch);
434   if (isa<SCEVCouldNotCompute>(InnerLoopBECountSC) ||
435       !InnerLoopBECountSC->getType()->isIntegerTy())
436     return false;
437 
438   // Get whether count is invariant to the outer loop
439   ScalarEvolution::LoopDisposition LD =
440       SE.getLoopDisposition(InnerLoopBECountSC, OuterL);
441   if (LD != ScalarEvolution::LoopInvariant)
442     return false;
443 
444   return true;
445 }
446 
447 /// Adds a 'fast' flag to floating point operations.
448 static Value *addFastMathFlag(Value *V) {
449   if (isa<FPMathOperator>(V)) {
450     FastMathFlags Flags;
451     Flags.setFast();
452     cast<Instruction>(V)->setFastMathFlags(Flags);
453   }
454   return V;
455 }
456 
457 Value *llvm::createMinMaxOp(IRBuilder<> &Builder,
458                             RecurrenceDescriptor::MinMaxRecurrenceKind RK,
459                             Value *Left, Value *Right) {
460   CmpInst::Predicate P = CmpInst::ICMP_NE;
461   switch (RK) {
462   default:
463     llvm_unreachable("Unknown min/max recurrence kind");
464   case RecurrenceDescriptor::MRK_UIntMin:
465     P = CmpInst::ICMP_ULT;
466     break;
467   case RecurrenceDescriptor::MRK_UIntMax:
468     P = CmpInst::ICMP_UGT;
469     break;
470   case RecurrenceDescriptor::MRK_SIntMin:
471     P = CmpInst::ICMP_SLT;
472     break;
473   case RecurrenceDescriptor::MRK_SIntMax:
474     P = CmpInst::ICMP_SGT;
475     break;
476   case RecurrenceDescriptor::MRK_FloatMin:
477     P = CmpInst::FCMP_OLT;
478     break;
479   case RecurrenceDescriptor::MRK_FloatMax:
480     P = CmpInst::FCMP_OGT;
481     break;
482   }
483 
484   // We only match FP sequences that are 'fast', so we can unconditionally
485   // set it on any generated instructions.
486   IRBuilder<>::FastMathFlagGuard FMFG(Builder);
487   FastMathFlags FMF;
488   FMF.setFast();
489   Builder.setFastMathFlags(FMF);
490 
491   Value *Cmp;
492   if (RK == RecurrenceDescriptor::MRK_FloatMin ||
493       RK == RecurrenceDescriptor::MRK_FloatMax)
494     Cmp = Builder.CreateFCmp(P, Left, Right, "rdx.minmax.cmp");
495   else
496     Cmp = Builder.CreateICmp(P, Left, Right, "rdx.minmax.cmp");
497 
498   Value *Select = Builder.CreateSelect(Cmp, Left, Right, "rdx.minmax.select");
499   return Select;
500 }
501 
502 // Helper to generate an ordered reduction.
503 Value *
504 llvm::getOrderedReduction(IRBuilder<> &Builder, Value *Acc, Value *Src,
505                           unsigned Op,
506                           RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind,
507                           ArrayRef<Value *> RedOps) {
508   unsigned VF = Src->getType()->getVectorNumElements();
509 
510   // Extract and apply reduction ops in ascending order:
511   // e.g. ((((Acc + Scl[0]) + Scl[1]) + Scl[2]) + ) ... + Scl[VF-1]
512   Value *Result = Acc;
513   for (unsigned ExtractIdx = 0; ExtractIdx != VF; ++ExtractIdx) {
514     Value *Ext =
515         Builder.CreateExtractElement(Src, Builder.getInt32(ExtractIdx));
516 
517     if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
518       Result = Builder.CreateBinOp((Instruction::BinaryOps)Op, Result, Ext,
519                                    "bin.rdx");
520     } else {
521       assert(MinMaxKind != RecurrenceDescriptor::MRK_Invalid &&
522              "Invalid min/max");
523       Result = createMinMaxOp(Builder, MinMaxKind, Result, Ext);
524     }
525 
526     if (!RedOps.empty())
527       propagateIRFlags(Result, RedOps);
528   }
529 
530   return Result;
531 }
532 
533 // Helper to generate a log2 shuffle reduction.
534 Value *
535 llvm::getShuffleReduction(IRBuilder<> &Builder, Value *Src, unsigned Op,
536                           RecurrenceDescriptor::MinMaxRecurrenceKind MinMaxKind,
537                           ArrayRef<Value *> RedOps) {
538   unsigned VF = Src->getType()->getVectorNumElements();
539   // VF is a power of 2 so we can emit the reduction using log2(VF) shuffles
540   // and vector ops, reducing the set of values being computed by half each
541   // round.
542   assert(isPowerOf2_32(VF) &&
543          "Reduction emission only supported for pow2 vectors!");
544   Value *TmpVec = Src;
545   SmallVector<Constant *, 32> ShuffleMask(VF, nullptr);
546   for (unsigned i = VF; i != 1; i >>= 1) {
547     // Move the upper half of the vector to the lower half.
548     for (unsigned j = 0; j != i / 2; ++j)
549       ShuffleMask[j] = Builder.getInt32(i / 2 + j);
550 
551     // Fill the rest of the mask with undef.
552     std::fill(&ShuffleMask[i / 2], ShuffleMask.end(),
553               UndefValue::get(Builder.getInt32Ty()));
554 
555     Value *Shuf = Builder.CreateShuffleVector(
556         TmpVec, UndefValue::get(TmpVec->getType()),
557         ConstantVector::get(ShuffleMask), "rdx.shuf");
558 
559     if (Op != Instruction::ICmp && Op != Instruction::FCmp) {
560       // Floating point operations had to be 'fast' to enable the reduction.
561       TmpVec = addFastMathFlag(Builder.CreateBinOp((Instruction::BinaryOps)Op,
562                                                    TmpVec, Shuf, "bin.rdx"));
563     } else {
564       assert(MinMaxKind != RecurrenceDescriptor::MRK_Invalid &&
565              "Invalid min/max");
566       TmpVec = createMinMaxOp(Builder, MinMaxKind, TmpVec, Shuf);
567     }
568     if (!RedOps.empty())
569       propagateIRFlags(TmpVec, RedOps);
570   }
571   // The result is in the first element of the vector.
572   return Builder.CreateExtractElement(TmpVec, Builder.getInt32(0));
573 }
574 
575 /// Create a simple vector reduction specified by an opcode and some
576 /// flags (if generating min/max reductions).
577 Value *llvm::createSimpleTargetReduction(
578     IRBuilder<> &Builder, const TargetTransformInfo *TTI, unsigned Opcode,
579     Value *Src, TargetTransformInfo::ReductionFlags Flags,
580     ArrayRef<Value *> RedOps) {
581   assert(isa<VectorType>(Src->getType()) && "Type must be a vector");
582 
583   Value *ScalarUdf = UndefValue::get(Src->getType()->getVectorElementType());
584   std::function<Value *()> BuildFunc;
585   using RD = RecurrenceDescriptor;
586   RD::MinMaxRecurrenceKind MinMaxKind = RD::MRK_Invalid;
587   // TODO: Support creating ordered reductions.
588   FastMathFlags FMFFast;
589   FMFFast.setFast();
590 
591   switch (Opcode) {
592   case Instruction::Add:
593     BuildFunc = [&]() { return Builder.CreateAddReduce(Src); };
594     break;
595   case Instruction::Mul:
596     BuildFunc = [&]() { return Builder.CreateMulReduce(Src); };
597     break;
598   case Instruction::And:
599     BuildFunc = [&]() { return Builder.CreateAndReduce(Src); };
600     break;
601   case Instruction::Or:
602     BuildFunc = [&]() { return Builder.CreateOrReduce(Src); };
603     break;
604   case Instruction::Xor:
605     BuildFunc = [&]() { return Builder.CreateXorReduce(Src); };
606     break;
607   case Instruction::FAdd:
608     BuildFunc = [&]() {
609       auto Rdx = Builder.CreateFAddReduce(ScalarUdf, Src);
610       cast<CallInst>(Rdx)->setFastMathFlags(FMFFast);
611       return Rdx;
612     };
613     break;
614   case Instruction::FMul:
615     BuildFunc = [&]() {
616       auto Rdx = Builder.CreateFMulReduce(ScalarUdf, Src);
617       cast<CallInst>(Rdx)->setFastMathFlags(FMFFast);
618       return Rdx;
619     };
620     break;
621   case Instruction::ICmp:
622     if (Flags.IsMaxOp) {
623       MinMaxKind = Flags.IsSigned ? RD::MRK_SIntMax : RD::MRK_UIntMax;
624       BuildFunc = [&]() {
625         return Builder.CreateIntMaxReduce(Src, Flags.IsSigned);
626       };
627     } else {
628       MinMaxKind = Flags.IsSigned ? RD::MRK_SIntMin : RD::MRK_UIntMin;
629       BuildFunc = [&]() {
630         return Builder.CreateIntMinReduce(Src, Flags.IsSigned);
631       };
632     }
633     break;
634   case Instruction::FCmp:
635     if (Flags.IsMaxOp) {
636       MinMaxKind = RD::MRK_FloatMax;
637       BuildFunc = [&]() { return Builder.CreateFPMaxReduce(Src, Flags.NoNaN); };
638     } else {
639       MinMaxKind = RD::MRK_FloatMin;
640       BuildFunc = [&]() { return Builder.CreateFPMinReduce(Src, Flags.NoNaN); };
641     }
642     break;
643   default:
644     llvm_unreachable("Unhandled opcode");
645     break;
646   }
647   if (TTI->useReductionIntrinsic(Opcode, Src->getType(), Flags))
648     return BuildFunc();
649   return getShuffleReduction(Builder, Src, Opcode, MinMaxKind, RedOps);
650 }
651 
652 /// Create a vector reduction using a given recurrence descriptor.
653 Value *llvm::createTargetReduction(IRBuilder<> &B,
654                                    const TargetTransformInfo *TTI,
655                                    RecurrenceDescriptor &Desc, Value *Src,
656                                    bool NoNaN) {
657   // TODO: Support in-order reductions based on the recurrence descriptor.
658   using RD = RecurrenceDescriptor;
659   RD::RecurrenceKind RecKind = Desc.getRecurrenceKind();
660   TargetTransformInfo::ReductionFlags Flags;
661   Flags.NoNaN = NoNaN;
662   switch (RecKind) {
663   case RD::RK_FloatAdd:
664     return createSimpleTargetReduction(B, TTI, Instruction::FAdd, Src, Flags);
665   case RD::RK_FloatMult:
666     return createSimpleTargetReduction(B, TTI, Instruction::FMul, Src, Flags);
667   case RD::RK_IntegerAdd:
668     return createSimpleTargetReduction(B, TTI, Instruction::Add, Src, Flags);
669   case RD::RK_IntegerMult:
670     return createSimpleTargetReduction(B, TTI, Instruction::Mul, Src, Flags);
671   case RD::RK_IntegerAnd:
672     return createSimpleTargetReduction(B, TTI, Instruction::And, Src, Flags);
673   case RD::RK_IntegerOr:
674     return createSimpleTargetReduction(B, TTI, Instruction::Or, Src, Flags);
675   case RD::RK_IntegerXor:
676     return createSimpleTargetReduction(B, TTI, Instruction::Xor, Src, Flags);
677   case RD::RK_IntegerMinMax: {
678     RD::MinMaxRecurrenceKind MMKind = Desc.getMinMaxRecurrenceKind();
679     Flags.IsMaxOp = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_UIntMax);
680     Flags.IsSigned = (MMKind == RD::MRK_SIntMax || MMKind == RD::MRK_SIntMin);
681     return createSimpleTargetReduction(B, TTI, Instruction::ICmp, Src, Flags);
682   }
683   case RD::RK_FloatMinMax: {
684     Flags.IsMaxOp = Desc.getMinMaxRecurrenceKind() == RD::MRK_FloatMax;
685     return createSimpleTargetReduction(B, TTI, Instruction::FCmp, Src, Flags);
686   }
687   default:
688     llvm_unreachable("Unhandled RecKind");
689   }
690 }
691 
692 void llvm::propagateIRFlags(Value *I, ArrayRef<Value *> VL, Value *OpValue) {
693   auto *VecOp = dyn_cast<Instruction>(I);
694   if (!VecOp)
695     return;
696   auto *Intersection = (OpValue == nullptr) ? dyn_cast<Instruction>(VL[0])
697                                             : dyn_cast<Instruction>(OpValue);
698   if (!Intersection)
699     return;
700   const unsigned Opcode = Intersection->getOpcode();
701   VecOp->copyIRFlags(Intersection);
702   for (auto *V : VL) {
703     auto *Instr = dyn_cast<Instruction>(V);
704     if (!Instr)
705       continue;
706     if (OpValue == nullptr || Opcode == Instr->getOpcode())
707       VecOp->andIRFlags(V);
708   }
709 }
710