1245679b6SChristopher Tetreault //===- MVETailPredication.cpp - MVE Tail Predication ------------*- C++ -*-===//
2312409e4SSam Parker //
3312409e4SSam Parker // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4312409e4SSam Parker // See https://llvm.org/LICENSE.txt for license information.
5312409e4SSam Parker // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6312409e4SSam Parker //
7312409e4SSam Parker //===----------------------------------------------------------------------===//
8312409e4SSam Parker //
9312409e4SSam Parker /// \file
10312409e4SSam Parker /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead
110736d1ccSSjoerd Meijer /// branches to help accelerate DSP applications. These two extensions,
120736d1ccSSjoerd Meijer /// combined with a new form of predication called tail-predication, can be used
130736d1ccSSjoerd Meijer /// to provide implicit vector predication within a low-overhead loop.
140736d1ccSSjoerd Meijer /// This is implicit because the predicate of active/inactive lanes is
150736d1ccSSjoerd Meijer /// calculated by hardware, and thus does not need to be explicitly passed
160736d1ccSSjoerd Meijer /// to vector instructions. The instructions responsible for this are the
170736d1ccSSjoerd Meijer /// DLSTP and WLSTP instructions, which setup a tail-predicated loop and the
180736d1ccSSjoerd Meijer /// the total number of data elements processed by the loop. The loop-end
190736d1ccSSjoerd Meijer /// LETP instruction is responsible for decrementing and setting the remaining
200736d1ccSSjoerd Meijer /// elements to be processed and generating the mask of active lanes.
210736d1ccSSjoerd Meijer ///
22312409e4SSam Parker /// The HardwareLoops pass inserts intrinsics identifying loops that the
23312409e4SSam Parker /// backend will attempt to convert into a low-overhead loop. The vectorizer is
24312409e4SSam Parker /// responsible for generating a vectorized loop in which the lanes are
250e49a40dSDavid Green /// predicated upon an get.active.lane.mask intrinsic. This pass looks at these
260e49a40dSDavid Green /// get.active.lane.mask intrinsic and attempts to convert them to VCTP
270e49a40dSDavid Green /// instructions. This will be picked up by the ARM Low-overhead loop pass later
280e49a40dSDavid Green /// in the backend, which performs the final transformation to a DLSTP or WLSTP
290e49a40dSDavid Green /// tail-predicated loop.
300e49a40dSDavid Green //
310e49a40dSDavid Green //===----------------------------------------------------------------------===//
32312409e4SSam Parker
335d986953SReid Kleckner #include "ARM.h"
345d986953SReid Kleckner #include "ARMSubtarget.h"
35595270aeSSjoerd Meijer #include "ARMTargetTransformInfo.h"
36312409e4SSam Parker #include "llvm/Analysis/LoopInfo.h"
37312409e4SSam Parker #include "llvm/Analysis/LoopPass.h"
38312409e4SSam Parker #include "llvm/Analysis/ScalarEvolution.h"
39312409e4SSam Parker #include "llvm/Analysis/ScalarEvolutionExpressions.h"
40c18b7536SSimon Pilgrim #include "llvm/Analysis/TargetLibraryInfo.h"
41312409e4SSam Parker #include "llvm/Analysis/TargetTransformInfo.h"
42312409e4SSam Parker #include "llvm/CodeGen/TargetPassConfig.h"
43312409e4SSam Parker #include "llvm/IR/IRBuilder.h"
445d986953SReid Kleckner #include "llvm/IR/Instructions.h"
455d986953SReid Kleckner #include "llvm/IR/IntrinsicsARM.h"
46312409e4SSam Parker #include "llvm/IR/PatternMatch.h"
47bcbd26bfSFlorian Hahn #include "llvm/InitializePasses.h"
48312409e4SSam Parker #include "llvm/Support/Debug.h"
49312409e4SSam Parker #include "llvm/Transforms/Utils/BasicBlockUtils.h"
500e49a40dSDavid Green #include "llvm/Transforms/Utils/Local.h"
518cba99e2SSjoerd Meijer #include "llvm/Transforms/Utils/LoopUtils.h"
52bcbd26bfSFlorian Hahn #include "llvm/Transforms/Utils/ScalarEvolutionExpander.h"
53312409e4SSam Parker
54312409e4SSam Parker using namespace llvm;
55312409e4SSam Parker
56312409e4SSam Parker #define DEBUG_TYPE "mve-tail-predication"
57312409e4SSam Parker #define DESC "Transform predicated vector loops to use MVE tail predication"
58312409e4SSam Parker
59595270aeSSjoerd Meijer cl::opt<TailPredication::Mode> EnableTailPredication(
602fc690acSSjoerd Meijer "tail-predication", cl::desc("MVE tail-predication pass options"),
611696dd27SSjoerd Meijer cl::init(TailPredication::Enabled),
62595270aeSSjoerd Meijer cl::values(clEnumValN(TailPredication::Disabled, "disabled",
63595270aeSSjoerd Meijer "Don't tail-predicate loops"),
64595270aeSSjoerd Meijer clEnumValN(TailPredication::EnabledNoReductions,
65595270aeSSjoerd Meijer "enabled-no-reductions",
66595270aeSSjoerd Meijer "Enable tail-predication, but not for reduction loops"),
67595270aeSSjoerd Meijer clEnumValN(TailPredication::Enabled,
68595270aeSSjoerd Meijer "enabled",
69595270aeSSjoerd Meijer "Enable tail-predication, including reduction loops"),
70595270aeSSjoerd Meijer clEnumValN(TailPredication::ForceEnabledNoReductions,
71595270aeSSjoerd Meijer "force-enabled-no-reductions",
72595270aeSSjoerd Meijer "Enable tail-predication, but not for reduction loops, "
73595270aeSSjoerd Meijer "and force this which might be unsafe"),
74595270aeSSjoerd Meijer clEnumValN(TailPredication::ForceEnabled,
75595270aeSSjoerd Meijer "force-enabled",
76595270aeSSjoerd Meijer "Enable tail-predication, including reduction loops, "
77595270aeSSjoerd Meijer "and force this which might be unsafe")));
78d1522513SSjoerd Meijer
79595270aeSSjoerd Meijer
80312409e4SSam Parker namespace {
81312409e4SSam Parker
82312409e4SSam Parker class MVETailPredication : public LoopPass {
83312409e4SSam Parker SmallVector<IntrinsicInst*, 4> MaskedInsts;
84312409e4SSam Parker Loop *L = nullptr;
85312409e4SSam Parker ScalarEvolution *SE = nullptr;
86312409e4SSam Parker TargetTransformInfo *TTI = nullptr;
87d9cb811cSSamuel Tebbs const ARMSubtarget *ST = nullptr;
88312409e4SSam Parker
89312409e4SSam Parker public:
90312409e4SSam Parker static char ID;
91312409e4SSam Parker
MVETailPredication()92312409e4SSam Parker MVETailPredication() : LoopPass(ID) { }
93312409e4SSam Parker
getAnalysisUsage(AnalysisUsage & AU) const94312409e4SSam Parker void getAnalysisUsage(AnalysisUsage &AU) const override {
95312409e4SSam Parker AU.addRequired<ScalarEvolutionWrapperPass>();
96312409e4SSam Parker AU.addRequired<LoopInfoWrapperPass>();
97312409e4SSam Parker AU.addRequired<TargetPassConfig>();
98312409e4SSam Parker AU.addRequired<TargetTransformInfoWrapperPass>();
99312409e4SSam Parker AU.addPreserved<LoopInfoWrapperPass>();
100312409e4SSam Parker AU.setPreservesCFG();
101312409e4SSam Parker }
102312409e4SSam Parker
103312409e4SSam Parker bool runOnLoop(Loop *L, LPPassManager&) override;
104312409e4SSam Parker
105312409e4SSam Parker private:
1060e49a40dSDavid Green /// Perform the relevant checks on the loop and convert active lane masks if
1070e49a40dSDavid Green /// possible.
1080e49a40dSDavid Green bool TryConvertActiveLaneMask(Value *TripCount);
109312409e4SSam Parker
110676febc0SSjoerd Meijer /// Perform several checks on the arguments of @llvm.get.active.lane.mask
111676febc0SSjoerd Meijer /// intrinsic. E.g., check that the loop induction variable and the element
112676febc0SSjoerd Meijer /// count are of the form we expect, and also perform overflow checks for
113676febc0SSjoerd Meijer /// the new expressions that are created.
1140e49a40dSDavid Green bool IsSafeActiveMask(IntrinsicInst *ActiveLaneMask, Value *TripCount);
1150efc9e5aSSjoerd Meijer
1160efc9e5aSSjoerd Meijer /// Insert the intrinsic to represent the effect of tail predication.
1170e49a40dSDavid Green void InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask, Value *TripCount);
1188cba99e2SSjoerd Meijer
1198cba99e2SSjoerd Meijer /// Rematerialize the iteration count in exit blocks, which enables
1208cba99e2SSjoerd Meijer /// ARMLowOverheadLoops to better optimise away loop update statements inside
1218cba99e2SSjoerd Meijer /// hardware-loops.
1228cba99e2SSjoerd Meijer void RematerializeIterCount();
123312409e4SSam Parker };
124312409e4SSam Parker
125312409e4SSam Parker } // end namespace
126312409e4SSam Parker
runOnLoop(Loop * L,LPPassManager &)127312409e4SSam Parker bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) {
128595270aeSSjoerd Meijer if (skipLoop(L) || !EnableTailPredication)
129312409e4SSam Parker return false;
130312409e4SSam Parker
131c04b9ba5SSam Parker MaskedInsts.clear();
132312409e4SSam Parker Function &F = *L->getHeader()->getParent();
133312409e4SSam Parker auto &TPC = getAnalysis<TargetPassConfig>();
134312409e4SSam Parker auto &TM = TPC.getTM<TargetMachine>();
135d9cb811cSSamuel Tebbs ST = &TM.getSubtarget<ARMSubtarget>(F);
136312409e4SSam Parker TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
137312409e4SSam Parker SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
138312409e4SSam Parker this->L = L;
139312409e4SSam Parker
140312409e4SSam Parker // The MVE and LOB extensions are combined to enable tail-predication, but
141312409e4SSam Parker // there's nothing preventing us from generating VCTP instructions for v8.1m.
142312409e4SSam Parker if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) {
1430efc9e5aSSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n");
144312409e4SSam Parker return false;
145312409e4SSam Parker }
146312409e4SSam Parker
147312409e4SSam Parker BasicBlock *Preheader = L->getLoopPreheader();
148312409e4SSam Parker if (!Preheader)
149312409e4SSam Parker return false;
150312409e4SSam Parker
151312409e4SSam Parker auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* {
152312409e4SSam Parker for (auto &I : *BB) {
153312409e4SSam Parker auto *Call = dyn_cast<IntrinsicInst>(&I);
154312409e4SSam Parker if (!Call)
155312409e4SSam Parker continue;
156312409e4SSam Parker
157312409e4SSam Parker Intrinsic::ID ID = Call->getIntrinsicID();
158b2ac9681SDavid Green if (ID == Intrinsic::start_loop_iterations ||
159fad70c30SDavid Green ID == Intrinsic::test_start_loop_iterations)
160312409e4SSam Parker return cast<IntrinsicInst>(&I);
161312409e4SSam Parker }
162312409e4SSam Parker return nullptr;
163312409e4SSam Parker };
164312409e4SSam Parker
165312409e4SSam Parker // Look for the hardware loop intrinsic that sets the iteration count.
166312409e4SSam Parker IntrinsicInst *Setup = FindLoopIterations(Preheader);
167312409e4SSam Parker
168312409e4SSam Parker // The test.set iteration could live in the pre-preheader.
169312409e4SSam Parker if (!Setup) {
170312409e4SSam Parker if (!Preheader->getSinglePredecessor())
171312409e4SSam Parker return false;
172312409e4SSam Parker Setup = FindLoopIterations(Preheader->getSinglePredecessor());
173312409e4SSam Parker if (!Setup)
174312409e4SSam Parker return false;
175312409e4SSam Parker }
176312409e4SSam Parker
1770e49a40dSDavid Green LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n");
178312409e4SSam Parker
1790e49a40dSDavid Green bool Changed = TryConvertActiveLaneMask(Setup->getArgOperand(0));
180312409e4SSam Parker
1810e49a40dSDavid Green return Changed;
1829feb429aSSam Parker }
1839feb429aSSam Parker
184d1522513SSjoerd Meijer // The active lane intrinsic has this form:
185d1522513SSjoerd Meijer //
186c352e7fbSSjoerd Meijer // @llvm.get.active.lane.mask(IV, TC)
187d1522513SSjoerd Meijer //
188d1522513SSjoerd Meijer // Here we perform checks that this intrinsic behaves as expected,
189d1522513SSjoerd Meijer // which means:
190d1522513SSjoerd Meijer //
191c352e7fbSSjoerd Meijer // 1) Check that the TripCount (TC) belongs to this loop (originally).
192c352e7fbSSjoerd Meijer // 2) The element count (TC) needs to be sufficiently large that the decrement
193c352e7fbSSjoerd Meijer // of element counter doesn't overflow, which means that we need to prove:
194d1522513SSjoerd Meijer // ceil(ElementCount / VectorWidth) >= TripCount
195d1522513SSjoerd Meijer // by rounding up ElementCount up:
196d1522513SSjoerd Meijer // ((ElementCount + (VectorWidth - 1)) / VectorWidth
197d1522513SSjoerd Meijer // and evaluate if expression isKnownNonNegative:
198d1522513SSjoerd Meijer // (((ElementCount + (VectorWidth - 1)) / VectorWidth) - TripCount
199d1522513SSjoerd Meijer // 3) The IV must be an induction phi with an increment equal to the
200d1522513SSjoerd Meijer // vector width.
IsSafeActiveMask(IntrinsicInst * ActiveLaneMask,Value * TripCount)2011319d9bbSSjoerd Meijer bool MVETailPredication::IsSafeActiveMask(IntrinsicInst *ActiveLaneMask,
2020e49a40dSDavid Green Value *TripCount) {
203595270aeSSjoerd Meijer bool ForceTailPredication =
204595270aeSSjoerd Meijer EnableTailPredication == TailPredication::ForceEnabledNoReductions ||
205595270aeSSjoerd Meijer EnableTailPredication == TailPredication::ForceEnabled;
2066716e786SSjoerd Meijer
207676febc0SSjoerd Meijer Value *ElemCount = ActiveLaneMask->getOperand(1);
208258e2e9aSDavid Green bool Changed = false;
209258e2e9aSDavid Green if (!L->makeLoopInvariant(ElemCount, Changed))
210258e2e9aSDavid Green return false;
211258e2e9aSDavid Green
212676febc0SSjoerd Meijer auto *EC= SE->getSCEV(ElemCount);
213676febc0SSjoerd Meijer auto *TC = SE->getSCEV(TripCount);
2140e49a40dSDavid Green int VectorWidth =
2150e49a40dSDavid Green cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
216*ab0c5ceaSDavid Green if (VectorWidth != 2 && VectorWidth != 4 && VectorWidth != 8 &&
217*ab0c5ceaSDavid Green VectorWidth != 16)
2180e49a40dSDavid Green return false;
219676febc0SSjoerd Meijer ConstantInt *ConstElemCount = nullptr;
220676febc0SSjoerd Meijer
221f39f92c1SSjoerd Meijer // 1) Smoke tests that the original scalar loop TripCount (TC) belongs to
222f39f92c1SSjoerd Meijer // this loop. The scalar tripcount corresponds the number of elements
223f39f92c1SSjoerd Meijer // processed by the loop, so we will refer to that from this point on.
224676febc0SSjoerd Meijer if (!SE->isLoopInvariant(EC, L)) {
225676febc0SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: element count must be loop invariant.\n");
226676febc0SSjoerd Meijer return false;
227676febc0SSjoerd Meijer }
228676febc0SSjoerd Meijer
229676febc0SSjoerd Meijer if ((ConstElemCount = dyn_cast<ConstantInt>(ElemCount))) {
230676febc0SSjoerd Meijer ConstantInt *TC = dyn_cast<ConstantInt>(TripCount);
231676febc0SSjoerd Meijer if (!TC) {
232676febc0SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: Constant tripcount expected in "
233676febc0SSjoerd Meijer "set.loop.iterations\n");
234676febc0SSjoerd Meijer return false;
235676febc0SSjoerd Meijer }
236676febc0SSjoerd Meijer
237676febc0SSjoerd Meijer // Calculate 2 tripcount values and check that they are consistent with
238f5abf0bdSDavid Green // each other. The TripCount for a predicated vector loop body is
239f5abf0bdSDavid Green // ceil(ElementCount/Width), or floor((ElementCount+Width-1)/Width) as we
240f5abf0bdSDavid Green // work it out here.
241f5abf0bdSDavid Green uint64_t TC1 = TC->getZExtValue();
242f5abf0bdSDavid Green uint64_t TC2 =
243f5abf0bdSDavid Green (ConstElemCount->getZExtValue() + VectorWidth - 1) / VectorWidth;
244676febc0SSjoerd Meijer
245f5abf0bdSDavid Green // If the tripcount values are inconsistent, we can't insert the VCTP and
246f5abf0bdSDavid Green // trigger tail-predication; keep the intrinsic as a get.active.lane.mask
247f5abf0bdSDavid Green // and legalize this.
248676febc0SSjoerd Meijer if (TC1 != TC2) {
249676febc0SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: inconsistent constant tripcount values: "
250676febc0SSjoerd Meijer << TC1 << " from set.loop.iterations, and "
251676febc0SSjoerd Meijer << TC2 << " from get.active.lane.mask\n");
252676febc0SSjoerd Meijer return false;
253676febc0SSjoerd Meijer }
254b5c3efebSSjoerd Meijer } else if (!ForceTailPredication) {
255f39f92c1SSjoerd Meijer // 2) We need to prove that the sub expression that we create in the
256f39f92c1SSjoerd Meijer // tail-predicated loop body, which calculates the remaining elements to be
257f39f92c1SSjoerd Meijer // processed, is non-negative, i.e. it doesn't overflow:
258d1522513SSjoerd Meijer //
259f39f92c1SSjoerd Meijer // ((ElementCount + VectorWidth - 1) / VectorWidth) - TripCount >= 0
260d1522513SSjoerd Meijer //
261f39f92c1SSjoerd Meijer // This is true if:
262d1522513SSjoerd Meijer //
263f39f92c1SSjoerd Meijer // TripCount == (ElementCount + VectorWidth - 1) / VectorWidth
264d1522513SSjoerd Meijer //
265f39f92c1SSjoerd Meijer // which what we will be using here.
266d1522513SSjoerd Meijer //
267f39f92c1SSjoerd Meijer auto *VW = SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth));
268f39f92c1SSjoerd Meijer // ElementCount + (VW-1):
269676febc0SSjoerd Meijer auto *ECPlusVWMinus1 = SE->getAddExpr(EC,
270d1522513SSjoerd Meijer SE->getSCEV(ConstantInt::get(TripCount->getType(), VectorWidth - 1)));
271d1522513SSjoerd Meijer
272f39f92c1SSjoerd Meijer // Ceil = ElementCount + (VW-1) / VW
273f39f92c1SSjoerd Meijer auto *Ceil = SE->getUDivExpr(ECPlusVWMinus1, VW);
274f39f92c1SSjoerd Meijer
275509fba75STres Popp // Prevent unused variable warnings with TC
276509fba75STres Popp (void)TC;
277f39f92c1SSjoerd Meijer LLVM_DEBUG(
278f39f92c1SSjoerd Meijer dbgs() << "ARM TP: Analysing overflow behaviour for:\n";
279f39f92c1SSjoerd Meijer dbgs() << "ARM TP: - TripCount = "; TC->dump();
280f39f92c1SSjoerd Meijer dbgs() << "ARM TP: - ElemCount = "; EC->dump();
281f39f92c1SSjoerd Meijer dbgs() << "ARM TP: - VecWidth = " << VectorWidth << "\n";
282f39f92c1SSjoerd Meijer dbgs() << "ARM TP: - (ElemCount+VW-1) / VW = "; Ceil->dump();
283f39f92c1SSjoerd Meijer );
284f39f92c1SSjoerd Meijer
285f39f92c1SSjoerd Meijer // As an example, almost all the tripcount expressions (produced by the
286f39f92c1SSjoerd Meijer // vectoriser) look like this:
287f39f92c1SSjoerd Meijer //
288f39f92c1SSjoerd Meijer // TC = ((-4 + (4 * ((3 + %N) /u 4))<nuw>) /u 4)
289f39f92c1SSjoerd Meijer //
290f39f92c1SSjoerd Meijer // and "ElementCount + (VW-1) / VW":
291f39f92c1SSjoerd Meijer //
292f39f92c1SSjoerd Meijer // Ceil = ((3 + %N) /u 4)
293f39f92c1SSjoerd Meijer //
294f39f92c1SSjoerd Meijer // Check for equality of TC and Ceil by calculating SCEV expression
295f39f92c1SSjoerd Meijer // TC - Ceil and test it for zero.
296f39f92c1SSjoerd Meijer //
29729fa37ecSPhilip Reames const SCEV *Sub =
29829fa37ecSPhilip Reames SE->getMinusSCEV(SE->getBackedgeTakenCount(L),
299f39f92c1SSjoerd Meijer SE->getUDivExpr(SE->getAddExpr(SE->getMulExpr(Ceil, VW),
300f39f92c1SSjoerd Meijer SE->getNegativeSCEV(VW)),
30129fa37ecSPhilip Reames VW));
302f39f92c1SSjoerd Meijer
30329fa37ecSPhilip Reames // Use context sensitive facts about the path to the loop to refine. This
30429fa37ecSPhilip Reames // comes up as the backedge taken count can incorporate context sensitive
30529fa37ecSPhilip Reames // reasoning, and our RHS just above doesn't.
30629fa37ecSPhilip Reames Sub = SE->applyLoopGuards(Sub, L);
30729fa37ecSPhilip Reames
30829fa37ecSPhilip Reames if (!Sub->isZero()) {
309f39f92c1SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: possible overflow in sub expression.\n");
310d1522513SSjoerd Meijer return false;
311d1522513SSjoerd Meijer }
312f39f92c1SSjoerd Meijer }
313d1522513SSjoerd Meijer
3146716e786SSjoerd Meijer // 3) Find out if IV is an induction phi. Note that we can't use Loop
315d1522513SSjoerd Meijer // helpers here to get the induction variable, because the hardware loop is
3166716e786SSjoerd Meijer // no longer in loopsimplify form, and also the hwloop intrinsic uses a
317d1522513SSjoerd Meijer // different counter. Using SCEV, we check that the induction is of the
318d1522513SSjoerd Meijer // form i = i + 4, where the increment must be equal to the VectorWidth.
319d1522513SSjoerd Meijer auto *IV = ActiveLaneMask->getOperand(0);
320d1522513SSjoerd Meijer auto *IVExpr = SE->getSCEV(IV);
321d1522513SSjoerd Meijer auto *AddExpr = dyn_cast<SCEVAddRecExpr>(IVExpr);
322f39f92c1SSjoerd Meijer
323d1522513SSjoerd Meijer if (!AddExpr) {
324d1522513SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: induction not an add expr: "; IVExpr->dump());
325d1522513SSjoerd Meijer return false;
326d1522513SSjoerd Meijer }
327d1522513SSjoerd Meijer // Check that this AddRec is associated with this loop.
328d1522513SSjoerd Meijer if (AddExpr->getLoop() != L) {
329d1522513SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: phi not part of this loop\n");
330d1522513SSjoerd Meijer return false;
331d1522513SSjoerd Meijer }
332f39f92c1SSjoerd Meijer auto *Base = dyn_cast<SCEVConstant>(AddExpr->getOperand(0));
333f39f92c1SSjoerd Meijer if (!Base || !Base->isZero()) {
334f39f92c1SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: induction base is not 0\n");
335f39f92c1SSjoerd Meijer return false;
336f39f92c1SSjoerd Meijer }
337d1522513SSjoerd Meijer auto *Step = dyn_cast<SCEVConstant>(AddExpr->getOperand(1));
338d1522513SSjoerd Meijer if (!Step) {
339d1522513SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: induction step is not a constant: ";
340d1522513SSjoerd Meijer AddExpr->getOperand(1)->dump());
341d1522513SSjoerd Meijer return false;
342d1522513SSjoerd Meijer }
343d1522513SSjoerd Meijer auto StepValue = Step->getValue()->getSExtValue();
344d1522513SSjoerd Meijer if (VectorWidth == StepValue)
345d1522513SSjoerd Meijer return true;
346d1522513SSjoerd Meijer
3470e49a40dSDavid Green LLVM_DEBUG(dbgs() << "ARM TP: Step value " << StepValue
3480e49a40dSDavid Green << " doesn't match vector width " << VectorWidth << "\n");
349d1522513SSjoerd Meijer
350d1522513SSjoerd Meijer return false;
351d1522513SSjoerd Meijer }
352d1522513SSjoerd Meijer
InsertVCTPIntrinsic(IntrinsicInst * ActiveLaneMask,Value * TripCount)353d1522513SSjoerd Meijer void MVETailPredication::InsertVCTPIntrinsic(IntrinsicInst *ActiveLaneMask,
3540e49a40dSDavid Green Value *TripCount) {
355d1522513SSjoerd Meijer IRBuilder<> Builder(L->getLoopPreheader()->getTerminator());
356312409e4SSam Parker Module *M = L->getHeader()->getModule();
357312409e4SSam Parker Type *Ty = IntegerType::get(M->getContext(), 32);
3580e49a40dSDavid Green unsigned VectorWidth =
3590e49a40dSDavid Green cast<FixedVectorType>(ActiveLaneMask->getType())->getNumElements();
360312409e4SSam Parker
361d1522513SSjoerd Meijer // Insert a phi to count the number of elements processed by the loop.
362d1522513SSjoerd Meijer Builder.SetInsertPoint(L->getHeader()->getFirstNonPHI());
363d1522513SSjoerd Meijer PHINode *Processed = Builder.CreatePHI(Ty, 2);
364c352e7fbSSjoerd Meijer Processed->addIncoming(ActiveLaneMask->getOperand(1), L->getLoopPreheader());
365d1522513SSjoerd Meijer
366c352e7fbSSjoerd Meijer // Replace @llvm.get.active.mask() with the ARM specific VCTP intrinic, and
367c352e7fbSSjoerd Meijer // thus represent the effect of tail predication.
368d1522513SSjoerd Meijer Builder.SetInsertPoint(ActiveLaneMask);
369c352e7fbSSjoerd Meijer ConstantInt *Factor = ConstantInt::get(cast<IntegerType>(Ty), VectorWidth);
3700efc9e5aSSjoerd Meijer
371312409e4SSam Parker Intrinsic::ID VCTPID;
3721319d9bbSSjoerd Meijer switch (VectorWidth) {
373312409e4SSam Parker default:
374312409e4SSam Parker llvm_unreachable("unexpected number of lanes");
375*ab0c5ceaSDavid Green case 2: VCTPID = Intrinsic::arm_mve_vctp64; break;
37648cce077SSimon Tatham case 4: VCTPID = Intrinsic::arm_mve_vctp32; break;
37748cce077SSimon Tatham case 8: VCTPID = Intrinsic::arm_mve_vctp16; break;
37848cce077SSimon Tatham case 16: VCTPID = Intrinsic::arm_mve_vctp8; break;
379312409e4SSam Parker }
380312409e4SSam Parker Function *VCTP = Intrinsic::getDeclaration(M, VCTPID);
381d1522513SSjoerd Meijer Value *VCTPCall = Builder.CreateCall(VCTP, Processed);
382d1522513SSjoerd Meijer ActiveLaneMask->replaceAllUsesWith(VCTPCall);
383312409e4SSam Parker
384312409e4SSam Parker // Add the incoming value to the new phi.
385aac03ae0SSam Parker // TODO: This add likely already exists in the loop.
386aac03ae0SSam Parker Value *Remaining = Builder.CreateSub(Processed, Factor);
387312409e4SSam Parker Processed->addIncoming(Remaining, L->getLoopLatch());
3880efc9e5aSSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: "
389312409e4SSam Parker << *Processed << "\n"
390d1522513SSjoerd Meijer << "ARM TP: Inserted VCTP: " << *VCTPCall << "\n");
3910efc9e5aSSjoerd Meijer }
3920efc9e5aSSjoerd Meijer
TryConvertActiveLaneMask(Value * TripCount)3930e49a40dSDavid Green bool MVETailPredication::TryConvertActiveLaneMask(Value *TripCount) {
3940e49a40dSDavid Green SmallVector<IntrinsicInst *, 4> ActiveLaneMasks;
3950e49a40dSDavid Green for (auto *BB : L->getBlocks())
3960e49a40dSDavid Green for (auto &I : *BB)
3970e49a40dSDavid Green if (auto *Int = dyn_cast<IntrinsicInst>(&I))
3980e49a40dSDavid Green if (Int->getIntrinsicID() == Intrinsic::get_active_lane_mask)
3990e49a40dSDavid Green ActiveLaneMasks.push_back(Int);
4000e49a40dSDavid Green
4010e49a40dSDavid Green if (ActiveLaneMasks.empty())
4020efc9e5aSSjoerd Meijer return false;
4030efc9e5aSSjoerd Meijer
4040efc9e5aSSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n");
4050efc9e5aSSjoerd Meijer
4060e49a40dSDavid Green for (auto *ActiveLaneMask : ActiveLaneMasks) {
407d1522513SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: Found active lane mask: "
408d1522513SSjoerd Meijer << *ActiveLaneMask << "\n");
4090736d1ccSSjoerd Meijer
4100e49a40dSDavid Green if (!IsSafeActiveMask(ActiveLaneMask, TripCount)) {
411d1522513SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: Not safe to insert VCTP.\n");
412b567ff2fSSjoerd Meijer return false;
4130736d1ccSSjoerd Meijer }
414d1522513SSjoerd Meijer LLVM_DEBUG(dbgs() << "ARM TP: Safe to insert VCTP.\n");
4150e49a40dSDavid Green InsertVCTPIntrinsic(ActiveLaneMask, TripCount);
416d1522513SSjoerd Meijer }
417b567ff2fSSjoerd Meijer
4180e49a40dSDavid Green // Remove dead instructions and now dead phis.
4190e49a40dSDavid Green for (auto *II : ActiveLaneMasks)
4200e49a40dSDavid Green RecursivelyDeleteTriviallyDeadInstructions(II);
4210e49a40dSDavid Green for (auto I : L->blocks())
4220e49a40dSDavid Green DeleteDeadPHIs(I);
423312409e4SSam Parker return true;
424312409e4SSam Parker }
425312409e4SSam Parker
createMVETailPredicationPass()426312409e4SSam Parker Pass *llvm::createMVETailPredicationPass() {
427312409e4SSam Parker return new MVETailPredication();
428312409e4SSam Parker }
429312409e4SSam Parker
430312409e4SSam Parker char MVETailPredication::ID = 0;
431312409e4SSam Parker
432312409e4SSam Parker INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false)
433312409e4SSam Parker INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false)
434