1245679b6SChristopher Tetreault //===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===//
2312409e4SSam Parker //
3312409e4SSam Parker // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4312409e4SSam Parker // See https://llvm.org/LICENSE.txt for license information.
5312409e4SSam Parker // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6312409e4SSam Parker //
7312409e4SSam Parker //===----------------------------------------------------------------------===//
8312409e4SSam Parker //
9312409e4SSam Parker /// \file
10312409e4SSam Parker /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
110736d1ccSSjoerd Meijer /// branches to help accelerate DSP applications. These two extensions,
120736d1ccSSjoerd Meijer /// combined with a new form of predication called tail-predication, can be used
130736d1ccSSjoerd Meijer /// to provide implicit vector predication within a low-overhead loop.
140736d1ccSSjoerd Meijer /// This is implicit because the predicate of active/inactive lanes is
150736d1ccSSjoerd Meijer /// calculated by hardware, and thus does not need to be explicitly passed
160736d1ccSSjoerd Meijer /// to vector instructions. The instructions responsible for this are the
170736d1ccSSjoerd Meijer /// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the
180736d1ccSSjoerd Meijer /// the total number of data elements processed by the loop. The loop-end
190736d1ccSSjoerd Meijer /// LETP instruction is responsible for decrementing and setting the remaining
200736d1ccSSjoerd Meijer /// elements to be processed and generating the mask of active lanes.
210736d1ccSSjoerd Meijer ///
22312409e4SSam Parker /// The HardwareLoops pass inserts intrinsics identifying loops that the
23312409e4SSam Parker /// backend will attempt to convert into a low-overhead loop. The vectorizer is
24312409e4SSam Parker /// responsible for generating a vectorized loop in which the lanes are
250e49a40dSDavid Green /// predicated upon an get.active.lane.mask intrinsic. This pass looks at these
260e49a40dSDavid Green /// get.active.lane.mask intrinsic and attempts to convert them to VCTP
270e49a40dSDavid Green /// instructions. This will be picked up by the ARM Low-overhead loop pass later
280e49a40dSDavid Green /// in the backend, which performs the final transformation to a DLSTP or WLSTP
290e49a40dSDavid Green /// tail-predicated loop.
300e49a40dSDavid Green //
310e49a40dSDavid Green //===----------------------------------------------------------------------===//
32312409e4SSam Parker 
335d986953SReid Kleckner #include "ARM.h"
345d986953SReid Kleckner #include "ARMSubtarget.h"
35595270aeSSjoerd Meijer #include "ARMTargetTransformInfo.h"
36312409e4SSam Parker #include "llvm/Analysis/LoopInfo.h"
37312409e4SSam Parker #include "llvm/Analysis/LoopPass.h"
38312409e4SSam Parker #include "llvm/Analysis/ScalarEvolution.h"
39312409e4SSam Parker #include "llvm/Analysis/ScalarEvolutionExpressions.h"
40c18b7536SSimon Pilgrim #include "llvm/Analysis/TargetLibraryInfo.h"
41312409e4SSam Parker #include "llvm/Analysis/TargetTransformInfo.h"
42312409e4SSam Parker #include "llvm/CodeGen/TargetPassConfig.h"
43312409e4SSam Parker #include "llvm/IR/IRBuilder.h"
445d986953SReid Kleckner #include "llvm/IR/Instructions.h"
455d986953SReid Kleckner #include "llvm/IR/IntrinsicsARM.h"
46312409e4SSam Parker #include "llvm/IR/PatternMatch.h"
47bcbd26bfSFlorian Hahn #include "llvm/InitializePasses.h"
48312409e4SSam Parker #include "llvm/Support/Debug.h"
49312409e4SSam Parker #include "llvm/Transforms/Utils/BasicBlockUtils.h"
500e49a40dSDavid Green #include "llvm/Transforms/Utils/Local.h"
518cba99e2SSjoerd Meijer #include "llvm/Transforms/Utils/LoopUtils.h"
52bcbd26bfSFlorian Hahn #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
53312409e4SSam Parker 
54312409e4SSam Parker using namespace llvm;
55312409e4SSam Parker 
56312409e4SSam Parker #define DEBUG_TYPE "mve-tail-predication"
57312409e4SSam Parker #define DESC "Transform predicated vector loops to use MVE tail predication"
58312409e4SSam Parker 
59595270aeSSjoerd Meijer cl::opt<TailPredication::Mode> EnableTailPredication(
602fc690acSSjoerd Meijer    "tail-predication", cl::desc("MVE tail-predication pass options"),
611696dd27SSjoerd Meijer    cl::init(TailPredication::Enabled),
62595270aeSSjoerd Meijer    cl::values(clEnumValN(TailPredication::Disabled, "disabled",
63595270aeSSjoerd Meijer                          "Don't tail-predicate loops"),
64595270aeSSjoerd Meijer               clEnumValN(TailPredication::EnabledNoReductions,
65595270aeSSjoerd Meijer                          "enabled-no-reductions",
66595270aeSSjoerd Meijer                          "Enable tail-predication, but not for reduction loops"),
67595270aeSSjoerd Meijer               clEnumValN(TailPredication::Enabled,
68595270aeSSjoerd Meijer                          "enabled",
69595270aeSSjoerd Meijer                          "Enable tail-predication, including reduction loops"),
70595270aeSSjoerd Meijer               clEnumValN(TailPredication::ForceEnabledNoReductions,
71595270aeSSjoerd Meijer                          "force-enabled-no-reductions",
72595270aeSSjoerd Meijer                          "Enable tail-predication, but not for reduction loops, "
73595270aeSSjoerd Meijer                          "and force this which might be unsafe"),
74595270aeSSjoerd Meijer               clEnumValN(TailPredication::ForceEnabled,
75595270aeSSjoerd Meijer                          "force-enabled",
76595270aeSSjoerd Meijer                          "Enable tail-predication, including reduction loops, "
77595270aeSSjoerd Meijer                          "and force this which might be unsafe")));
78d1522513SSjoerd Meijer 
79595270aeSSjoerd Meijer 
80312409e4SSam Parker namespace {
81312409e4SSam Parker 
82312409e4SSam Parker class MVETailPredication : public LoopPass {
83312409e4SSam Parker   SmallVector<IntrinsicInst*, 4> MaskedInsts;
84312409e4SSam Parker   Loop *L = nullptr;
85312409e4SSam Parker   ScalarEvolution *SE = nullptr;
86312409e4SSam Parker   TargetTransformInfo *TTI = nullptr;
87d9cb811cSSamuel Tebbs   const ARMSubtarget *ST = nullptr;
88312409e4SSam Parker 
89312409e4SSam Parker public:
90312409e4SSam Parker   static char ID;
91312409e4SSam Parker 
MVETailPredication()92312409e4SSam Parker   MVETailPredication() : LoopPass(ID) { }
93312409e4SSam Parker 
getAnalysisUsage(AnalysisUsage & AU) const94312409e4SSam Parker   void getAnalysisUsage(AnalysisUsage &AU) const override {
95312409e4SSam Parker     AU.addRequired<ScalarEvolutionWrapperPass>();
96312409e4SSam Parker     AU.addRequired<LoopInfoWrapperPass>();
97312409e4SSam Parker     AU.addRequired<TargetPassConfig>();
98312409e4SSam Parker     AU.addRequired<TargetTransformInfoWrapperPass>();
99312409e4SSam Parker     AU.addPreserved<LoopInfoWrapperPass>();
100312409e4SSam Parker     AU.setPreservesCFG();
101312409e4SSam Parker   }
102312409e4SSam Parker 
103312409e4SSam Parker   bool runOnLoop(Loop *L, LPPassManager&) override;
104312409e4SSam Parker 
105312409e4SSam Parker private:
1060e49a40dSDavid Green   /// Perform the relevant checks on the loop and convert active lane masks if
1070e49a40dSDavid Green   /// possible.
1080e49a40dSDavid Green   bool TryConvertActiveLaneMask(Value *TripCount);
109312409e4SSam Parker 
110676febc0SSjoerd Meijer   /// Perform several checks on the arguments of @llvm.get.active.lane.mask
111676febc0SSjoerd Meijer   /// intrinsic. E.g., check that the loop induction variable and the element
112676febc0SSjoerd Meijer   /// count are of the form we expect, and also perform overflow checks for
113676febc0SSjoerd Meijer   /// the new expressions that are created.
1140e49a40dSDavid Green   bool IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount);
1150efc9e5aSSjoerd Meijer 
1160efc9e5aSSjoerd Meijer   /// Insert the intrinsic to represent the effect of tail predication.
1170e49a40dSDavid Green   void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *TripCount);
1188cba99e2SSjoerd Meijer 
1198cba99e2SSjoerd Meijer   /// Rematerialize the iteration count in exit blocks, which enables
1208cba99e2SSjoerd Meijer   /// ARMLowOverheadLoops to better optimise away loop update statements inside
1218cba99e2SSjoerd Meijer   /// hardware-loops.
1228cba99e2SSjoerd Meijer   void RematerializeIterCount();
123312409e4SSam Parker };
124312409e4SSam Parker 
125312409e4SSam Parker } // end namespace
126312409e4SSam Parker 
runOnLoop(Loop * L,LPPassManager &)127312409e4SSam Parker bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
128595270aeSSjoerd Meijer   if (skipLoop(L) || !EnableTailPredication)
129312409e4SSam Parker     return false;
130312409e4SSam Parker 
131c04b9ba5SSam Parker   MaskedInsts.clear();
132312409e4SSam Parker   Function &F = *L->getHeader()->getParent();
133312409e4SSam Parker   auto &TPC = getAnalysis<TargetPassConfig>();
134312409e4SSam Parker   auto &TM = TPC.getTM<TargetMachine>();
135d9cb811cSSamuel Tebbs   ST = &TM.getSubtarget<ARMSubtarget>(F);
136312409e4SSam Parker   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
137312409e4SSam Parker   SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
138312409e4SSam Parker   this->L = L;
139312409e4SSam Parker 
140312409e4SSam Parker   // The MVE and LOB extensions are combined to enable tail-predication, but
141312409e4SSam Parker   // there's nothing preventing us from generating VCTP instructions for v8.1m.
142312409e4SSam Parker   if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
1430efc9e5aSSjoerd Meijer     LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n");
144312409e4SSam Parker     return false;
145312409e4SSam Parker   }
146312409e4SSam Parker 
147312409e4SSam Parker   BasicBlock *Preheader = L->getLoopPreheader();
148312409e4SSam Parker   if (!Preheader)
149312409e4SSam Parker     return false;
150312409e4SSam Parker 
151312409e4SSam Parker   auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
152312409e4SSam Parker     for (auto &I : *BB) {
153312409e4SSam Parker       auto *Call = dyn_cast<IntrinsicInst>(&I);
154312409e4SSam Parker       if (!Call)
155312409e4SSam Parker         continue;
156312409e4SSam Parker 
157312409e4SSam Parker       Intrinsic::ID ID = Call->getIntrinsicID();
158b2ac9681SDavid Green       if (ID == Intrinsic::start_loop_iterations ||
159fad70c30SDavid Green           ID == Intrinsic::test_start_loop_iterations)
160312409e4SSam Parker         return cast<IntrinsicInst>(&I);
161312409e4SSam Parker     }
162312409e4SSam Parker     return nullptr;
163312409e4SSam Parker   };
164312409e4SSam Parker 
165312409e4SSam Parker   // Look for the hardware loop intrinsic that sets the iteration count.
166312409e4SSam Parker   IntrinsicInst *Setup = FindLoopIterations(Preheader);
167312409e4SSam Parker 
168312409e4SSam Parker   // The test.set iteration could live in the pre-preheader.
169312409e4SSam Parker   if (!Setup) {
170312409e4SSam Parker     if (!Preheader->getSinglePredecessor())
171312409e4SSam Parker       return false;
172312409e4SSam Parker     Setup = FindLoopIterations(Preheader->getSinglePredecessor());
173312409e4SSam Parker     if (!Setup)
174312409e4SSam Parker       return false;
175312409e4SSam Parker   }
176312409e4SSam Parker 
1770e49a40dSDavid Green   LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n");
178312409e4SSam Parker 
1790e49a40dSDavid Green   bool Changed = TryConvertActiveLaneMask(Setup->getArgOperand(0));
180312409e4SSam Parker 
1810e49a40dSDavid Green   return Changed;
1829feb429aSSam Parker }
1839feb429aSSam Parker 
184d1522513SSjoerd Meijer // The active lane intrinsic has this form:
185d1522513SSjoerd Meijer //
186c352e7fbSSjoerd Meijer //    @llvm.get.active.lane.mask(IV, TC)
187d1522513SSjoerd Meijer //
188d1522513SSjoerd Meijer // Here we perform checks that this intrinsic behaves as expected,
189d1522513SSjoerd Meijer // which means:
190d1522513SSjoerd Meijer //
191c352e7fbSSjoerd Meijer // 1) Check that the TripCount (TC) belongs to this loop (originally).
192c352e7fbSSjoerd Meijer // 2) The element count (TC) needs to be sufficiently large that the decrement
193c352e7fbSSjoerd Meijer //    of element counter doesn't overflow, which means that we need to prove:
194d1522513SSjoerd Meijer //        ceil(ElementCount / VectorWidth) >= TripCount
195d1522513SSjoerd Meijer //    by rounding up ElementCount up:
196d1522513SSjoerd Meijer //        ((ElementCount + (VectorWidth - 1)) / VectorWidth
197d1522513SSjoerd Meijer //    and evaluate if expression isKnownNonNegative:
198d1522513SSjoerd Meijer //        (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount
199d1522513SSjoerd Meijer // 3) The IV must be an induction phi with an increment equal to the
200d1522513SSjoerd Meijer //    vector width.
IsSafeActiveMask(IntrinsicInst * ActiveLaneMask,Value * TripCount)2011319d9bbSSjoerd Meijer bool MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask,
2020e49a40dSDavid Green                                           Value *TripCount) {
203595270aeSSjoerd Meijer   bool ForceTailPredication =
204595270aeSSjoerd Meijer     EnableTailPredication == TailPredication::ForceEnabledNoReductions ||
205595270aeSSjoerd Meijer     EnableTailPredication == TailPredication::ForceEnabled;
2066716e786SSjoerd Meijer 
207676febc0SSjoerd Meijer   Value *ElemCount = ActiveLaneMask->getOperand(1);
208258e2e9aSDavid Green   bool Changed = false;
209258e2e9aSDavid Green   if (!L->makeLoopInvariant(ElemCount, Changed))
210258e2e9aSDavid Green     return false;
211258e2e9aSDavid Green 
212676febc0SSjoerd Meijer   auto *EC= SE->getSCEV(ElemCount);
213676febc0SSjoerd Meijer   auto *TC = SE->getSCEV(TripCount);
2140e49a40dSDavid Green   int VectorWidth =
2150e49a40dSDavid Green       cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
216*ab0c5ceaSDavid Green   if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 &&
217*ab0c5ceaSDavid Green       VectorWidth != 16)
2180e49a40dSDavid Green     return false;
219676febc0SSjoerd Meijer   ConstantInt *ConstElemCount = nullptr;
220676febc0SSjoerd Meijer 
221f39f92c1SSjoerd Meijer   // 1) Smoke tests that the original scalar loop TripCount (TC) belongs to
222f39f92c1SSjoerd Meijer   // this loop.  The scalar tripcount corresponds the number of elements
223f39f92c1SSjoerd Meijer   // processed by the loop, so we will refer to that from this point on.
224676febc0SSjoerd Meijer   if (!SE->isLoopInvariant(EC, L)) {
225676febc0SSjoerd Meijer     LLVM_DEBUG(dbgs() << "ARM TP: element count must be loop invariant.\n");
226676febc0SSjoerd Meijer     return false;
227676febc0SSjoerd Meijer   }
228676febc0SSjoerd Meijer 
229676febc0SSjoerd Meijer   if ((ConstElemCount = dyn_cast<ConstantInt>(ElemCount))) {
230676febc0SSjoerd Meijer     ConstantInt *TC = dyn_cast<ConstantInt>(TripCount);
231676febc0SSjoerd Meijer     if (!TC) {
232676febc0SSjoerd Meijer       LLVM_DEBUG(dbgs() << "ARM TP: Constant tripcount expected in "
233676febc0SSjoerd Meijer                            "set.loop.iterations\n");
234676febc0SSjoerd Meijer       return false;
235676febc0SSjoerd Meijer     }
236676febc0SSjoerd Meijer 
237676febc0SSjoerd Meijer     // Calculate 2 tripcount values and check that they are consistent with
238f5abf0bdSDavid Green     // each other. The TripCount for a predicated vector loop body is
239f5abf0bdSDavid Green     // ceil(ElementCount/Width), or floor((ElementCount+Width-1)/Width) as we
240f5abf0bdSDavid Green     // work it out here.
241f5abf0bdSDavid Green     uint64_t TC1 = TC->getZExtValue();
242f5abf0bdSDavid Green     uint64_t TC2 =
243f5abf0bdSDavid Green         (ConstElemCount->getZExtValue() + VectorWidth - 1) / VectorWidth;
244676febc0SSjoerd Meijer 
245f5abf0bdSDavid Green     // If the tripcount values are inconsistent, we can't insert the VCTP and
246f5abf0bdSDavid Green     // trigger tail-predication; keep the intrinsic as a get.active.lane.mask
247f5abf0bdSDavid Green     // and legalize this.
248676febc0SSjoerd Meijer     if (TC1 != TC2) {
249676febc0SSjoerd Meijer       LLVM_DEBUG(dbgs() << "ARM TP: inconsistent constant tripcount values: "
250676febc0SSjoerd Meijer                  << TC1 << " from set.loop.iterations, and "
251676febc0SSjoerd Meijer                  << TC2 << " from get.active.lane.mask\n");
252676febc0SSjoerd Meijer       return false;
253676febc0SSjoerd Meijer     }
254b5c3efebSSjoerd Meijer   } else if (!ForceTailPredication) {
255f39f92c1SSjoerd Meijer     // 2) We need to prove that the sub expression that we create in the
256f39f92c1SSjoerd Meijer     // tail-predicated loop body, which calculates the remaining elements to be
257f39f92c1SSjoerd Meijer     // processed, is non-negative, i.e. it doesn't overflow:
258d1522513SSjoerd Meijer     //
259f39f92c1SSjoerd Meijer     //   ((ElementCount + VectorWidth - 1) / VectorWidth) - TripCount >= 0
260d1522513SSjoerd Meijer     //
261f39f92c1SSjoerd Meijer     // This is true if:
262d1522513SSjoerd Meijer     //
263f39f92c1SSjoerd Meijer     //    TripCount == (ElementCount + VectorWidth - 1) / VectorWidth
264d1522513SSjoerd Meijer     //
265f39f92c1SSjoerd Meijer     // which what we will be using here.
266d1522513SSjoerd Meijer     //
267f39f92c1SSjoerd Meijer     auto *VW = SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth));
268f39f92c1SSjoerd Meijer     // ElementCount + (VW-1):
269676febc0SSjoerd Meijer     auto *ECPlusVWMinus1 = SE->getAddExpr(EC,
270d1522513SSjoerd Meijer         SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1)));
271d1522513SSjoerd Meijer 
272f39f92c1SSjoerd Meijer     // Ceil = ElementCount + (VW-1) / VW
273f39f92c1SSjoerd Meijer     auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW);
274f39f92c1SSjoerd Meijer 
275509fba75STres Popp     // Prevent unused variable warnings with TC
276509fba75STres Popp     (void)TC;
277f39f92c1SSjoerd Meijer     LLVM_DEBUG(
278f39f92c1SSjoerd Meijer       dbgs() << "ARM TP: Analysing overflow behaviour for:\n";
279f39f92c1SSjoerd Meijer       dbgs() << "ARM TP: - TripCount = "; TC->dump();
280f39f92c1SSjoerd Meijer       dbgs() << "ARM TP: - ElemCount = "; EC->dump();
281f39f92c1SSjoerd Meijer       dbgs() << "ARM TP: - VecWidth =  " << VectorWidth << "\n";
282f39f92c1SSjoerd Meijer       dbgs() << "ARM TP: - (ElemCount+VW-1) / VW = "; Ceil->dump();
283f39f92c1SSjoerd Meijer     );
284f39f92c1SSjoerd Meijer 
285f39f92c1SSjoerd Meijer     // As an example, almost all the tripcount expressions (produced by the
286f39f92c1SSjoerd Meijer     // vectoriser) look like this:
287f39f92c1SSjoerd Meijer     //
288f39f92c1SSjoerd Meijer     //   TC = ((-4 + (4 * ((3 + %N) /u 4))<nuw>) /u 4)
289f39f92c1SSjoerd Meijer     //
290f39f92c1SSjoerd Meijer     // and "ElementCount + (VW-1) / VW":
291f39f92c1SSjoerd Meijer     //
292f39f92c1SSjoerd Meijer     //   Ceil = ((3 + %N) /u 4)
293f39f92c1SSjoerd Meijer     //
294f39f92c1SSjoerd Meijer     // Check for equality of TC and Ceil by calculating SCEV expression
295f39f92c1SSjoerd Meijer     // TC - Ceil and test it for zero.
296f39f92c1SSjoerd Meijer     //
29729fa37ecSPhilip Reames     const SCEV *Sub =
29829fa37ecSPhilip Reames       SE->getMinusSCEV(SE->getBackedgeTakenCount(L),
299f39f92c1SSjoerd Meijer                        SE->getUDivExpr(SE->getAddExpr(SE->getMulExpr(Ceil, VW),
300f39f92c1SSjoerd Meijer                                                       SE->getNegativeSCEV(VW)),
30129fa37ecSPhilip Reames                                        VW));
302f39f92c1SSjoerd Meijer 
30329fa37ecSPhilip Reames     // Use context sensitive facts about the path to the loop to refine.  This
30429fa37ecSPhilip Reames     // comes up as the backedge taken count can incorporate context sensitive
30529fa37ecSPhilip Reames     // reasoning, and our RHS just above doesn't.
30629fa37ecSPhilip Reames     Sub = SE->applyLoopGuards(Sub, L);
30729fa37ecSPhilip Reames 
30829fa37ecSPhilip Reames     if (!Sub->isZero()) {
309f39f92c1SSjoerd Meijer       LLVM_DEBUG(dbgs() << "ARM TP: possible overflow in sub expression.\n");
310d1522513SSjoerd Meijer       return false;
311d1522513SSjoerd Meijer     }
312f39f92c1SSjoerd Meijer   }
313d1522513SSjoerd Meijer 
3146716e786SSjoerd Meijer   // 3) Find out if IV is an induction phi. Note that we can't use Loop
315d1522513SSjoerd Meijer   // helpers here to get the induction variable, because the hardware loop is
3166716e786SSjoerd Meijer   // no longer in loopsimplify form, and also the hwloop intrinsic uses a
317d1522513SSjoerd Meijer   // different counter. Using SCEV, we check that the induction is of the
318d1522513SSjoerd Meijer   // form i = i + 4, where the increment must be equal to the VectorWidth.
319d1522513SSjoerd Meijer   auto *IV = ActiveLaneMask->getOperand(0);
320d1522513SSjoerd Meijer   auto *IVExpr = SE->getSCEV(IV);
321d1522513SSjoerd Meijer   auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr);
322f39f92c1SSjoerd Meijer 
323d1522513SSjoerd Meijer   if (!AddExpr) {
324d1522513SSjoerd Meijer     LLVM_DEBUG(dbgs() << "ARM TP: induction not an add expr: "; IVExpr->dump());
325d1522513SSjoerd Meijer     return false;
326d1522513SSjoerd Meijer   }
327d1522513SSjoerd Meijer   // Check that this AddRec is associated with this loop.
328d1522513SSjoerd Meijer   if (AddExpr->getLoop() != L) {
329d1522513SSjoerd Meijer     LLVM_DEBUG(dbgs() << "ARM TP: phi not part of this loop\n");
330d1522513SSjoerd Meijer     return false;
331d1522513SSjoerd Meijer   }
332f39f92c1SSjoerd Meijer   auto *Base = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
333f39f92c1SSjoerd Meijer   if (!Base || !Base->isZero()) {
334f39f92c1SSjoerd Meijer     LLVM_DEBUG(dbgs() << "ARM TP: induction base is not 0\n");
335f39f92c1SSjoerd Meijer     return false;
336f39f92c1SSjoerd Meijer   }
337d1522513SSjoerd Meijer   auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1));
338d1522513SSjoerd Meijer   if (!Step) {
339d1522513SSjoerd Meijer     LLVM_DEBUG(dbgs() << "ARM TP: induction step is not a constant: ";
340d1522513SSjoerd Meijer                AddExpr->getOperand(1)->dump());
341d1522513SSjoerd Meijer     return false;
342d1522513SSjoerd Meijer   }
343d1522513SSjoerd Meijer   auto StepValue = Step->getValue()->getSExtValue();
344d1522513SSjoerd Meijer   if (VectorWidth == StepValue)
345d1522513SSjoerd Meijer     return true;
346d1522513SSjoerd Meijer 
3470e49a40dSDavid Green   LLVM_DEBUG(dbgs() << "ARM TP: Step value " << StepValue
3480e49a40dSDavid Green                     << " doesn't match vector width " << VectorWidth << "\n");
349d1522513SSjoerd Meijer 
350d1522513SSjoerd Meijer   return false;
351d1522513SSjoerd Meijer }
352d1522513SSjoerd Meijer 
InsertVCTPIntrinsic(IntrinsicInst * ActiveLaneMask,Value * TripCount)353d1522513SSjoerd Meijer void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask,
3540e49a40dSDavid Green                                              Value *TripCount) {
355d1522513SSjoerd Meijer   IRBuilder<> Builder(L->getLoopPreheader()->getTerminator());
356312409e4SSam Parker   Module *M = L->getHeader()->getModule();
357312409e4SSam Parker   Type *Ty = IntegerType::get(M->getContext(), 32);
3580e49a40dSDavid Green   unsigned VectorWidth =
3590e49a40dSDavid Green       cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
360312409e4SSam Parker 
361d1522513SSjoerd Meijer   // Insert a phi to count the number of elements processed by the loop.
362d1522513SSjoerd Meijer   Builder.SetInsertPoint(L->getHeader()->getFirstNonPHI());
363d1522513SSjoerd Meijer   PHINode *Processed = Builder.CreatePHI(Ty, 2);
364c352e7fbSSjoerd Meijer   Processed->addIncoming(ActiveLaneMask->getOperand(1), L->getLoopPreheader());
365d1522513SSjoerd Meijer 
366c352e7fbSSjoerd Meijer   // Replace @llvm.get.active.mask() with the ARM specific VCTP intrinic, and
367c352e7fbSSjoerd Meijer   // thus represent the effect of tail predication.
368d1522513SSjoerd Meijer   Builder.SetInsertPoint(ActiveLaneMask);
369c352e7fbSSjoerd Meijer   ConstantInt *Factor = ConstantInt::get(cast<IntegerType>(Ty), VectorWidth);
3700efc9e5aSSjoerd Meijer 
371312409e4SSam Parker   Intrinsic::ID VCTPID;
3721319d9bbSSjoerd Meijer   switch (VectorWidth) {
373312409e4SSam Parker   default:
374312409e4SSam Parker     llvm_unreachable("unexpected number of lanes");
375*ab0c5ceaSDavid Green   case 2:  VCTPID = Intrinsic::arm_mve_vctp64; break;
37648cce077SSimon Tatham   case 4:  VCTPID = Intrinsic::arm_mve_vctp32; break;
37748cce077SSimon Tatham   case 8:  VCTPID = Intrinsic::arm_mve_vctp16; break;
37848cce077SSimon Tatham   case 16: VCTPID = Intrinsic::arm_mve_vctp8; break;
379312409e4SSam Parker   }
380312409e4SSam Parker   Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
381d1522513SSjoerd Meijer   Value *VCTPCall = Builder.CreateCall(VCTP, Processed);
382d1522513SSjoerd Meijer   ActiveLaneMask->replaceAllUsesWith(VCTPCall);
383312409e4SSam Parker 
384312409e4SSam Parker   // Add the incoming value to the new phi.
385aac03ae0SSam Parker   // TODO: This add likely already exists in the loop.
386aac03ae0SSam Parker   Value *Remaining = Builder.CreateSub(Processed, Factor);
387312409e4SSam Parker   Processed->addIncoming(Remaining, L->getLoopLatch());
3880efc9e5aSSjoerd Meijer   LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: "
389312409e4SSam Parker              << *Processed << "\n"
390d1522513SSjoerd Meijer              << "ARM TP: Inserted VCTP: " << *VCTPCall << "\n");
3910efc9e5aSSjoerd Meijer }
3920efc9e5aSSjoerd Meijer 
TryConvertActiveLaneMask(Value * TripCount)3930e49a40dSDavid Green bool MVETailPredication::TryConvertActiveLaneMask(Value *TripCount) {
3940e49a40dSDavid Green   SmallVector<IntrinsicInst *, 4> ActiveLaneMasks;
3950e49a40dSDavid Green   for (auto *BB : L->getBlocks())
3960e49a40dSDavid Green     for (auto &I : *BB)
3970e49a40dSDavid Green       if (auto *Int = dyn_cast<IntrinsicInst>(&I))
3980e49a40dSDavid Green         if (Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
3990e49a40dSDavid Green           ActiveLaneMasks.push_back(Int);
4000e49a40dSDavid Green 
4010e49a40dSDavid Green   if (ActiveLaneMasks.empty())
4020efc9e5aSSjoerd Meijer     return false;
4030efc9e5aSSjoerd Meijer 
4040efc9e5aSSjoerd Meijer   LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
4050efc9e5aSSjoerd Meijer 
4060e49a40dSDavid Green   for (auto *ActiveLaneMask : ActiveLaneMasks) {
407d1522513SSjoerd Meijer     LLVM_DEBUG(dbgs() << "ARM TP: Found active lane mask: "
408d1522513SSjoerd Meijer                       << *ActiveLaneMask << "\n");
4090736d1ccSSjoerd Meijer 
4100e49a40dSDavid Green     if (!IsSafeActiveMask(ActiveLaneMask, TripCount)) {
411d1522513SSjoerd Meijer       LLVM_DEBUG(dbgs() << "ARM TP: Not safe to insert VCTP.\n");
412b567ff2fSSjoerd Meijer       return false;
4130736d1ccSSjoerd Meijer     }
414d1522513SSjoerd Meijer     LLVM_DEBUG(dbgs() << "ARM TP: Safe to insert VCTP.\n");
4150e49a40dSDavid Green     InsertVCTPIntrinsic(ActiveLaneMask, TripCount);
416d1522513SSjoerd Meijer   }
417b567ff2fSSjoerd Meijer 
4180e49a40dSDavid Green   // Remove dead instructions and now dead phis.
4190e49a40dSDavid Green   for (auto *II : ActiveLaneMasks)
4200e49a40dSDavid Green     RecursivelyDeleteTriviallyDeadInstructions(II);
4210e49a40dSDavid Green   for (auto I : L->blocks())
4220e49a40dSDavid Green     DeleteDeadPHIs(I);
423312409e4SSam Parker   return true;
424312409e4SSam Parker }
425312409e4SSam Parker 
createMVETailPredicationPass()426312409e4SSam Parker Pass *llvm::createMVETailPredicationPass() {
427312409e4SSam Parker   return new MVETailPredication();
428312409e4SSam Parker }
429312409e4SSam Parker 
430312409e4SSam Parker char MVETailPredication::ID = 0;
431312409e4SSam Parker 
432312409e4SSam Parker INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
433312409e4SSam Parker INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)
434