1 //===- MVETailPredication.cpp - MVE Tail Predication ----------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// Armv8.1m introduced MVE, M-Profile Vector Extension, and low-overhead 11 /// branches to help accelerate DSP applications. These two extensions can be 12 /// combined to provide implicit vector predication within a low-overhead loop. 13 /// The HardwareLoops pass inserts intrinsics identifying loops that the 14 /// backend will attempt to convert into a low-overhead loop. The vectorizer is 15 /// responsible for generating a vectorized loop in which the lanes are 16 /// predicated upon the iteration counter. This pass looks at these predicated 17 /// vector loops, that are targets for low-overhead loops, and prepares it for 18 /// code generation. Once the vectorizer has produced a masked loop, there's a 19 /// couple of final forms: 20 /// - A tail-predicated loop, with implicit predication. 21 /// - A loop containing multiple VCPT instructions, predicating multiple VPT 22 /// blocks of instructions operating on different vector types. 23 /// 24 /// This pass inserts the inserts the VCTP intrinsic to represent the effect of 25 /// tail predication. This will be picked up by the ARM Low-overhead loop pass, 26 /// which performs the final transformation to a DLSTP or WLSTP tail-predicated 27 /// loop. 28 29 #include "ARM.h" 30 #include "ARMSubtarget.h" 31 #include "llvm/Analysis/LoopInfo.h" 32 #include "llvm/Analysis/LoopPass.h" 33 #include "llvm/Analysis/ScalarEvolution.h" 34 #include "llvm/Analysis/ScalarEvolutionExpander.h" 35 #include "llvm/Analysis/ScalarEvolutionExpressions.h" 36 #include "llvm/Analysis/TargetTransformInfo.h" 37 #include "llvm/CodeGen/TargetPassConfig.h" 38 #include "llvm/InitializePasses.h" 39 #include "llvm/IR/IRBuilder.h" 40 #include "llvm/IR/Instructions.h" 41 #include "llvm/IR/IntrinsicsARM.h" 42 #include "llvm/IR/PatternMatch.h" 43 #include "llvm/Support/Debug.h" 44 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 45 #include "llvm/Transforms/Utils/LoopUtils.h" 46 47 using namespace llvm; 48 49 #define DEBUG_TYPE "mve-tail-predication" 50 #define DESC "Transform predicated vector loops to use MVE tail predication" 51 52 cl::opt<bool> 53 DisableTailPredication("disable-mve-tail-predication", cl::Hidden, 54 cl::init(true), 55 cl::desc("Disable MVE Tail Predication")); 56 namespace { 57 58 // Bookkeeping for pattern matching the loop trip count and the number of 59 // elements processed by the loop. 60 struct TripCountPattern { 61 // The Predicate used by the masked loads/stores, i.e. an icmp instruction 62 // which calculates active/inactive lanes 63 Instruction *Predicate = nullptr; 64 65 // The add instruction that increments the IV 66 Value *TripCount = nullptr; 67 68 // The number of elements processed by the vector loop. 69 Value *NumElements = nullptr; 70 71 VectorType *VecTy = nullptr; 72 Instruction *Shuffle = nullptr; 73 Instruction *Induction = nullptr; 74 75 TripCountPattern(Instruction *P, Value *TC, VectorType *VT) 76 : Predicate(P), TripCount(TC), VecTy(VT){}; 77 }; 78 79 class MVETailPredication : public LoopPass { 80 SmallVector<IntrinsicInst*, 4> MaskedInsts; 81 Loop *L = nullptr; 82 LoopInfo *LI = nullptr; 83 const DataLayout *DL; 84 DominatorTree *DT = nullptr; 85 ScalarEvolution *SE = nullptr; 86 TargetTransformInfo *TTI = nullptr; 87 TargetLibraryInfo *TLI = nullptr; 88 bool ClonedVCTPInExitBlock = false; 89 90 public: 91 static char ID; 92 93 MVETailPredication() : LoopPass(ID) { } 94 95 void getAnalysisUsage(AnalysisUsage &AU) const override { 96 AU.addRequired<ScalarEvolutionWrapperPass>(); 97 AU.addRequired<LoopInfoWrapperPass>(); 98 AU.addRequired<TargetPassConfig>(); 99 AU.addRequired<TargetTransformInfoWrapperPass>(); 100 AU.addRequired<DominatorTreeWrapperPass>(); 101 AU.addRequired<TargetLibraryInfoWrapperPass>(); 102 AU.addPreserved<LoopInfoWrapperPass>(); 103 AU.setPreservesCFG(); 104 } 105 106 bool runOnLoop(Loop *L, LPPassManager&) override; 107 108 private: 109 /// Perform the relevant checks on the loop and convert if possible. 110 bool TryConvert(Value *TripCount); 111 112 /// Return whether this is a vectorized loop, that contains masked 113 /// load/stores. 114 bool IsPredicatedVectorLoop(); 115 116 /// Compute a value for the total number of elements that the predicated 117 /// loop will process if it is a runtime value. 118 bool ComputeRuntimeElements(TripCountPattern &TCP); 119 120 /// Is the icmp that generates an i1 vector, based upon a loop counter 121 /// and a limit that is defined outside the loop. 122 bool isTailPredicate(TripCountPattern &TCP); 123 124 /// Insert the intrinsic to represent the effect of tail predication. 125 void InsertVCTPIntrinsic(TripCountPattern &TCP, 126 DenseMap<Instruction *, Instruction *> &NewPredicates); 127 128 /// Rematerialize the iteration count in exit blocks, which enables 129 /// ARMLowOverheadLoops to better optimise away loop update statements inside 130 /// hardware-loops. 131 void RematerializeIterCount(); 132 }; 133 134 } // end namespace 135 136 static bool IsDecrement(Instruction &I) { 137 auto *Call = dyn_cast<IntrinsicInst>(&I); 138 if (!Call) 139 return false; 140 141 Intrinsic::ID ID = Call->getIntrinsicID(); 142 return ID == Intrinsic::loop_decrement_reg; 143 } 144 145 static bool IsMasked(Instruction *I) { 146 auto *Call = dyn_cast<IntrinsicInst>(I); 147 if (!Call) 148 return false; 149 150 Intrinsic::ID ID = Call->getIntrinsicID(); 151 // TODO: Support gather/scatter expand/compress operations. 152 return ID == Intrinsic::masked_store || ID == Intrinsic::masked_load; 153 } 154 155 void MVETailPredication::RematerializeIterCount() { 156 SmallVector<WeakTrackingVH, 16> DeadInsts; 157 SCEVExpander Rewriter(*SE, *DL, "mvetp"); 158 ReplaceExitVal ReplaceExitValue = AlwaysRepl; 159 160 formLCSSARecursively(*L, *DT, LI, SE); 161 rewriteLoopExitValues(L, LI, TLI, SE, Rewriter, DT, ReplaceExitValue, 162 DeadInsts); 163 } 164 165 bool MVETailPredication::runOnLoop(Loop *L, LPPassManager&) { 166 if (skipLoop(L) || DisableTailPredication) 167 return false; 168 169 MaskedInsts.clear(); 170 Function &F = *L->getHeader()->getParent(); 171 auto &TPC = getAnalysis<TargetPassConfig>(); 172 auto &TM = TPC.getTM<TargetMachine>(); 173 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 174 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 175 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 176 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 177 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 178 auto *TLIP = getAnalysisIfAvailable<TargetLibraryInfoWrapperPass>(); 179 TLI = TLIP ? &TLIP->getTLI(*L->getHeader()->getParent()) : nullptr; 180 DL = &L->getHeader()->getModule()->getDataLayout(); 181 this->L = L; 182 183 // The MVE and LOB extensions are combined to enable tail-predication, but 184 // there's nothing preventing us from generating VCTP instructions for v8.1m. 185 if (!ST->hasMVEIntegerOps() || !ST->hasV8_1MMainlineOps()) { 186 LLVM_DEBUG(dbgs() << "ARM TP: Not a v8.1m.main+mve target.\n"); 187 return false; 188 } 189 190 BasicBlock *Preheader = L->getLoopPreheader(); 191 if (!Preheader) 192 return false; 193 194 auto FindLoopIterations = [](BasicBlock *BB) -> IntrinsicInst* { 195 for (auto &I : *BB) { 196 auto *Call = dyn_cast<IntrinsicInst>(&I); 197 if (!Call) 198 continue; 199 200 Intrinsic::ID ID = Call->getIntrinsicID(); 201 if (ID == Intrinsic::set_loop_iterations || 202 ID == Intrinsic::test_set_loop_iterations) 203 return cast<IntrinsicInst>(&I); 204 } 205 return nullptr; 206 }; 207 208 // Look for the hardware loop intrinsic that sets the iteration count. 209 IntrinsicInst *Setup = FindLoopIterations(Preheader); 210 211 // The test.set iteration could live in the pre-preheader. 212 if (!Setup) { 213 if (!Preheader->getSinglePredecessor()) 214 return false; 215 Setup = FindLoopIterations(Preheader->getSinglePredecessor()); 216 if (!Setup) 217 return false; 218 } 219 220 // Search for the hardware loop intrinic that decrements the loop counter. 221 IntrinsicInst *Decrement = nullptr; 222 for (auto *BB : L->getBlocks()) { 223 for (auto &I : *BB) { 224 if (IsDecrement(I)) { 225 Decrement = cast<IntrinsicInst>(&I); 226 break; 227 } 228 } 229 } 230 231 if (!Decrement) 232 return false; 233 234 ClonedVCTPInExitBlock = false; 235 LLVM_DEBUG(dbgs() << "ARM TP: Running on Loop: " << *L << *Setup << "\n" 236 << *Decrement << "\n"); 237 238 if (TryConvert(Setup->getArgOperand(0))) { 239 if (ClonedVCTPInExitBlock) 240 RematerializeIterCount(); 241 return true; 242 } 243 244 return false; 245 } 246 247 // Pattern match predicates/masks and determine if they use the loop induction 248 // variable to control the number of elements processed by the loop. If so, 249 // the loop is a candidate for tail-predication. 250 bool MVETailPredication::isTailPredicate(TripCountPattern &TCP) { 251 using namespace PatternMatch; 252 253 // Pattern match the loop body and find the add with takes the index iv 254 // and adds a constant vector to it: 255 // 256 // vector.body: 257 // .. 258 // %index = phi i32 259 // %broadcast.splatinsert = insertelement <4 x i32> undef, i32 %index, i32 0 260 // %broadcast.splat = shufflevector <4 x i32> %broadcast.splatinsert, 261 // <4 x i32> undef, 262 // <4 x i32> zeroinitializer 263 // %induction = add <4 x i32> %broadcast.splat, <i32 0, i32 1, i32 2, i32 3> 264 // %pred = icmp ule <4 x i32> %induction, %broadcast.splat11 265 266 Instruction *BroadcastSplat = nullptr; 267 Constant *Const = nullptr; 268 if (!match(TCP.Induction, 269 m_Add(m_Instruction(BroadcastSplat), m_Constant(Const)))) 270 return false; 271 272 // Check that we're adding <0, 1, 2, 3... 273 if (auto *CDS = dyn_cast<ConstantDataSequential>(Const)) { 274 for (unsigned i = 0; i < CDS->getNumElements(); ++i) { 275 if (CDS->getElementAsInteger(i) != i) 276 return false; 277 } 278 } else 279 return false; 280 281 Instruction *Insert = nullptr; 282 // The shuffle which broadcasts the index iv into a vector. 283 if (!match(BroadcastSplat, 284 m_ShuffleVector(m_Instruction(Insert), m_Undef(), m_Zero()))) 285 return false; 286 287 // The insert element which initialises a vector with the index iv. 288 Instruction *IV = nullptr; 289 if (!match(Insert, m_InsertElement(m_Undef(), m_Instruction(IV), m_Zero()))) 290 return false; 291 292 // The index iv. 293 auto *Phi = dyn_cast<PHINode>(IV); 294 if (!Phi) 295 return false; 296 297 // TODO: Don't think we need to check the entry value. 298 Value *OnEntry = Phi->getIncomingValueForBlock(L->getLoopPreheader()); 299 if (!match(OnEntry, m_Zero())) 300 return false; 301 302 Value *InLoop = Phi->getIncomingValueForBlock(L->getLoopLatch()); 303 unsigned Lanes = cast<VectorType>(Insert->getType())->getNumElements(); 304 305 Instruction *LHS = nullptr; 306 if (!match(InLoop, m_Add(m_Instruction(LHS), m_SpecificInt(Lanes)))) 307 return false; 308 309 return LHS == Phi; 310 } 311 312 static VectorType *getVectorType(IntrinsicInst *I) { 313 unsigned TypeOp = I->getIntrinsicID() == Intrinsic::masked_load ? 0 : 1; 314 auto *PtrTy = cast<PointerType>(I->getOperand(TypeOp)->getType()); 315 return cast<VectorType>(PtrTy->getElementType()); 316 } 317 318 bool MVETailPredication::IsPredicatedVectorLoop() { 319 // Check that the loop contains at least one masked load/store intrinsic. 320 // We only support 'normal' vector instructions - other than masked 321 // load/stores. 322 for (auto *BB : L->getBlocks()) { 323 for (auto &I : *BB) { 324 if (IsMasked(&I)) { 325 VectorType *VecTy = getVectorType(cast<IntrinsicInst>(&I)); 326 unsigned Lanes = VecTy->getNumElements(); 327 unsigned ElementWidth = VecTy->getScalarSizeInBits(); 328 // MVE vectors are 128-bit, but don't support 128 x i1. 329 // TODO: Can we support vectors larger than 128-bits? 330 unsigned MaxWidth = TTI->getRegisterBitWidth(true); 331 if (Lanes * ElementWidth > MaxWidth || Lanes == MaxWidth) 332 return false; 333 MaskedInsts.push_back(cast<IntrinsicInst>(&I)); 334 } else if (auto *Int = dyn_cast<IntrinsicInst>(&I)) { 335 for (auto &U : Int->args()) { 336 if (isa<VectorType>(U->getType())) 337 return false; 338 } 339 } 340 } 341 } 342 343 return !MaskedInsts.empty(); 344 } 345 346 // Pattern match the predicate, which is an icmp with a constant vector of this 347 // form: 348 // 349 // icmp ult <4 x i32> %induction, <i32 32002, i32 32002, i32 32002, i32 32002> 350 // 351 // and return the constant, i.e. 32002 in this example. This is assumed to be 352 // the scalar loop iteration count: the number of loop elements by the 353 // the vector loop. Further checks are performed in function isTailPredicate(), 354 // to verify 'induction' behaves as an induction variable. 355 // 356 static bool ComputeConstElements(TripCountPattern &TCP) { 357 if (!dyn_cast<ConstantInt>(TCP.TripCount)) 358 return false; 359 360 ConstantInt *VF = ConstantInt::get( 361 cast<IntegerType>(TCP.TripCount->getType()), TCP.VecTy->getNumElements()); 362 using namespace PatternMatch; 363 CmpInst::Predicate CC; 364 365 if (!match(TCP.Predicate, m_ICmp(CC, m_Instruction(TCP.Induction), 366 m_AnyIntegralConstant())) || 367 CC != ICmpInst::ICMP_ULT) 368 return false; 369 370 LLVM_DEBUG(dbgs() << "ARM TP: icmp with constants: "; TCP.Predicate->dump();); 371 Value *ConstVec = TCP.Predicate->getOperand(1); 372 373 auto *CDS = dyn_cast<ConstantDataSequential>(ConstVec); 374 if (!CDS || CDS->getNumElements() != VF->getSExtValue()) 375 return false; 376 377 if ((TCP.NumElements = CDS->getSplatValue())) { 378 assert(dyn_cast<ConstantInt>(TCP.NumElements)->getSExtValue() % 379 VF->getSExtValue() != 380 0 && 381 "tail-predication: trip count should not be a multiple of the VF"); 382 LLVM_DEBUG(dbgs() << "ARM TP: Found const elem count: " << *TCP.NumElements 383 << "\n"); 384 return true; 385 } 386 return false; 387 } 388 389 // Pattern match the loop iteration count setup: 390 // 391 // %trip.count.minus.1 = add i32 %N, -1 392 // %broadcast.splatinsert10 = insertelement <4 x i32> undef, 393 // i32 %trip.count.minus.1, i32 0 394 // %broadcast.splat11 = shufflevector <4 x i32> %broadcast.splatinsert10, 395 // <4 x i32> undef, 396 // <4 x i32> zeroinitializer 397 // .. 398 // vector.body: 399 // .. 400 // 401 static bool MatchElemCountLoopSetup(Loop *L, Instruction *Shuffle, 402 Value *NumElements) { 403 using namespace PatternMatch; 404 Instruction *Insert = nullptr; 405 406 if (!match(Shuffle, 407 m_ShuffleVector(m_Instruction(Insert), m_Undef(), m_Zero()))) 408 return false; 409 410 // Insert the limit into a vector. 411 Instruction *BECount = nullptr; 412 if (!match(Insert, 413 m_InsertElement(m_Undef(), m_Instruction(BECount), m_Zero()))) 414 return false; 415 416 // The limit calculation, backedge count. 417 Value *TripCount = nullptr; 418 if (!match(BECount, m_Add(m_Value(TripCount), m_AllOnes()))) 419 return false; 420 421 if (TripCount != NumElements || !L->isLoopInvariant(BECount)) 422 return false; 423 424 return true; 425 } 426 427 bool MVETailPredication::ComputeRuntimeElements(TripCountPattern &TCP) { 428 using namespace PatternMatch; 429 const SCEV *TripCountSE = SE->getSCEV(TCP.TripCount); 430 ConstantInt *VF = ConstantInt::get( 431 cast<IntegerType>(TCP.TripCount->getType()), TCP.VecTy->getNumElements()); 432 433 if (VF->equalsInt(1)) 434 return false; 435 436 CmpInst::Predicate Pred; 437 if (!match(TCP.Predicate, m_ICmp(Pred, m_Instruction(TCP.Induction), 438 m_Instruction(TCP.Shuffle))) || 439 Pred != ICmpInst::ICMP_ULE) 440 return false; 441 442 LLVM_DEBUG(dbgs() << "Computing number of elements for vector trip count: "; 443 TCP.TripCount->dump()); 444 445 // Otherwise, continue and try to pattern match the vector iteration 446 // count expression 447 auto VisitAdd = [&](const SCEVAddExpr *S) -> const SCEVMulExpr * { 448 if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 449 if (Const->getAPInt() != -VF->getValue()) 450 return nullptr; 451 } else 452 return nullptr; 453 return dyn_cast<SCEVMulExpr>(S->getOperand(1)); 454 }; 455 456 auto VisitMul = [&](const SCEVMulExpr *S) -> const SCEVUDivExpr * { 457 if (auto *Const = dyn_cast<SCEVConstant>(S->getOperand(0))) { 458 if (Const->getValue() != VF) 459 return nullptr; 460 } else 461 return nullptr; 462 return dyn_cast<SCEVUDivExpr>(S->getOperand(1)); 463 }; 464 465 auto VisitDiv = [&](const SCEVUDivExpr *S) -> const SCEV * { 466 if (auto *Const = dyn_cast<SCEVConstant>(S->getRHS())) { 467 if (Const->getValue() != VF) 468 return nullptr; 469 } else 470 return nullptr; 471 472 if (auto *RoundUp = dyn_cast<SCEVAddExpr>(S->getLHS())) { 473 if (auto *Const = dyn_cast<SCEVConstant>(RoundUp->getOperand(0))) { 474 if (Const->getAPInt() != (VF->getValue() - 1)) 475 return nullptr; 476 } else 477 return nullptr; 478 479 return RoundUp->getOperand(1); 480 } 481 return nullptr; 482 }; 483 484 // TODO: Can we use SCEV helpers, such as findArrayDimensions, and friends to 485 // determine the numbers of elements instead? Looks like this is what is used 486 // for delinearization, but I'm not sure if it can be applied to the 487 // vectorized form - at least not without a bit more work than I feel 488 // comfortable with. 489 490 // Search for Elems in the following SCEV: 491 // (1 + ((-VF + (VF * (((VF - 1) + %Elems) /u VF))<nuw>) /u VF))<nuw><nsw> 492 const SCEV *Elems = nullptr; 493 if (auto *TC = dyn_cast<SCEVAddExpr>(TripCountSE)) 494 if (auto *Div = dyn_cast<SCEVUDivExpr>(TC->getOperand(1))) 495 if (auto *Add = dyn_cast<SCEVAddExpr>(Div->getLHS())) 496 if (auto *Mul = VisitAdd(Add)) 497 if (auto *Div = VisitMul(Mul)) 498 if (auto *Res = VisitDiv(Div)) 499 Elems = Res; 500 501 if (!Elems) 502 return false; 503 504 Instruction *InsertPt = L->getLoopPreheader()->getTerminator(); 505 if (!isSafeToExpandAt(Elems, InsertPt, *SE)) 506 return false; 507 508 auto DL = L->getHeader()->getModule()->getDataLayout(); 509 SCEVExpander Expander(*SE, DL, "elements"); 510 TCP.NumElements = Expander.expandCodeFor(Elems, Elems->getType(), InsertPt); 511 512 if (!MatchElemCountLoopSetup(L, TCP.Shuffle, TCP.NumElements)) 513 return false; 514 515 return true; 516 } 517 518 // Look through the exit block to see whether there's a duplicate predicate 519 // instruction. This can happen when we need to perform a select on values 520 // from the last and previous iteration. Instead of doing a straight 521 // replacement of that predicate with the vctp, clone the vctp and place it 522 // in the block. This means that the VPR doesn't have to be live into the 523 // exit block which should make it easier to convert this loop into a proper 524 // tail predicated loop. 525 static bool Cleanup(DenseMap<Instruction*, Instruction*> &NewPredicates, 526 SetVector<Instruction*> &MaybeDead, Loop *L) { 527 BasicBlock *Exit = L->getUniqueExitBlock(); 528 if (!Exit) { 529 LLVM_DEBUG(dbgs() << "ARM TP: can't find loop exit block\n"); 530 return false; 531 } 532 533 bool ClonedVCTPInExitBlock = false; 534 535 for (auto &Pair : NewPredicates) { 536 Instruction *OldPred = Pair.first; 537 Instruction *NewPred = Pair.second; 538 539 for (auto &I : *Exit) { 540 if (I.isSameOperationAs(OldPred)) { 541 Instruction *PredClone = NewPred->clone(); 542 PredClone->insertBefore(&I); 543 I.replaceAllUsesWith(PredClone); 544 MaybeDead.insert(&I); 545 ClonedVCTPInExitBlock = true; 546 LLVM_DEBUG(dbgs() << "ARM TP: replacing: "; I.dump(); 547 dbgs() << "ARM TP: with: "; PredClone->dump()); 548 break; 549 } 550 } 551 } 552 553 // Drop references and add operands to check for dead. 554 SmallPtrSet<Instruction*, 4> Dead; 555 while (!MaybeDead.empty()) { 556 auto *I = MaybeDead.front(); 557 MaybeDead.remove(I); 558 if (I->hasNUsesOrMore(1)) 559 continue; 560 561 for (auto &U : I->operands()) { 562 if (auto *OpI = dyn_cast<Instruction>(U)) 563 MaybeDead.insert(OpI); 564 } 565 I->dropAllReferences(); 566 Dead.insert(I); 567 } 568 569 for (auto *I : Dead) { 570 LLVM_DEBUG(dbgs() << "ARM TP: removing dead insn: "; I->dump()); 571 I->eraseFromParent(); 572 } 573 574 for (auto I : L->blocks()) 575 DeleteDeadPHIs(I); 576 577 return ClonedVCTPInExitBlock; 578 } 579 580 void MVETailPredication::InsertVCTPIntrinsic(TripCountPattern &TCP, 581 DenseMap<Instruction*, Instruction*> &NewPredicates) { 582 IRBuilder<> Builder(L->getHeader()->getFirstNonPHI()); 583 Module *M = L->getHeader()->getModule(); 584 Type *Ty = IntegerType::get(M->getContext(), 32); 585 586 // Insert a phi to count the number of elements processed by the loop. 587 PHINode *Processed = Builder.CreatePHI(Ty, 2); 588 Processed->addIncoming(TCP.NumElements, L->getLoopPreheader()); 589 590 // Insert the intrinsic to represent the effect of tail predication. 591 Builder.SetInsertPoint(cast<Instruction>(TCP.Predicate)); 592 ConstantInt *Factor = 593 ConstantInt::get(cast<IntegerType>(Ty), TCP.VecTy->getNumElements()); 594 595 Intrinsic::ID VCTPID; 596 switch (TCP.VecTy->getNumElements()) { 597 default: 598 llvm_unreachable("unexpected number of lanes"); 599 case 4: VCTPID = Intrinsic::arm_mve_vctp32; break; 600 case 8: VCTPID = Intrinsic::arm_mve_vctp16; break; 601 case 16: VCTPID = Intrinsic::arm_mve_vctp8; break; 602 603 // FIXME: vctp64 currently not supported because the predicate 604 // vector wants to be <2 x i1>, but v2i1 is not a legal MVE 605 // type, so problems happen at isel time. 606 // Intrinsic::arm_mve_vctp64 exists for ACLE intrinsics 607 // purposes, but takes a v4i1 instead of a v2i1. 608 } 609 Function *VCTP = Intrinsic::getDeclaration(M, VCTPID); 610 Value *TailPredicate = Builder.CreateCall(VCTP, Processed); 611 TCP.Predicate->replaceAllUsesWith(TailPredicate); 612 NewPredicates[TCP.Predicate] = cast<Instruction>(TailPredicate); 613 614 // Add the incoming value to the new phi. 615 // TODO: This add likely already exists in the loop. 616 Value *Remaining = Builder.CreateSub(Processed, Factor); 617 Processed->addIncoming(Remaining, L->getLoopLatch()); 618 LLVM_DEBUG(dbgs() << "ARM TP: Insert processed elements phi: " 619 << *Processed << "\n" 620 << "ARM TP: Inserted VCTP: " << *TailPredicate << "\n"); 621 } 622 623 bool MVETailPredication::TryConvert(Value *TripCount) { 624 if (!IsPredicatedVectorLoop()) { 625 LLVM_DEBUG(dbgs() << "ARM TP: no masked instructions in loop.\n"); 626 return false; 627 } 628 629 LLVM_DEBUG(dbgs() << "ARM TP: Found predicated vector loop.\n"); 630 631 // Walk through the masked intrinsics and try to find whether the predicate 632 // operand is generated from an induction variable. 633 SetVector<Instruction*> Predicates; 634 DenseMap<Instruction*, Instruction*> NewPredicates; 635 636 for (auto *I : MaskedInsts) { 637 Intrinsic::ID ID = I->getIntrinsicID(); 638 unsigned PredOp = ID == Intrinsic::masked_load ? 2 : 3; 639 auto *Predicate = dyn_cast<Instruction>(I->getArgOperand(PredOp)); 640 if (!Predicate || Predicates.count(Predicate)) 641 continue; 642 643 TripCountPattern TCP(Predicate, TripCount, getVectorType(I)); 644 645 if (!(ComputeConstElements(TCP) || ComputeRuntimeElements(TCP))) 646 continue; 647 648 if (!isTailPredicate(TCP)) { 649 LLVM_DEBUG(dbgs() << "ARM TP: Not tail predicate: " << *Predicate << "\n"); 650 continue; 651 } 652 653 LLVM_DEBUG(dbgs() << "ARM TP: Found tail predicate: " << *Predicate << "\n"); 654 Predicates.insert(Predicate); 655 InsertVCTPIntrinsic(TCP, NewPredicates); 656 } 657 658 if (!NewPredicates.size()) 659 return false; 660 661 // Now clean up. 662 ClonedVCTPInExitBlock = Cleanup(NewPredicates, Predicates, L); 663 return true; 664 } 665 666 Pass *llvm::createMVETailPredicationPass() { 667 return new MVETailPredication(); 668 } 669 670 char MVETailPredication::ID = 0; 671 672 INITIALIZE_PASS_BEGIN(MVETailPredication, DEBUG_TYPE, DESC, false, false) 673 INITIALIZE_PASS_END(MVETailPredication, DEBUG_TYPE, DESC, false, false) 674