1 //===- ARMParallelDSP.cpp - Parallel DSP Pass -----------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// \file 10 /// Armv6 introduced instructions to perform 32-bit SIMD operations. The 11 /// purpose of this pass is do some IR pattern matching to create ACLE 12 /// DSP intrinsics, which map on these 32-bit SIMD operations. 13 /// This pass runs only when unaligned accesses is supported/enabled. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/ADT/Statistic.h" 18 #include "llvm/ADT/SmallPtrSet.h" 19 #include "llvm/Analysis/AliasAnalysis.h" 20 #include "llvm/Analysis/LoopAccessAnalysis.h" 21 #include "llvm/IR/Instructions.h" 22 #include "llvm/IR/NoFolder.h" 23 #include "llvm/Transforms/Scalar.h" 24 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 25 #include "llvm/Pass.h" 26 #include "llvm/PassRegistry.h" 27 #include "llvm/PassSupport.h" 28 #include "llvm/Support/Debug.h" 29 #include "llvm/IR/PatternMatch.h" 30 #include "llvm/CodeGen/TargetPassConfig.h" 31 #include "ARM.h" 32 #include "ARMSubtarget.h" 33 34 using namespace llvm; 35 using namespace PatternMatch; 36 37 #define DEBUG_TYPE "arm-parallel-dsp" 38 39 STATISTIC(NumSMLAD , "Number of smlad instructions generated"); 40 41 static cl::opt<bool> 42 DisableParallelDSP("disable-arm-parallel-dsp", cl::Hidden, cl::init(false), 43 cl::desc("Disable the ARM Parallel DSP pass")); 44 45 namespace { 46 struct OpChain; 47 struct MulCandidate; 48 class Reduction; 49 50 using MulCandList = SmallVector<std::unique_ptr<MulCandidate>, 8>; 51 using ReductionList = SmallVector<Reduction, 8>; 52 using ValueList = SmallVector<Value*, 8>; 53 using MemInstList = SmallVector<LoadInst*, 8>; 54 using PMACPair = std::pair<MulCandidate*,MulCandidate*>; 55 using PMACPairList = SmallVector<PMACPair, 8>; 56 57 // 'MulCandidate' holds the multiplication instructions that are candidates 58 // for parallel execution. 59 struct MulCandidate { 60 Instruction *Root; 61 MemInstList VecLd; // Container for loads to widen. 62 Value* LHS; 63 Value* RHS; 64 bool Exchange = false; 65 bool ReadOnly = true; 66 67 MulCandidate(Instruction *I, ValueList &lhs, ValueList &rhs) : 68 Root(I), LHS(lhs.front()), RHS(rhs.front()) { } 69 70 bool HasTwoLoadInputs() const { 71 return isa<LoadInst>(LHS) && isa<LoadInst>(RHS); 72 } 73 74 LoadInst *getBaseLoad() const { 75 return cast<LoadInst>(LHS); 76 } 77 }; 78 79 /// Represent a sequence of multiply-accumulate operations with the aim to 80 /// perform the multiplications in parallel. 81 class Reduction { 82 Instruction *Root = nullptr; 83 Value *Acc = nullptr; 84 MulCandList Muls; 85 PMACPairList MulPairs; 86 SmallPtrSet<Instruction*, 4> Adds; 87 88 public: 89 Reduction() = delete; 90 91 Reduction (Instruction *Add) : Root(Add) { } 92 93 /// Record an Add instruction that is a part of the this reduction. 94 void InsertAdd(Instruction *I) { Adds.insert(I); } 95 96 /// Record a MulCandidate, rooted at a Mul instruction, that is a part of 97 /// this reduction. 98 void InsertMul(Instruction *I, ValueList &LHS, ValueList &RHS) { 99 Muls.push_back(make_unique<MulCandidate>(I, LHS, RHS)); 100 } 101 102 /// Add the incoming accumulator value, returns true if a value had not 103 /// already been added. Returning false signals to the user that this 104 /// reduction already has a value to initialise the accumulator. 105 bool InsertAcc(Value *V) { 106 if (Acc) 107 return false; 108 Acc = V; 109 return true; 110 } 111 112 /// Set two MulCandidates, rooted at muls, that can be executed as a single 113 /// parallel operation. 114 void AddMulPair(MulCandidate *Mul0, MulCandidate *Mul1) { 115 MulPairs.push_back(std::make_pair(Mul0, Mul1)); 116 } 117 118 /// Return true if enough mul operations are found that can be executed in 119 /// parallel. 120 bool CreateParallelPairs(); 121 122 /// Return the add instruction which is the root of the reduction. 123 Instruction *getRoot() { return Root; } 124 125 bool is64Bit() const { return Root->getType()->isIntegerTy(64); } 126 127 /// Return the incoming value to be accumulated. This maybe null. 128 Value *getAccumulator() { return Acc; } 129 130 /// Return the set of adds that comprise the reduction. 131 SmallPtrSetImpl<Instruction*> &getAdds() { return Adds; } 132 133 /// Return the MulCandidate, rooted at mul instruction, that comprise the 134 /// the reduction. 135 MulCandList &getMuls() { return Muls; } 136 137 /// Return the MulCandidate, rooted at mul instructions, that have been 138 /// paired for parallel execution. 139 PMACPairList &getMulPairs() { return MulPairs; } 140 141 /// To finalise, replace the uses of the root with the intrinsic call. 142 void UpdateRoot(Instruction *SMLAD) { 143 Root->replaceAllUsesWith(SMLAD); 144 } 145 }; 146 147 class WidenedLoad { 148 LoadInst *NewLd = nullptr; 149 SmallVector<LoadInst*, 4> Loads; 150 151 public: 152 WidenedLoad(SmallVectorImpl<LoadInst*> &Lds, LoadInst *Wide) 153 : NewLd(Wide) { 154 for (auto *I : Lds) 155 Loads.push_back(I); 156 } 157 LoadInst *getLoad() { 158 return NewLd; 159 } 160 }; 161 162 class ARMParallelDSP : public FunctionPass { 163 ScalarEvolution *SE; 164 AliasAnalysis *AA; 165 TargetLibraryInfo *TLI; 166 DominatorTree *DT; 167 const DataLayout *DL; 168 Module *M; 169 std::map<LoadInst*, LoadInst*> LoadPairs; 170 SmallPtrSet<LoadInst*, 4> OffsetLoads; 171 std::map<LoadInst*, std::unique_ptr<WidenedLoad>> WideLoads; 172 173 template<unsigned> 174 bool IsNarrowSequence(Value *V, ValueList &VL); 175 176 bool RecordMemoryOps(BasicBlock *BB); 177 void InsertParallelMACs(Reduction &Reduction); 178 bool AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, MemInstList &VecMem); 179 LoadInst* CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads, 180 IntegerType *LoadTy); 181 bool CreateParallelPairs(Reduction &R); 182 183 /// Try to match and generate: SMLAD, SMLADX - Signed Multiply Accumulate 184 /// Dual performs two signed 16x16-bit multiplications. It adds the 185 /// products to a 32-bit accumulate operand. Optionally, the instruction can 186 /// exchange the halfwords of the second operand before performing the 187 /// arithmetic. 188 bool MatchSMLAD(Function &F); 189 190 public: 191 static char ID; 192 193 ARMParallelDSP() : FunctionPass(ID) { } 194 195 void getAnalysisUsage(AnalysisUsage &AU) const override { 196 FunctionPass::getAnalysisUsage(AU); 197 AU.addRequired<AssumptionCacheTracker>(); 198 AU.addRequired<ScalarEvolutionWrapperPass>(); 199 AU.addRequired<AAResultsWrapperPass>(); 200 AU.addRequired<TargetLibraryInfoWrapperPass>(); 201 AU.addRequired<DominatorTreeWrapperPass>(); 202 AU.addRequired<TargetPassConfig>(); 203 AU.addPreserved<ScalarEvolutionWrapperPass>(); 204 AU.addPreserved<GlobalsAAWrapperPass>(); 205 AU.setPreservesCFG(); 206 } 207 208 bool runOnFunction(Function &F) override { 209 if (DisableParallelDSP) 210 return false; 211 if (skipFunction(F)) 212 return false; 213 214 SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE(); 215 AA = &getAnalysis<AAResultsWrapperPass>().getAAResults(); 216 TLI = &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); 217 DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree(); 218 auto &TPC = getAnalysis<TargetPassConfig>(); 219 220 M = F.getParent(); 221 DL = &M->getDataLayout(); 222 223 auto &TM = TPC.getTM<TargetMachine>(); 224 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 225 226 if (!ST->allowsUnalignedMem()) { 227 LLVM_DEBUG(dbgs() << "Unaligned memory access not supported: not " 228 "running pass ARMParallelDSP\n"); 229 return false; 230 } 231 232 if (!ST->hasDSP()) { 233 LLVM_DEBUG(dbgs() << "DSP extension not enabled: not running pass " 234 "ARMParallelDSP\n"); 235 return false; 236 } 237 238 if (!ST->isLittle()) { 239 LLVM_DEBUG(dbgs() << "Only supporting little endian: not running pass " 240 << "ARMParallelDSP\n"); 241 return false; 242 } 243 244 LLVM_DEBUG(dbgs() << "\n== Parallel DSP pass ==\n"); 245 LLVM_DEBUG(dbgs() << " - " << F.getName() << "\n\n"); 246 247 bool Changes = MatchSMLAD(F); 248 return Changes; 249 } 250 }; 251 } 252 253 template<typename MemInst> 254 static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1, 255 const DataLayout &DL, ScalarEvolution &SE) { 256 if (isConsecutiveAccess(MemOp0, MemOp1, DL, SE)) 257 return true; 258 return false; 259 } 260 261 bool ARMParallelDSP::AreSequentialLoads(LoadInst *Ld0, LoadInst *Ld1, 262 MemInstList &VecMem) { 263 if (!Ld0 || !Ld1) 264 return false; 265 266 if (!LoadPairs.count(Ld0) || LoadPairs[Ld0] != Ld1) 267 return false; 268 269 LLVM_DEBUG(dbgs() << "Loads are sequential and valid:\n"; 270 dbgs() << "Ld0:"; Ld0->dump(); 271 dbgs() << "Ld1:"; Ld1->dump(); 272 ); 273 274 VecMem.clear(); 275 VecMem.push_back(Ld0); 276 VecMem.push_back(Ld1); 277 return true; 278 } 279 280 // MaxBitwidth: the maximum supported bitwidth of the elements in the DSP 281 // instructions, which is set to 16. So here we should collect all i8 and i16 282 // narrow operations. 283 // TODO: we currently only collect i16, and will support i8 later, so that's 284 // why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth. 285 template<unsigned MaxBitWidth> 286 bool ARMParallelDSP::IsNarrowSequence(Value *V, ValueList &VL) { 287 if (auto *SExt = dyn_cast<SExtInst>(V)) { 288 if (SExt->getSrcTy()->getIntegerBitWidth() != MaxBitWidth) 289 return false; 290 291 if (auto *Ld = dyn_cast<LoadInst>(SExt->getOperand(0))) { 292 // Check that these load could be paired. 293 if (!LoadPairs.count(Ld) && !OffsetLoads.count(Ld)) 294 return false; 295 296 VL.push_back(Ld); 297 VL.push_back(SExt); 298 return true; 299 } 300 } 301 return false; 302 } 303 304 /// Iterate through the block and record base, offset pairs of loads which can 305 /// be widened into a single load. 306 bool ARMParallelDSP::RecordMemoryOps(BasicBlock *BB) { 307 SmallVector<LoadInst*, 8> Loads; 308 SmallVector<Instruction*, 8> Writes; 309 LoadPairs.clear(); 310 WideLoads.clear(); 311 312 // Collect loads and instruction that may write to memory. For now we only 313 // record loads which are simple, sign-extended and have a single user. 314 // TODO: Allow zero-extended loads. 315 for (auto &I : *BB) { 316 if (I.mayWriteToMemory()) 317 Writes.push_back(&I); 318 auto *Ld = dyn_cast<LoadInst>(&I); 319 if (!Ld || !Ld->isSimple() || 320 !Ld->hasOneUse() || !isa<SExtInst>(Ld->user_back())) 321 continue; 322 Loads.push_back(Ld); 323 } 324 325 using InstSet = std::set<Instruction*>; 326 using DepMap = std::map<Instruction*, InstSet>; 327 DepMap RAWDeps; 328 329 // Record any writes that may alias a load. 330 const auto Size = LocationSize::unknown(); 331 for (auto Read : Loads) { 332 for (auto Write : Writes) { 333 MemoryLocation ReadLoc = 334 MemoryLocation(Read->getPointerOperand(), Size); 335 336 if (!isModOrRefSet(intersectModRef(AA->getModRefInfo(Write, ReadLoc), 337 ModRefInfo::ModRef))) 338 continue; 339 if (DT->dominates(Write, Read)) 340 RAWDeps[Read].insert(Write); 341 } 342 } 343 344 // Check whether there's not a write between the two loads which would 345 // prevent them from being safely merged. 346 auto SafeToPair = [&](LoadInst *Base, LoadInst *Offset) { 347 LoadInst *Dominator = DT->dominates(Base, Offset) ? Base : Offset; 348 LoadInst *Dominated = DT->dominates(Base, Offset) ? Offset : Base; 349 350 if (RAWDeps.count(Dominated)) { 351 InstSet &WritesBefore = RAWDeps[Dominated]; 352 353 for (auto Before : WritesBefore) { 354 355 // We can't move the second load backward, past a write, to merge 356 // with the first load. 357 if (DT->dominates(Dominator, Before)) 358 return false; 359 } 360 } 361 return true; 362 }; 363 364 // Record base, offset load pairs. 365 for (auto *Base : Loads) { 366 for (auto *Offset : Loads) { 367 if (Base == Offset) 368 continue; 369 370 if (AreSequentialAccesses<LoadInst>(Base, Offset, *DL, *SE) && 371 SafeToPair(Base, Offset)) { 372 LoadPairs[Base] = Offset; 373 OffsetLoads.insert(Offset); 374 break; 375 } 376 } 377 } 378 379 LLVM_DEBUG(if (!LoadPairs.empty()) { 380 dbgs() << "Consecutive load pairs:\n"; 381 for (auto &MapIt : LoadPairs) { 382 LLVM_DEBUG(dbgs() << *MapIt.first << ", " 383 << *MapIt.second << "\n"); 384 } 385 }); 386 return LoadPairs.size() > 1; 387 } 388 389 // The pass needs to identify integer add/sub reductions of 16-bit vector 390 // multiplications. 391 // To use SMLAD: 392 // 1) we first need to find integer add then look for this pattern: 393 // 394 // acc0 = ... 395 // ld0 = load i16 396 // sext0 = sext i16 %ld0 to i32 397 // ld1 = load i16 398 // sext1 = sext i16 %ld1 to i32 399 // mul0 = mul %sext0, %sext1 400 // ld2 = load i16 401 // sext2 = sext i16 %ld2 to i32 402 // ld3 = load i16 403 // sext3 = sext i16 %ld3 to i32 404 // mul1 = mul i32 %sext2, %sext3 405 // add0 = add i32 %mul0, %acc0 406 // acc1 = add i32 %add0, %mul1 407 // 408 // Which can be selected to: 409 // 410 // ldr r0 411 // ldr r1 412 // smlad r2, r0, r1, r2 413 // 414 // If constants are used instead of loads, these will need to be hoisted 415 // out and into a register. 416 // 417 // If loop invariants are used instead of loads, these need to be packed 418 // before the loop begins. 419 // 420 bool ARMParallelDSP::MatchSMLAD(Function &F) { 421 // Search recursively back through the operands to find a tree of values that 422 // form a multiply-accumulate chain. The search records the Add and Mul 423 // instructions that form the reduction and allows us to find a single value 424 // to be used as the initial input to the accumlator. 425 std::function<bool(Value*, BasicBlock*, Reduction&)> Search = [&] 426 (Value *V, BasicBlock *BB, Reduction &R) -> bool { 427 428 // If we find a non-instruction, try to use it as the initial accumulator 429 // value. This may have already been found during the search in which case 430 // this function will return false, signaling a search fail. 431 auto *I = dyn_cast<Instruction>(V); 432 if (!I) 433 return R.InsertAcc(V); 434 435 if (I->getParent() != BB) 436 return false; 437 438 switch (I->getOpcode()) { 439 default: 440 break; 441 case Instruction::PHI: 442 // Could be the accumulator value. 443 return R.InsertAcc(V); 444 case Instruction::Add: { 445 // Adds should be adding together two muls, or another add and a mul to 446 // be within the mac chain. One of the operands may also be the 447 // accumulator value at which point we should stop searching. 448 bool ValidLHS = Search(I->getOperand(0), BB, R); 449 bool ValidRHS = Search(I->getOperand(1), BB, R); 450 if (!ValidLHS && !ValidLHS) 451 return false; 452 else if (ValidLHS && ValidRHS) { 453 R.InsertAdd(I); 454 return true; 455 } else { 456 R.InsertAdd(I); 457 return R.InsertAcc(I); 458 } 459 } 460 case Instruction::Mul: { 461 Value *MulOp0 = I->getOperand(0); 462 Value *MulOp1 = I->getOperand(1); 463 if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) { 464 ValueList LHS; 465 ValueList RHS; 466 if (IsNarrowSequence<16>(MulOp0, LHS) && 467 IsNarrowSequence<16>(MulOp1, RHS)) { 468 R.InsertMul(I, LHS, RHS); 469 return true; 470 } 471 } 472 return false; 473 } 474 case Instruction::SExt: 475 return Search(I->getOperand(0), BB, R); 476 } 477 return false; 478 }; 479 480 bool Changed = false; 481 482 for (auto &BB : F) { 483 SmallPtrSet<Instruction*, 4> AllAdds; 484 if (!RecordMemoryOps(&BB)) 485 continue; 486 487 for (Instruction &I : reverse(BB)) { 488 if (I.getOpcode() != Instruction::Add) 489 continue; 490 491 if (AllAdds.count(&I)) 492 continue; 493 494 const auto *Ty = I.getType(); 495 if (!Ty->isIntegerTy(32) && !Ty->isIntegerTy(64)) 496 continue; 497 498 Reduction R(&I); 499 if (!Search(&I, &BB, R)) 500 continue; 501 502 if (!CreateParallelPairs(R)) 503 continue; 504 505 InsertParallelMACs(R); 506 Changed = true; 507 AllAdds.insert(R.getAdds().begin(), R.getAdds().end()); 508 } 509 } 510 511 return Changed; 512 } 513 514 bool ARMParallelDSP::CreateParallelPairs(Reduction &R) { 515 516 // Not enough mul operations to make a pair. 517 if (R.getMuls().size() < 2) 518 return false; 519 520 // Check that the muls operate directly upon sign extended loads. 521 for (auto &MulCand : R.getMuls()) { 522 if (!MulCand->HasTwoLoadInputs()) 523 return false; 524 } 525 526 auto CanPair = [&](Reduction &R, MulCandidate *PMul0, MulCandidate *PMul1) { 527 // The first elements of each vector should be loads with sexts. If we 528 // find that its two pairs of consecutive loads, then these can be 529 // transformed into two wider loads and the users can be replaced with 530 // DSP intrinsics. 531 auto Ld0 = static_cast<LoadInst*>(PMul0->LHS); 532 auto Ld1 = static_cast<LoadInst*>(PMul1->LHS); 533 auto Ld2 = static_cast<LoadInst*>(PMul0->RHS); 534 auto Ld3 = static_cast<LoadInst*>(PMul1->RHS); 535 536 LLVM_DEBUG(dbgs() << "Loads:\n" 537 << " - " << *Ld0 << "\n" 538 << " - " << *Ld1 << "\n" 539 << " - " << *Ld2 << "\n" 540 << " - " << *Ld3 << "\n"); 541 542 if (AreSequentialLoads(Ld0, Ld1, PMul0->VecLd)) { 543 if (AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) { 544 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n"); 545 R.AddMulPair(PMul0, PMul1); 546 return true; 547 } else if (AreSequentialLoads(Ld3, Ld2, PMul1->VecLd)) { 548 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n"); 549 LLVM_DEBUG(dbgs() << " exchanging Ld2 and Ld3\n"); 550 PMul1->Exchange = true; 551 R.AddMulPair(PMul0, PMul1); 552 return true; 553 } 554 } else if (AreSequentialLoads(Ld1, Ld0, PMul0->VecLd) && 555 AreSequentialLoads(Ld2, Ld3, PMul1->VecLd)) { 556 LLVM_DEBUG(dbgs() << "OK: found two pairs of parallel loads!\n"); 557 LLVM_DEBUG(dbgs() << " exchanging Ld0 and Ld1\n"); 558 LLVM_DEBUG(dbgs() << " and swapping muls\n"); 559 PMul0->Exchange = true; 560 // Only the second operand can be exchanged, so swap the muls. 561 R.AddMulPair(PMul1, PMul0); 562 return true; 563 } 564 return false; 565 }; 566 567 MulCandList &Muls = R.getMuls(); 568 const unsigned Elems = Muls.size(); 569 SmallPtrSet<const Instruction*, 4> Paired; 570 for (unsigned i = 0; i < Elems; ++i) { 571 MulCandidate *PMul0 = static_cast<MulCandidate*>(Muls[i].get()); 572 if (Paired.count(PMul0->Root)) 573 continue; 574 575 for (unsigned j = 0; j < Elems; ++j) { 576 if (i == j) 577 continue; 578 579 MulCandidate *PMul1 = static_cast<MulCandidate*>(Muls[j].get()); 580 if (Paired.count(PMul1->Root)) 581 continue; 582 583 const Instruction *Mul0 = PMul0->Root; 584 const Instruction *Mul1 = PMul1->Root; 585 if (Mul0 == Mul1) 586 continue; 587 588 assert(PMul0 != PMul1 && "expected different chains"); 589 590 if (CanPair(R, PMul0, PMul1)) { 591 Paired.insert(Mul0); 592 Paired.insert(Mul1); 593 break; 594 } 595 } 596 } 597 return !R.getMulPairs().empty(); 598 } 599 600 601 void ARMParallelDSP::InsertParallelMACs(Reduction &R) { 602 603 auto CreateSMLAD = [&](LoadInst* WideLd0, LoadInst *WideLd1, 604 Value *Acc, bool Exchange, 605 Instruction *InsertAfter) { 606 // Replace the reduction chain with an intrinsic call 607 608 Value* Args[] = { WideLd0, WideLd1, Acc }; 609 Function *SMLAD = nullptr; 610 if (Exchange) 611 SMLAD = Acc->getType()->isIntegerTy(32) ? 612 Intrinsic::getDeclaration(M, Intrinsic::arm_smladx) : 613 Intrinsic::getDeclaration(M, Intrinsic::arm_smlaldx); 614 else 615 SMLAD = Acc->getType()->isIntegerTy(32) ? 616 Intrinsic::getDeclaration(M, Intrinsic::arm_smlad) : 617 Intrinsic::getDeclaration(M, Intrinsic::arm_smlald); 618 619 IRBuilder<NoFolder> Builder(InsertAfter->getParent(), 620 ++BasicBlock::iterator(InsertAfter)); 621 Instruction *Call = Builder.CreateCall(SMLAD, Args); 622 NumSMLAD++; 623 return Call; 624 }; 625 626 Instruction *InsertAfter = R.getRoot(); 627 Value *Acc = R.getAccumulator(); 628 if (!Acc) 629 Acc = ConstantInt::get(IntegerType::get(M->getContext(), 32), 0); 630 631 IntegerType *Ty = IntegerType::get(M->getContext(), 32); 632 LLVM_DEBUG(dbgs() << "Root: " << *InsertAfter << "\n" 633 << "Acc: " << *Acc << "\n"); 634 for (auto &Pair : R.getMulPairs()) { 635 MulCandidate *LHSMul = Pair.first; 636 MulCandidate *RHSMul = Pair.second; 637 LLVM_DEBUG(dbgs() << "Muls:\n" 638 << "- " << *LHSMul->Root << "\n" 639 << "- " << *RHSMul->Root << "\n"); 640 LoadInst *BaseLHS = LHSMul->getBaseLoad(); 641 LoadInst *BaseRHS = RHSMul->getBaseLoad(); 642 LoadInst *WideLHS = WideLoads.count(BaseLHS) ? 643 WideLoads[BaseLHS]->getLoad() : CreateWideLoad(LHSMul->VecLd, Ty); 644 LoadInst *WideRHS = WideLoads.count(BaseRHS) ? 645 WideLoads[BaseRHS]->getLoad() : CreateWideLoad(RHSMul->VecLd, Ty); 646 647 Acc = CreateSMLAD(WideLHS, WideRHS, Acc, RHSMul->Exchange, InsertAfter); 648 InsertAfter = cast<Instruction>(Acc); 649 } 650 R.UpdateRoot(cast<Instruction>(Acc)); 651 } 652 653 LoadInst* ARMParallelDSP::CreateWideLoad(SmallVectorImpl<LoadInst*> &Loads, 654 IntegerType *LoadTy) { 655 assert(Loads.size() == 2 && "currently only support widening two loads"); 656 657 LoadInst *Base = Loads[0]; 658 LoadInst *Offset = Loads[1]; 659 660 Instruction *BaseSExt = dyn_cast<SExtInst>(Base->user_back()); 661 Instruction *OffsetSExt = dyn_cast<SExtInst>(Offset->user_back()); 662 663 assert((BaseSExt && OffsetSExt) 664 && "Loads should have a single, extending, user"); 665 666 std::function<void(Value*, Value*)> MoveBefore = 667 [&](Value *A, Value *B) -> void { 668 if (!isa<Instruction>(A) || !isa<Instruction>(B)) 669 return; 670 671 auto *Source = cast<Instruction>(A); 672 auto *Sink = cast<Instruction>(B); 673 674 if (DT->dominates(Source, Sink) || 675 Source->getParent() != Sink->getParent() || 676 isa<PHINode>(Source) || isa<PHINode>(Sink)) 677 return; 678 679 Source->moveBefore(Sink); 680 for (auto &Op : Source->operands()) 681 MoveBefore(Op, Source); 682 }; 683 684 // Insert the load at the point of the original dominating load. 685 LoadInst *DomLoad = DT->dominates(Base, Offset) ? Base : Offset; 686 IRBuilder<NoFolder> IRB(DomLoad->getParent(), 687 ++BasicBlock::iterator(DomLoad)); 688 689 // Bitcast the pointer to a wider type and create the wide load, while making 690 // sure to maintain the original alignment as this prevents ldrd from being 691 // generated when it could be illegal due to memory alignment. 692 const unsigned AddrSpace = DomLoad->getPointerAddressSpace(); 693 Value *VecPtr = IRB.CreateBitCast(Base->getPointerOperand(), 694 LoadTy->getPointerTo(AddrSpace)); 695 LoadInst *WideLoad = IRB.CreateAlignedLoad(LoadTy, VecPtr, 696 Base->getAlignment()); 697 698 // Make sure everything is in the correct order in the basic block. 699 MoveBefore(Base->getPointerOperand(), VecPtr); 700 MoveBefore(VecPtr, WideLoad); 701 702 // From the wide load, create two values that equal the original two loads. 703 // Loads[0] needs trunc while Loads[1] needs a lshr and trunc. 704 // TODO: Support big-endian as well. 705 Value *Bottom = IRB.CreateTrunc(WideLoad, Base->getType()); 706 BaseSExt->setOperand(0, Bottom); 707 708 IntegerType *OffsetTy = cast<IntegerType>(Offset->getType()); 709 Value *ShiftVal = ConstantInt::get(LoadTy, OffsetTy->getBitWidth()); 710 Value *Top = IRB.CreateLShr(WideLoad, ShiftVal); 711 Value *Trunc = IRB.CreateTrunc(Top, OffsetTy); 712 OffsetSExt->setOperand(0, Trunc); 713 714 WideLoads.emplace(std::make_pair(Base, 715 make_unique<WidenedLoad>(Loads, WideLoad))); 716 return WideLoad; 717 } 718 719 Pass *llvm::createARMParallelDSPPass() { 720 return new ARMParallelDSP(); 721 } 722 723 char ARMParallelDSP::ID = 0; 724 725 INITIALIZE_PASS_BEGIN(ARMParallelDSP, "arm-parallel-dsp", 726 "Transform functions to use DSP intrinsics", false, false) 727 INITIALIZE_PASS_END(ARMParallelDSP, "arm-parallel-dsp", 728 "Transform functions to use DSP intrinsics", false, false) 729