1 //===- MVEGatherScatterLowering.cpp - Gather/Scatter lowering -------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 /// This pass custom lowers llvm.gather and llvm.scatter instructions to 10 /// arm.mve.gather and arm.mve.scatter intrinsics, optimising the code to 11 /// produce a better final result as we go. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "ARM.h" 16 #include "ARMBaseInstrInfo.h" 17 #include "ARMSubtarget.h" 18 #include "llvm/Analysis/LoopInfo.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/CodeGen/TargetLowering.h" 21 #include "llvm/CodeGen/TargetPassConfig.h" 22 #include "llvm/CodeGen/TargetSubtargetInfo.h" 23 #include "llvm/InitializePasses.h" 24 #include "llvm/IR/BasicBlock.h" 25 #include "llvm/IR/Constant.h" 26 #include "llvm/IR/Constants.h" 27 #include "llvm/IR/DerivedTypes.h" 28 #include "llvm/IR/Function.h" 29 #include "llvm/IR/InstrTypes.h" 30 #include "llvm/IR/Instruction.h" 31 #include "llvm/IR/Instructions.h" 32 #include "llvm/IR/IntrinsicInst.h" 33 #include "llvm/IR/Intrinsics.h" 34 #include "llvm/IR/IntrinsicsARM.h" 35 #include "llvm/IR/IRBuilder.h" 36 #include "llvm/IR/PatternMatch.h" 37 #include "llvm/IR/Type.h" 38 #include "llvm/IR/Value.h" 39 #include "llvm/Pass.h" 40 #include "llvm/Support/Casting.h" 41 #include "llvm/Transforms/Utils/Local.h" 42 #include <algorithm> 43 #include <cassert> 44 45 using namespace llvm; 46 47 #define DEBUG_TYPE "arm-mve-gather-scatter-lowering" 48 49 cl::opt<bool> EnableMaskedGatherScatters( 50 "enable-arm-maskedgatscat", cl::Hidden, cl::init(true), 51 cl::desc("Enable the generation of masked gathers and scatters")); 52 53 namespace { 54 55 class MVEGatherScatterLowering : public FunctionPass { 56 public: 57 static char ID; // Pass identification, replacement for typeid 58 59 explicit MVEGatherScatterLowering() : FunctionPass(ID) { 60 initializeMVEGatherScatterLoweringPass(*PassRegistry::getPassRegistry()); 61 } 62 63 bool runOnFunction(Function &F) override; 64 65 StringRef getPassName() const override { 66 return "MVE gather/scatter lowering"; 67 } 68 69 void getAnalysisUsage(AnalysisUsage &AU) const override { 70 AU.setPreservesCFG(); 71 AU.addRequired<TargetPassConfig>(); 72 AU.addRequired<LoopInfoWrapperPass>(); 73 FunctionPass::getAnalysisUsage(AU); 74 } 75 76 private: 77 LoopInfo *LI = nullptr; 78 79 // Check this is a valid gather with correct alignment 80 bool isLegalTypeAndAlignment(unsigned NumElements, unsigned ElemSize, 81 Align Alignment); 82 // Check whether Ptr is hidden behind a bitcast and look through it 83 void lookThroughBitcast(Value *&Ptr); 84 // Check for a getelementptr and deduce base and offsets from it, on success 85 // returning the base directly and the offsets indirectly using the Offsets 86 // argument 87 Value *checkGEP(Value *&Offsets, FixedVectorType *Ty, GetElementPtrInst *GEP, 88 IRBuilder<> &Builder); 89 // Compute the scale of this gather/scatter instruction 90 int computeScale(unsigned GEPElemSize, unsigned MemoryElemSize); 91 // If the value is a constant, or derived from constants via additions 92 // and multilications, return its numeric value 93 Optional<int64_t> getIfConst(const Value *V); 94 // If Inst is an add instruction, check whether one summand is a 95 // constant. If so, scale this constant and return it together with 96 // the other summand. 97 std::pair<Value *, int64_t> getVarAndConst(Value *Inst, int TypeScale); 98 99 Value *lowerGather(IntrinsicInst *I); 100 // Create a gather from a base + vector of offsets 101 Value *tryCreateMaskedGatherOffset(IntrinsicInst *I, Value *Ptr, 102 Instruction *&Root, IRBuilder<> &Builder); 103 // Create a gather from a vector of pointers 104 Value *tryCreateMaskedGatherBase(IntrinsicInst *I, Value *Ptr, 105 IRBuilder<> &Builder, int64_t Increment = 0); 106 // Create an incrementing gather from a vector of pointers 107 Value *tryCreateMaskedGatherBaseWB(IntrinsicInst *I, Value *Ptr, 108 IRBuilder<> &Builder, 109 int64_t Increment = 0); 110 111 Value *lowerScatter(IntrinsicInst *I); 112 // Create a scatter to a base + vector of offsets 113 Value *tryCreateMaskedScatterOffset(IntrinsicInst *I, Value *Offsets, 114 IRBuilder<> &Builder); 115 // Create a scatter to a vector of pointers 116 Value *tryCreateMaskedScatterBase(IntrinsicInst *I, Value *Ptr, 117 IRBuilder<> &Builder, 118 int64_t Increment = 0); 119 // Create an incrementing scatter from a vector of pointers 120 Value *tryCreateMaskedScatterBaseWB(IntrinsicInst *I, Value *Ptr, 121 IRBuilder<> &Builder, 122 int64_t Increment = 0); 123 124 // QI gathers and scatters can increment their offsets on their own if 125 // the increment is a constant value (digit) 126 Value *tryCreateIncrementingGatScat(IntrinsicInst *I, Value *BasePtr, 127 Value *Ptr, GetElementPtrInst *GEP, 128 IRBuilder<> &Builder); 129 // QI gathers/scatters can increment their offsets on their own if the 130 // increment is a constant value (digit) - this creates a writeback QI 131 // gather/scatter 132 Value *tryCreateIncrementingWBGatScat(IntrinsicInst *I, Value *BasePtr, 133 Value *Ptr, unsigned TypeScale, 134 IRBuilder<> &Builder); 135 136 // Optimise the base and offsets of the given address 137 bool optimiseAddress(Value *Address, BasicBlock *BB, LoopInfo *LI); 138 // Try to fold consecutive geps together into one 139 Value *foldGEP(GetElementPtrInst *GEP, Value *&Offsets, IRBuilder<> &Builder); 140 // Check whether these offsets could be moved out of the loop they're in 141 bool optimiseOffsets(Value *Offsets, BasicBlock *BB, LoopInfo *LI); 142 // Pushes the given add out of the loop 143 void pushOutAdd(PHINode *&Phi, Value *OffsSecondOperand, unsigned StartIndex); 144 // Pushes the given mul out of the loop 145 void pushOutMul(PHINode *&Phi, Value *IncrementPerRound, 146 Value *OffsSecondOperand, unsigned LoopIncrement, 147 IRBuilder<> &Builder); 148 }; 149 150 } // end anonymous namespace 151 152 char MVEGatherScatterLowering::ID = 0; 153 154 INITIALIZE_PASS(MVEGatherScatterLowering, DEBUG_TYPE, 155 "MVE gather/scattering lowering pass", false, false) 156 157 Pass *llvm::createMVEGatherScatterLoweringPass() { 158 return new MVEGatherScatterLowering(); 159 } 160 161 bool MVEGatherScatterLowering::isLegalTypeAndAlignment(unsigned NumElements, 162 unsigned ElemSize, 163 Align Alignment) { 164 if (((NumElements == 4 && 165 (ElemSize == 32 || ElemSize == 16 || ElemSize == 8)) || 166 (NumElements == 8 && (ElemSize == 16 || ElemSize == 8)) || 167 (NumElements == 16 && ElemSize == 8)) && 168 Alignment >= ElemSize / 8) 169 return true; 170 LLVM_DEBUG(dbgs() << "masked gathers/scatters: instruction does not have " 171 << "valid alignment or vector type \n"); 172 return false; 173 } 174 175 static bool checkOffsetSize(Value *Offsets, unsigned TargetElemCount) { 176 // Offsets that are not of type <N x i32> are sign extended by the 177 // getelementptr instruction, and MVE gathers/scatters treat the offset as 178 // unsigned. Thus, if the element size is smaller than 32, we can only allow 179 // positive offsets - i.e., the offsets are not allowed to be variables we 180 // can't look into. 181 // Additionally, <N x i32> offsets have to either originate from a zext of a 182 // vector with element types smaller or equal the type of the gather we're 183 // looking at, or consist of constants that we can check are small enough 184 // to fit into the gather type. 185 // Thus we check that 0 < value < 2^TargetElemSize. 186 unsigned TargetElemSize = 128 / TargetElemCount; 187 unsigned OffsetElemSize = cast<FixedVectorType>(Offsets->getType()) 188 ->getElementType() 189 ->getScalarSizeInBits(); 190 if (OffsetElemSize != TargetElemSize || OffsetElemSize != 32) { 191 Constant *ConstOff = dyn_cast<Constant>(Offsets); 192 if (!ConstOff) 193 return false; 194 int64_t TargetElemMaxSize = (1ULL << TargetElemSize); 195 auto CheckValueSize = [TargetElemMaxSize](Value *OffsetElem) { 196 ConstantInt *OConst = dyn_cast<ConstantInt>(OffsetElem); 197 if (!OConst) 198 return false; 199 int SExtValue = OConst->getSExtValue(); 200 if (SExtValue >= TargetElemMaxSize || SExtValue < 0) 201 return false; 202 return true; 203 }; 204 if (isa<FixedVectorType>(ConstOff->getType())) { 205 for (unsigned i = 0; i < TargetElemCount; i++) { 206 if (!CheckValueSize(ConstOff->getAggregateElement(i))) 207 return false; 208 } 209 } else { 210 if (!CheckValueSize(ConstOff)) 211 return false; 212 } 213 } 214 return true; 215 } 216 217 Value *MVEGatherScatterLowering::checkGEP(Value *&Offsets, FixedVectorType *Ty, 218 GetElementPtrInst *GEP, 219 IRBuilder<> &Builder) { 220 if (!GEP) { 221 LLVM_DEBUG(dbgs() << "masked gathers/scatters: no getelementpointer " 222 << "found\n"); 223 return nullptr; 224 } 225 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementpointer found." 226 << " Looking at intrinsic for base + vector of offsets\n"); 227 Value *GEPPtr = GEP->getPointerOperand(); 228 Offsets = GEP->getOperand(1); 229 if (GEPPtr->getType()->isVectorTy() || 230 !isa<FixedVectorType>(Offsets->getType())) 231 return nullptr; 232 233 if (GEP->getNumOperands() != 2) { 234 LLVM_DEBUG(dbgs() << "masked gathers/scatters: getelementptr with too many" 235 << " operands. Expanding.\n"); 236 return nullptr; 237 } 238 Offsets = GEP->getOperand(1); 239 unsigned OffsetsElemCount = 240 cast<FixedVectorType>(Offsets->getType())->getNumElements(); 241 // Paranoid check whether the number of parallel lanes is the same 242 assert(Ty->getNumElements() == OffsetsElemCount); 243 244 ZExtInst *ZextOffs = dyn_cast<ZExtInst>(Offsets); 245 if (ZextOffs) 246 Offsets = ZextOffs->getOperand(0); 247 FixedVectorType *OffsetType = cast<FixedVectorType>(Offsets->getType()); 248 249 // If the offsets are already being zext-ed to <N x i32>, that relieves us of 250 // having to make sure that they won't overflow. 251 if (!ZextOffs || cast<FixedVectorType>(ZextOffs->getDestTy()) 252 ->getElementType() 253 ->getScalarSizeInBits() != 32) 254 if (!checkOffsetSize(Offsets, OffsetsElemCount)) 255 return nullptr; 256 257 // The offset sizes have been checked; if any truncating or zext-ing is 258 // required to fix them, do that now 259 if (Ty != Offsets->getType()) { 260 if ((Ty->getElementType()->getScalarSizeInBits() < 261 OffsetType->getElementType()->getScalarSizeInBits())) { 262 Offsets = Builder.CreateTrunc(Offsets, Ty); 263 } else { 264 Offsets = Builder.CreateZExt(Offsets, VectorType::getInteger(Ty)); 265 } 266 } 267 // If none of the checks failed, return the gep's base pointer 268 LLVM_DEBUG(dbgs() << "masked gathers/scatters: found correct offsets\n"); 269 return GEPPtr; 270 } 271 272 void MVEGatherScatterLowering::lookThroughBitcast(Value *&Ptr) { 273 // Look through bitcast instruction if #elements is the same 274 if (auto *BitCast = dyn_cast<BitCastInst>(Ptr)) { 275 auto *BCTy = cast<FixedVectorType>(BitCast->getType()); 276 auto *BCSrcTy = cast<FixedVectorType>(BitCast->getOperand(0)->getType()); 277 if (BCTy->getNumElements() == BCSrcTy->getNumElements()) { 278 LLVM_DEBUG(dbgs() << "masked gathers/scatters: looking through " 279 << "bitcast\n"); 280 Ptr = BitCast->getOperand(0); 281 } 282 } 283 } 284 285 int MVEGatherScatterLowering::computeScale(unsigned GEPElemSize, 286 unsigned MemoryElemSize) { 287 // This can be a 32bit load/store scaled by 4, a 16bit load/store scaled by 2, 288 // or a 8bit, 16bit or 32bit load/store scaled by 1 289 if (GEPElemSize == 32 && MemoryElemSize == 32) 290 return 2; 291 else if (GEPElemSize == 16 && MemoryElemSize == 16) 292 return 1; 293 else if (GEPElemSize == 8) 294 return 0; 295 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incorrect scale. Can't " 296 << "create intrinsic\n"); 297 return -1; 298 } 299 300 Optional<int64_t> MVEGatherScatterLowering::getIfConst(const Value *V) { 301 const Constant *C = dyn_cast<Constant>(V); 302 if (C != nullptr) 303 return Optional<int64_t>{C->getUniqueInteger().getSExtValue()}; 304 if (!isa<Instruction>(V)) 305 return Optional<int64_t>{}; 306 307 const Instruction *I = cast<Instruction>(V); 308 if (I->getOpcode() == Instruction::Add || 309 I->getOpcode() == Instruction::Mul) { 310 Optional<int64_t> Op0 = getIfConst(I->getOperand(0)); 311 Optional<int64_t> Op1 = getIfConst(I->getOperand(1)); 312 if (!Op0 || !Op1) 313 return Optional<int64_t>{}; 314 if (I->getOpcode() == Instruction::Add) 315 return Optional<int64_t>{Op0.getValue() + Op1.getValue()}; 316 if (I->getOpcode() == Instruction::Mul) 317 return Optional<int64_t>{Op0.getValue() * Op1.getValue()}; 318 } 319 return Optional<int64_t>{}; 320 } 321 322 std::pair<Value *, int64_t> 323 MVEGatherScatterLowering::getVarAndConst(Value *Inst, int TypeScale) { 324 std::pair<Value *, int64_t> ReturnFalse = 325 std::pair<Value *, int64_t>(nullptr, 0); 326 // At this point, the instruction we're looking at must be an add or we 327 // bail out 328 Instruction *Add = dyn_cast<Instruction>(Inst); 329 if (Add == nullptr || Add->getOpcode() != Instruction::Add) 330 return ReturnFalse; 331 332 Value *Summand; 333 Optional<int64_t> Const; 334 // Find out which operand the value that is increased is 335 if ((Const = getIfConst(Add->getOperand(0)))) 336 Summand = Add->getOperand(1); 337 else if ((Const = getIfConst(Add->getOperand(1)))) 338 Summand = Add->getOperand(0); 339 else 340 return ReturnFalse; 341 342 // Check that the constant is small enough for an incrementing gather 343 int64_t Immediate = Const.getValue() << TypeScale; 344 if (Immediate > 512 || Immediate < -512 || Immediate % 4 != 0) 345 return ReturnFalse; 346 347 return std::pair<Value *, int64_t>(Summand, Immediate); 348 } 349 350 Value *MVEGatherScatterLowering::lowerGather(IntrinsicInst *I) { 351 using namespace PatternMatch; 352 LLVM_DEBUG(dbgs() << "masked gathers: checking transform preconditions\n" 353 << *I << "\n"); 354 355 // @llvm.masked.gather.*(Ptrs, alignment, Mask, Src0) 356 // Attempt to turn the masked gather in I into a MVE intrinsic 357 // Potentially optimising the addressing modes as we do so. 358 auto *Ty = cast<FixedVectorType>(I->getType()); 359 Value *Ptr = I->getArgOperand(0); 360 Align Alignment = cast<ConstantInt>(I->getArgOperand(1))->getAlignValue(); 361 Value *Mask = I->getArgOperand(2); 362 Value *PassThru = I->getArgOperand(3); 363 364 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 365 Alignment)) 366 return nullptr; 367 lookThroughBitcast(Ptr); 368 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 369 370 IRBuilder<> Builder(I->getContext()); 371 Builder.SetInsertPoint(I); 372 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 373 374 Instruction *Root = I; 375 Value *Load = tryCreateMaskedGatherOffset(I, Ptr, Root, Builder); 376 if (!Load) 377 Load = tryCreateMaskedGatherBase(I, Ptr, Builder); 378 if (!Load) 379 return nullptr; 380 381 if (!isa<UndefValue>(PassThru) && !match(PassThru, m_Zero())) { 382 LLVM_DEBUG(dbgs() << "masked gathers: found non-trivial passthru - " 383 << "creating select\n"); 384 Load = Builder.CreateSelect(Mask, Load, PassThru); 385 } 386 387 Root->replaceAllUsesWith(Load); 388 Root->eraseFromParent(); 389 if (Root != I) 390 // If this was an extending gather, we need to get rid of the sext/zext 391 // sext/zext as well as of the gather itself 392 I->eraseFromParent(); 393 394 LLVM_DEBUG(dbgs() << "masked gathers: successfully built masked gather\n" 395 << *Load << "\n"); 396 return Load; 397 } 398 399 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBase(IntrinsicInst *I, 400 Value *Ptr, 401 IRBuilder<> &Builder, 402 int64_t Increment) { 403 using namespace PatternMatch; 404 auto *Ty = cast<FixedVectorType>(I->getType()); 405 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers\n"); 406 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 407 // Can't build an intrinsic for this 408 return nullptr; 409 Value *Mask = I->getArgOperand(2); 410 if (match(Mask, m_One())) 411 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base, 412 {Ty, Ptr->getType()}, 413 {Ptr, Builder.getInt32(Increment)}); 414 else 415 return Builder.CreateIntrinsic( 416 Intrinsic::arm_mve_vldr_gather_base_predicated, 417 {Ty, Ptr->getType(), Mask->getType()}, 418 {Ptr, Builder.getInt32(Increment), Mask}); 419 } 420 421 Value *MVEGatherScatterLowering::tryCreateMaskedGatherBaseWB( 422 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 423 using namespace PatternMatch; 424 auto *Ty = cast<FixedVectorType>(I->getType()); 425 LLVM_DEBUG(dbgs() << "masked gathers: loading from vector of pointers with " 426 << "writeback\n"); 427 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 428 // Can't build an intrinsic for this 429 return nullptr; 430 Value *Mask = I->getArgOperand(2); 431 if (match(Mask, m_One())) 432 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vldr_gather_base_wb, 433 {Ty, Ptr->getType()}, 434 {Ptr, Builder.getInt32(Increment)}); 435 else 436 return Builder.CreateIntrinsic( 437 Intrinsic::arm_mve_vldr_gather_base_wb_predicated, 438 {Ty, Ptr->getType(), Mask->getType()}, 439 {Ptr, Builder.getInt32(Increment), Mask}); 440 } 441 442 Value *MVEGatherScatterLowering::tryCreateMaskedGatherOffset( 443 IntrinsicInst *I, Value *Ptr, Instruction *&Root, IRBuilder<> &Builder) { 444 using namespace PatternMatch; 445 446 Type *OriginalTy = I->getType(); 447 Type *ResultTy = OriginalTy; 448 449 unsigned Unsigned = 1; 450 // The size of the gather was already checked in isLegalTypeAndAlignment; 451 // if it was not a full vector width an appropriate extend should follow. 452 auto *Extend = Root; 453 if (OriginalTy->getPrimitiveSizeInBits() < 128) { 454 // Only transform gathers with exactly one use 455 if (!I->hasOneUse()) 456 return nullptr; 457 458 // The correct root to replace is not the CallInst itself, but the 459 // instruction which extends it 460 Extend = cast<Instruction>(*I->users().begin()); 461 if (isa<SExtInst>(Extend)) { 462 Unsigned = 0; 463 } else if (!isa<ZExtInst>(Extend)) { 464 LLVM_DEBUG(dbgs() << "masked gathers: extend needed but not provided. " 465 << "Expanding\n"); 466 return nullptr; 467 } 468 LLVM_DEBUG(dbgs() << "masked gathers: found an extending gather\n"); 469 ResultTy = Extend->getType(); 470 // The final size of the gather must be a full vector width 471 if (ResultTy->getPrimitiveSizeInBits() != 128) { 472 LLVM_DEBUG(dbgs() << "masked gathers: extending from the wrong type. " 473 << "Expanding\n"); 474 return nullptr; 475 } 476 } 477 478 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 479 Value *Offsets; 480 Value *BasePtr = 481 checkGEP(Offsets, cast<FixedVectorType>(ResultTy), GEP, Builder); 482 if (!BasePtr) 483 return nullptr; 484 // Check whether the offset is a constant increment that could be merged into 485 // a QI gather 486 Value *Load = tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); 487 if (Load) 488 return Load; 489 490 int Scale = 491 computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(), 492 OriginalTy->getScalarSizeInBits()); 493 if (Scale == -1) 494 return nullptr; 495 Root = Extend; 496 497 Value *Mask = I->getArgOperand(2); 498 if (!match(Mask, m_One())) 499 return Builder.CreateIntrinsic( 500 Intrinsic::arm_mve_vldr_gather_offset_predicated, 501 {ResultTy, BasePtr->getType(), Offsets->getType(), Mask->getType()}, 502 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), 503 Builder.getInt32(Scale), Builder.getInt32(Unsigned), Mask}); 504 else 505 return Builder.CreateIntrinsic( 506 Intrinsic::arm_mve_vldr_gather_offset, 507 {ResultTy, BasePtr->getType(), Offsets->getType()}, 508 {BasePtr, Offsets, Builder.getInt32(OriginalTy->getScalarSizeInBits()), 509 Builder.getInt32(Scale), Builder.getInt32(Unsigned)}); 510 } 511 512 Value *MVEGatherScatterLowering::lowerScatter(IntrinsicInst *I) { 513 using namespace PatternMatch; 514 LLVM_DEBUG(dbgs() << "masked scatters: checking transform preconditions\n" 515 << *I << "\n"); 516 517 // @llvm.masked.scatter.*(data, ptrs, alignment, mask) 518 // Attempt to turn the masked scatter in I into a MVE intrinsic 519 // Potentially optimising the addressing modes as we do so. 520 Value *Input = I->getArgOperand(0); 521 Value *Ptr = I->getArgOperand(1); 522 Align Alignment = cast<ConstantInt>(I->getArgOperand(2))->getAlignValue(); 523 auto *Ty = cast<FixedVectorType>(Input->getType()); 524 525 if (!isLegalTypeAndAlignment(Ty->getNumElements(), Ty->getScalarSizeInBits(), 526 Alignment)) 527 return nullptr; 528 529 lookThroughBitcast(Ptr); 530 assert(Ptr->getType()->isVectorTy() && "Unexpected pointer type"); 531 532 IRBuilder<> Builder(I->getContext()); 533 Builder.SetInsertPoint(I); 534 Builder.SetCurrentDebugLocation(I->getDebugLoc()); 535 536 Value *Store = tryCreateMaskedScatterOffset(I, Ptr, Builder); 537 if (!Store) 538 Store = tryCreateMaskedScatterBase(I, Ptr, Builder); 539 if (!Store) 540 return nullptr; 541 542 LLVM_DEBUG(dbgs() << "masked scatters: successfully built masked scatter\n" 543 << *Store << "\n"); 544 I->eraseFromParent(); 545 return Store; 546 } 547 548 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBase( 549 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 550 using namespace PatternMatch; 551 Value *Input = I->getArgOperand(0); 552 auto *Ty = cast<FixedVectorType>(Input->getType()); 553 // Only QR variants allow truncating 554 if (!(Ty->getNumElements() == 4 && Ty->getScalarSizeInBits() == 32)) { 555 // Can't build an intrinsic for this 556 return nullptr; 557 } 558 Value *Mask = I->getArgOperand(3); 559 // int_arm_mve_vstr_scatter_base(_predicated) addr, offset, data(, mask) 560 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers\n"); 561 if (match(Mask, m_One())) 562 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base, 563 {Ptr->getType(), Input->getType()}, 564 {Ptr, Builder.getInt32(Increment), Input}); 565 else 566 return Builder.CreateIntrinsic( 567 Intrinsic::arm_mve_vstr_scatter_base_predicated, 568 {Ptr->getType(), Input->getType(), Mask->getType()}, 569 {Ptr, Builder.getInt32(Increment), Input, Mask}); 570 } 571 572 Value *MVEGatherScatterLowering::tryCreateMaskedScatterBaseWB( 573 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder, int64_t Increment) { 574 using namespace PatternMatch; 575 Value *Input = I->getArgOperand(0); 576 auto *Ty = cast<FixedVectorType>(Input->getType()); 577 LLVM_DEBUG(dbgs() << "masked scatters: storing to a vector of pointers " 578 << "with writeback\n"); 579 if (Ty->getNumElements() != 4 || Ty->getScalarSizeInBits() != 32) 580 // Can't build an intrinsic for this 581 return nullptr; 582 Value *Mask = I->getArgOperand(3); 583 if (match(Mask, m_One())) 584 return Builder.CreateIntrinsic(Intrinsic::arm_mve_vstr_scatter_base_wb, 585 {Ptr->getType(), Input->getType()}, 586 {Ptr, Builder.getInt32(Increment), Input}); 587 else 588 return Builder.CreateIntrinsic( 589 Intrinsic::arm_mve_vstr_scatter_base_wb_predicated, 590 {Ptr->getType(), Input->getType(), Mask->getType()}, 591 {Ptr, Builder.getInt32(Increment), Input, Mask}); 592 } 593 594 Value *MVEGatherScatterLowering::tryCreateMaskedScatterOffset( 595 IntrinsicInst *I, Value *Ptr, IRBuilder<> &Builder) { 596 using namespace PatternMatch; 597 Value *Input = I->getArgOperand(0); 598 Value *Mask = I->getArgOperand(3); 599 Type *InputTy = Input->getType(); 600 Type *MemoryTy = InputTy; 601 LLVM_DEBUG(dbgs() << "masked scatters: getelementpointer found. Storing" 602 << " to base + vector of offsets\n"); 603 // If the input has been truncated, try to integrate that trunc into the 604 // scatter instruction (we don't care about alignment here) 605 if (TruncInst *Trunc = dyn_cast<TruncInst>(Input)) { 606 Value *PreTrunc = Trunc->getOperand(0); 607 Type *PreTruncTy = PreTrunc->getType(); 608 if (PreTruncTy->getPrimitiveSizeInBits() == 128) { 609 Input = PreTrunc; 610 InputTy = PreTruncTy; 611 } 612 } 613 if (InputTy->getPrimitiveSizeInBits() != 128) { 614 LLVM_DEBUG(dbgs() << "masked scatters: cannot create scatters for " 615 "non-standard input types. Expanding.\n"); 616 return nullptr; 617 } 618 619 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Ptr); 620 Value *Offsets; 621 Value *BasePtr = 622 checkGEP(Offsets, cast<FixedVectorType>(InputTy), GEP, Builder); 623 if (!BasePtr) 624 return nullptr; 625 // Check whether the offset is a constant increment that could be merged into 626 // a QI gather 627 Value *Store = 628 tryCreateIncrementingGatScat(I, BasePtr, Offsets, GEP, Builder); 629 if (Store) 630 return Store; 631 int Scale = 632 computeScale(GEP->getSourceElementType()->getPrimitiveSizeInBits(), 633 MemoryTy->getScalarSizeInBits()); 634 if (Scale == -1) 635 return nullptr; 636 637 if (!match(Mask, m_One())) 638 return Builder.CreateIntrinsic( 639 Intrinsic::arm_mve_vstr_scatter_offset_predicated, 640 {BasePtr->getType(), Offsets->getType(), Input->getType(), 641 Mask->getType()}, 642 {BasePtr, Offsets, Input, 643 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 644 Builder.getInt32(Scale), Mask}); 645 else 646 return Builder.CreateIntrinsic( 647 Intrinsic::arm_mve_vstr_scatter_offset, 648 {BasePtr->getType(), Offsets->getType(), Input->getType()}, 649 {BasePtr, Offsets, Input, 650 Builder.getInt32(MemoryTy->getScalarSizeInBits()), 651 Builder.getInt32(Scale)}); 652 } 653 654 Value *MVEGatherScatterLowering::tryCreateIncrementingGatScat( 655 IntrinsicInst *I, Value *BasePtr, Value *Offsets, GetElementPtrInst *GEP, 656 IRBuilder<> &Builder) { 657 FixedVectorType *Ty; 658 if (I->getIntrinsicID() == Intrinsic::masked_gather) 659 Ty = cast<FixedVectorType>(I->getType()); 660 else 661 Ty = cast<FixedVectorType>(I->getArgOperand(0)->getType()); 662 // Incrementing gathers only exist for v4i32 663 if (Ty->getNumElements() != 4 || 664 Ty->getScalarSizeInBits() != 32) 665 return nullptr; 666 Loop *L = LI->getLoopFor(I->getParent()); 667 if (L == nullptr) 668 // Incrementing gathers are not beneficial outside of a loop 669 return nullptr; 670 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 671 "wb gather/scatter\n"); 672 673 // The gep was in charge of making sure the offsets are scaled correctly 674 // - calculate that factor so it can be applied by hand 675 DataLayout DT = I->getParent()->getParent()->getParent()->getDataLayout(); 676 int TypeScale = 677 computeScale(DT.getTypeSizeInBits(GEP->getOperand(0)->getType()), 678 DT.getTypeSizeInBits(GEP->getType()) / 679 cast<FixedVectorType>(GEP->getType())->getNumElements()); 680 if (TypeScale == -1) 681 return nullptr; 682 683 if (GEP->hasOneUse()) { 684 // Only in this case do we want to build a wb gather, because the wb will 685 // change the phi which does affect other users of the gep (which will still 686 // be using the phi in the old way) 687 Value *Load = 688 tryCreateIncrementingWBGatScat(I, BasePtr, Offsets, TypeScale, Builder); 689 if (Load != nullptr) 690 return Load; 691 } 692 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to build incrementing " 693 "non-wb gather/scatter\n"); 694 695 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 696 if (Add.first == nullptr) 697 return nullptr; 698 Value *OffsetsIncoming = Add.first; 699 int64_t Immediate = Add.second; 700 701 // Make sure the offsets are scaled correctly 702 Instruction *ScaledOffsets = BinaryOperator::Create( 703 Instruction::Shl, OffsetsIncoming, 704 Builder.CreateVectorSplat(Ty->getNumElements(), Builder.getInt32(TypeScale)), 705 "ScaledIndex", I); 706 // Add the base to the offsets 707 OffsetsIncoming = BinaryOperator::Create( 708 Instruction::Add, ScaledOffsets, 709 Builder.CreateVectorSplat( 710 Ty->getNumElements(), 711 Builder.CreatePtrToInt( 712 BasePtr, 713 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 714 "StartIndex", I); 715 716 if (I->getIntrinsicID() == Intrinsic::masked_gather) 717 return cast<IntrinsicInst>( 718 tryCreateMaskedGatherBase(I, OffsetsIncoming, Builder, Immediate)); 719 else 720 return cast<IntrinsicInst>( 721 tryCreateMaskedScatterBase(I, OffsetsIncoming, Builder, Immediate)); 722 } 723 724 Value *MVEGatherScatterLowering::tryCreateIncrementingWBGatScat( 725 IntrinsicInst *I, Value *BasePtr, Value *Offsets, unsigned TypeScale, 726 IRBuilder<> &Builder) { 727 // Check whether this gather's offset is incremented by a constant - if so, 728 // and the load is of the right type, we can merge this into a QI gather 729 Loop *L = LI->getLoopFor(I->getParent()); 730 // Offsets that are worth merging into this instruction will be incremented 731 // by a constant, thus we're looking for an add of a phi and a constant 732 PHINode *Phi = dyn_cast<PHINode>(Offsets); 733 if (Phi == nullptr || Phi->getNumIncomingValues() != 2 || 734 Phi->getParent() != L->getHeader() || Phi->getNumUses() != 2) 735 // No phi means no IV to write back to; if there is a phi, we expect it 736 // to have exactly two incoming values; the only phis we are interested in 737 // will be loop IV's and have exactly two uses, one in their increment and 738 // one in the gather's gep 739 return nullptr; 740 741 unsigned IncrementIndex = 742 Phi->getIncomingBlock(0) == L->getLoopLatch() ? 0 : 1; 743 // Look through the phi to the phi increment 744 Offsets = Phi->getIncomingValue(IncrementIndex); 745 746 std::pair<Value *, int64_t> Add = getVarAndConst(Offsets, TypeScale); 747 if (Add.first == nullptr) 748 return nullptr; 749 Value *OffsetsIncoming = Add.first; 750 int64_t Immediate = Add.second; 751 if (OffsetsIncoming != Phi) 752 // Then the increment we are looking at is not an increment of the 753 // induction variable, and we don't want to do a writeback 754 return nullptr; 755 756 Builder.SetInsertPoint(&Phi->getIncomingBlock(1 - IncrementIndex)->back()); 757 unsigned NumElems = 758 cast<FixedVectorType>(OffsetsIncoming->getType())->getNumElements(); 759 760 // Make sure the offsets are scaled correctly 761 Instruction *ScaledOffsets = BinaryOperator::Create( 762 Instruction::Shl, Phi->getIncomingValue(1 - IncrementIndex), 763 Builder.CreateVectorSplat(NumElems, Builder.getInt32(TypeScale)), 764 "ScaledIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 765 // Add the base to the offsets 766 OffsetsIncoming = BinaryOperator::Create( 767 Instruction::Add, ScaledOffsets, 768 Builder.CreateVectorSplat( 769 NumElems, 770 Builder.CreatePtrToInt( 771 BasePtr, 772 cast<VectorType>(ScaledOffsets->getType())->getElementType())), 773 "StartIndex", &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 774 // The gather is pre-incrementing 775 OffsetsIncoming = BinaryOperator::Create( 776 Instruction::Sub, OffsetsIncoming, 777 Builder.CreateVectorSplat(NumElems, Builder.getInt32(Immediate)), 778 "PreIncrementStartIndex", 779 &Phi->getIncomingBlock(1 - IncrementIndex)->back()); 780 Phi->setIncomingValue(1 - IncrementIndex, OffsetsIncoming); 781 782 Builder.SetInsertPoint(I); 783 784 Value *EndResult; 785 Value *NewInduction; 786 if (I->getIntrinsicID() == Intrinsic::masked_gather) { 787 // Build the incrementing gather 788 Value *Load = tryCreateMaskedGatherBaseWB(I, Phi, Builder, Immediate); 789 // One value to be handed to whoever uses the gather, one is the loop 790 // increment 791 EndResult = Builder.CreateExtractValue(Load, 0, "Gather"); 792 NewInduction = Builder.CreateExtractValue(Load, 1, "GatherIncrement"); 793 } else { 794 // Build the incrementing scatter 795 NewInduction = tryCreateMaskedScatterBaseWB(I, Phi, Builder, Immediate); 796 EndResult = NewInduction; 797 } 798 Instruction *AddInst = cast<Instruction>(Offsets); 799 AddInst->replaceAllUsesWith(NewInduction); 800 AddInst->eraseFromParent(); 801 Phi->setIncomingValue(IncrementIndex, NewInduction); 802 803 return EndResult; 804 } 805 806 void MVEGatherScatterLowering::pushOutAdd(PHINode *&Phi, 807 Value *OffsSecondOperand, 808 unsigned StartIndex) { 809 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising add instruction\n"); 810 Instruction *InsertionPoint = 811 &cast<Instruction>(Phi->getIncomingBlock(StartIndex)->back()); 812 // Initialize the phi with a vector that contains a sum of the constants 813 Instruction *NewIndex = BinaryOperator::Create( 814 Instruction::Add, Phi->getIncomingValue(StartIndex), OffsSecondOperand, 815 "PushedOutAdd", InsertionPoint); 816 unsigned IncrementIndex = StartIndex == 0 ? 1 : 0; 817 818 // Order such that start index comes first (this reduces mov's) 819 Phi->addIncoming(NewIndex, Phi->getIncomingBlock(StartIndex)); 820 Phi->addIncoming(Phi->getIncomingValue(IncrementIndex), 821 Phi->getIncomingBlock(IncrementIndex)); 822 Phi->removeIncomingValue(IncrementIndex); 823 Phi->removeIncomingValue(StartIndex); 824 } 825 826 void MVEGatherScatterLowering::pushOutMul(PHINode *&Phi, 827 Value *IncrementPerRound, 828 Value *OffsSecondOperand, 829 unsigned LoopIncrement, 830 IRBuilder<> &Builder) { 831 LLVM_DEBUG(dbgs() << "masked gathers/scatters: optimising mul instruction\n"); 832 833 // Create a new scalar add outside of the loop and transform it to a splat 834 // by which loop variable can be incremented 835 Instruction *InsertionPoint = &cast<Instruction>( 836 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)->back()); 837 838 // Create a new index 839 Value *StartIndex = BinaryOperator::Create( 840 Instruction::Mul, Phi->getIncomingValue(LoopIncrement == 1 ? 0 : 1), 841 OffsSecondOperand, "PushedOutMul", InsertionPoint); 842 843 Instruction *Product = 844 BinaryOperator::Create(Instruction::Mul, IncrementPerRound, 845 OffsSecondOperand, "Product", InsertionPoint); 846 // Increment NewIndex by Product instead of the multiplication 847 Instruction *NewIncrement = BinaryOperator::Create( 848 Instruction::Add, Phi, Product, "IncrementPushedOutMul", 849 cast<Instruction>(Phi->getIncomingBlock(LoopIncrement)->back()) 850 .getPrevNode()); 851 852 Phi->addIncoming(StartIndex, 853 Phi->getIncomingBlock(LoopIncrement == 1 ? 0 : 1)); 854 Phi->addIncoming(NewIncrement, Phi->getIncomingBlock(LoopIncrement)); 855 Phi->removeIncomingValue((unsigned)0); 856 Phi->removeIncomingValue((unsigned)0); 857 } 858 859 // Check whether all usages of this instruction are as offsets of 860 // gathers/scatters or simple arithmetics only used by gathers/scatters 861 static bool hasAllGatScatUsers(Instruction *I) { 862 if (I->hasNUses(0)) { 863 return false; 864 } 865 bool Gatscat = true; 866 for (User *U : I->users()) { 867 if (!isa<Instruction>(U)) 868 return false; 869 if (isa<GetElementPtrInst>(U) || 870 isGatherScatter(dyn_cast<IntrinsicInst>(U))) { 871 return Gatscat; 872 } else { 873 unsigned OpCode = cast<Instruction>(U)->getOpcode(); 874 if ((OpCode == Instruction::Add || OpCode == Instruction::Mul) && 875 hasAllGatScatUsers(cast<Instruction>(U))) { 876 continue; 877 } 878 return false; 879 } 880 } 881 return Gatscat; 882 } 883 884 bool MVEGatherScatterLowering::optimiseOffsets(Value *Offsets, BasicBlock *BB, 885 LoopInfo *LI) { 886 LLVM_DEBUG(dbgs() << "masked gathers/scatters: trying to optimize\n" 887 << *Offsets << "\n"); 888 // Optimise the addresses of gathers/scatters by moving invariant 889 // calculations out of the loop 890 if (!isa<Instruction>(Offsets)) 891 return false; 892 Instruction *Offs = cast<Instruction>(Offsets); 893 if (Offs->getOpcode() != Instruction::Add && 894 Offs->getOpcode() != Instruction::Mul) 895 return false; 896 Loop *L = LI->getLoopFor(BB); 897 if (L == nullptr) 898 return false; 899 if (!Offs->hasOneUse()) { 900 if (!hasAllGatScatUsers(Offs)) 901 return false; 902 } 903 904 // Find out which, if any, operand of the instruction 905 // is a phi node 906 PHINode *Phi; 907 int OffsSecondOp; 908 if (isa<PHINode>(Offs->getOperand(0))) { 909 Phi = cast<PHINode>(Offs->getOperand(0)); 910 OffsSecondOp = 1; 911 } else if (isa<PHINode>(Offs->getOperand(1))) { 912 Phi = cast<PHINode>(Offs->getOperand(1)); 913 OffsSecondOp = 0; 914 } else { 915 bool Changed = true; 916 if (isa<Instruction>(Offs->getOperand(0)) && 917 L->contains(cast<Instruction>(Offs->getOperand(0)))) 918 Changed |= optimiseOffsets(Offs->getOperand(0), BB, LI); 919 if (isa<Instruction>(Offs->getOperand(1)) && 920 L->contains(cast<Instruction>(Offs->getOperand(1)))) 921 Changed |= optimiseOffsets(Offs->getOperand(1), BB, LI); 922 if (!Changed) { 923 return false; 924 } else { 925 if (isa<PHINode>(Offs->getOperand(0))) { 926 Phi = cast<PHINode>(Offs->getOperand(0)); 927 OffsSecondOp = 1; 928 } else if (isa<PHINode>(Offs->getOperand(1))) { 929 Phi = cast<PHINode>(Offs->getOperand(1)); 930 OffsSecondOp = 0; 931 } else { 932 return false; 933 } 934 } 935 } 936 // A phi node we want to perform this function on should be from the 937 // loop header, and shouldn't have more than 2 incoming values 938 if (Phi->getParent() != L->getHeader() || 939 Phi->getNumIncomingValues() != 2) 940 return false; 941 942 // The phi must be an induction variable 943 int IncrementingBlock = -1; 944 945 for (int i = 0; i < 2; i++) 946 if (auto *Op = dyn_cast<Instruction>(Phi->getIncomingValue(i))) 947 if (Op->getOpcode() == Instruction::Add && 948 (Op->getOperand(0) == Phi || Op->getOperand(1) == Phi)) 949 IncrementingBlock = i; 950 if (IncrementingBlock == -1) 951 return false; 952 953 Instruction *IncInstruction = 954 cast<Instruction>(Phi->getIncomingValue(IncrementingBlock)); 955 956 // If the phi is not used by anything else, we can just adapt it when 957 // replacing the instruction; if it is, we'll have to duplicate it 958 PHINode *NewPhi; 959 Value *IncrementPerRound = IncInstruction->getOperand( 960 (IncInstruction->getOperand(0) == Phi) ? 1 : 0); 961 962 // Get the value that is added to/multiplied with the phi 963 Value *OffsSecondOperand = Offs->getOperand(OffsSecondOp); 964 965 if (IncrementPerRound->getType() != OffsSecondOperand->getType() || 966 !L->isLoopInvariant(OffsSecondOperand)) 967 // Something has gone wrong, abort 968 return false; 969 970 // Only proceed if the increment per round is a constant or an instruction 971 // which does not originate from within the loop 972 if (!isa<Constant>(IncrementPerRound) && 973 !(isa<Instruction>(IncrementPerRound) && 974 !L->contains(cast<Instruction>(IncrementPerRound)))) 975 return false; 976 977 if (Phi->getNumUses() == 2) { 978 // No other users -> reuse existing phi (One user is the instruction 979 // we're looking at, the other is the phi increment) 980 if (IncInstruction->getNumUses() != 1) { 981 // If the incrementing instruction does have more users than 982 // our phi, we need to copy it 983 IncInstruction = BinaryOperator::Create( 984 Instruction::BinaryOps(IncInstruction->getOpcode()), Phi, 985 IncrementPerRound, "LoopIncrement", IncInstruction); 986 Phi->setIncomingValue(IncrementingBlock, IncInstruction); 987 } 988 NewPhi = Phi; 989 } else { 990 // There are other users -> create a new phi 991 NewPhi = PHINode::Create(Phi->getType(), 0, "NewPhi", Phi); 992 std::vector<Value *> Increases; 993 // Copy the incoming values of the old phi 994 NewPhi->addIncoming(Phi->getIncomingValue(IncrementingBlock == 1 ? 0 : 1), 995 Phi->getIncomingBlock(IncrementingBlock == 1 ? 0 : 1)); 996 IncInstruction = BinaryOperator::Create( 997 Instruction::BinaryOps(IncInstruction->getOpcode()), NewPhi, 998 IncrementPerRound, "LoopIncrement", IncInstruction); 999 NewPhi->addIncoming(IncInstruction, 1000 Phi->getIncomingBlock(IncrementingBlock)); 1001 IncrementingBlock = 1; 1002 } 1003 1004 IRBuilder<> Builder(BB->getContext()); 1005 Builder.SetInsertPoint(Phi); 1006 Builder.SetCurrentDebugLocation(Offs->getDebugLoc()); 1007 1008 switch (Offs->getOpcode()) { 1009 case Instruction::Add: 1010 pushOutAdd(NewPhi, OffsSecondOperand, IncrementingBlock == 1 ? 0 : 1); 1011 break; 1012 case Instruction::Mul: 1013 pushOutMul(NewPhi, IncrementPerRound, OffsSecondOperand, IncrementingBlock, 1014 Builder); 1015 break; 1016 default: 1017 return false; 1018 } 1019 LLVM_DEBUG(dbgs() << "masked gathers/scatters: simplified loop variable " 1020 << "add/mul\n"); 1021 1022 // The instruction has now been "absorbed" into the phi value 1023 Offs->replaceAllUsesWith(NewPhi); 1024 if (Offs->hasNUses(0)) 1025 Offs->eraseFromParent(); 1026 // Clean up the old increment in case it's unused because we built a new 1027 // one 1028 if (IncInstruction->hasNUses(0)) 1029 IncInstruction->eraseFromParent(); 1030 1031 return true; 1032 } 1033 1034 static Value *CheckAndCreateOffsetAdd(Value *X, Value *Y, Value *GEP, 1035 IRBuilder<> &Builder) { 1036 // Splat the non-vector value to a vector of the given type - if the value is 1037 // a constant (and its value isn't too big), we can even use this opportunity 1038 // to scale it to the size of the vector elements 1039 auto FixSummands = [&Builder](FixedVectorType *&VT, Value *&NonVectorVal) { 1040 ConstantInt *Const; 1041 if ((Const = dyn_cast<ConstantInt>(NonVectorVal)) && 1042 VT->getElementType() != NonVectorVal->getType()) { 1043 unsigned TargetElemSize = VT->getElementType()->getPrimitiveSizeInBits(); 1044 uint64_t N = Const->getZExtValue(); 1045 if (N < (unsigned)(1 << (TargetElemSize - 1))) { 1046 NonVectorVal = Builder.CreateVectorSplat( 1047 VT->getNumElements(), Builder.getIntN(TargetElemSize, N)); 1048 return; 1049 } 1050 } 1051 NonVectorVal = 1052 Builder.CreateVectorSplat(VT->getNumElements(), NonVectorVal); 1053 }; 1054 1055 FixedVectorType *XElType = dyn_cast<FixedVectorType>(X->getType()); 1056 FixedVectorType *YElType = dyn_cast<FixedVectorType>(Y->getType()); 1057 // If one of X, Y is not a vector, we have to splat it in order 1058 // to add the two of them. 1059 if (XElType && !YElType) { 1060 FixSummands(XElType, Y); 1061 YElType = cast<FixedVectorType>(Y->getType()); 1062 } else if (YElType && !XElType) { 1063 FixSummands(YElType, X); 1064 XElType = cast<FixedVectorType>(X->getType()); 1065 } 1066 assert(XElType && YElType && "Unknown vector types"); 1067 // Check that the summands are of compatible types 1068 if (XElType != YElType) { 1069 LLVM_DEBUG(dbgs() << "masked gathers/scatters: incompatible gep offsets\n"); 1070 return nullptr; 1071 } 1072 1073 if (XElType->getElementType()->getScalarSizeInBits() != 32) { 1074 // Check that by adding the vectors we do not accidentally 1075 // create an overflow 1076 Constant *ConstX = dyn_cast<Constant>(X); 1077 Constant *ConstY = dyn_cast<Constant>(Y); 1078 if (!ConstX || !ConstY) 1079 return nullptr; 1080 unsigned TargetElemSize = 128 / XElType->getNumElements(); 1081 for (unsigned i = 0; i < XElType->getNumElements(); i++) { 1082 ConstantInt *ConstXEl = 1083 dyn_cast<ConstantInt>(ConstX->getAggregateElement(i)); 1084 ConstantInt *ConstYEl = 1085 dyn_cast<ConstantInt>(ConstY->getAggregateElement(i)); 1086 if (!ConstXEl || !ConstYEl || 1087 ConstXEl->getZExtValue() + ConstYEl->getZExtValue() >= 1088 (unsigned)(1 << (TargetElemSize - 1))) 1089 return nullptr; 1090 } 1091 } 1092 1093 Value *Add = Builder.CreateAdd(X, Y); 1094 1095 FixedVectorType *GEPType = cast<FixedVectorType>(GEP->getType()); 1096 if (checkOffsetSize(Add, GEPType->getNumElements())) 1097 return Add; 1098 else 1099 return nullptr; 1100 } 1101 1102 Value *MVEGatherScatterLowering::foldGEP(GetElementPtrInst *GEP, 1103 Value *&Offsets, 1104 IRBuilder<> &Builder) { 1105 Value *GEPPtr = GEP->getPointerOperand(); 1106 Offsets = GEP->getOperand(1); 1107 // We only merge geps with constant offsets, because only for those 1108 // we can make sure that we do not cause an overflow 1109 if (!isa<Constant>(Offsets)) 1110 return nullptr; 1111 GetElementPtrInst *BaseGEP; 1112 if ((BaseGEP = dyn_cast<GetElementPtrInst>(GEPPtr))) { 1113 // Merge the two geps into one 1114 Value *BaseBasePtr = foldGEP(BaseGEP, Offsets, Builder); 1115 if (!BaseBasePtr) 1116 return nullptr; 1117 Offsets = 1118 CheckAndCreateOffsetAdd(Offsets, GEP->getOperand(1), GEP, Builder); 1119 if (Offsets == nullptr) 1120 return nullptr; 1121 return BaseBasePtr; 1122 } 1123 return GEPPtr; 1124 } 1125 1126 bool MVEGatherScatterLowering::optimiseAddress(Value *Address, BasicBlock *BB, 1127 LoopInfo *LI) { 1128 GetElementPtrInst *GEP = dyn_cast<GetElementPtrInst>(Address); 1129 if (!GEP) 1130 return false; 1131 bool Changed = false; 1132 if (GEP->hasOneUse() && 1133 dyn_cast<GetElementPtrInst>(GEP->getPointerOperand())) { 1134 IRBuilder<> Builder(GEP->getContext()); 1135 Builder.SetInsertPoint(GEP); 1136 Builder.SetCurrentDebugLocation(GEP->getDebugLoc()); 1137 Value *Offsets; 1138 Value *Base = foldGEP(GEP, Offsets, Builder); 1139 // We only want to merge the geps if there is a real chance that they can be 1140 // used by an MVE gather; thus the offset has to have the correct size 1141 // (always i32 if it is not of vector type) and the base has to be a 1142 // pointer. 1143 if (Offsets && Base && Base != GEP) { 1144 GetElementPtrInst *NewAddress = GetElementPtrInst::Create( 1145 GEP->getSourceElementType(), Base, Offsets, "gep.merged", GEP); 1146 GEP->replaceAllUsesWith(NewAddress); 1147 GEP = NewAddress; 1148 Changed = true; 1149 } 1150 } 1151 Changed |= optimiseOffsets(GEP->getOperand(1), GEP->getParent(), LI); 1152 return Changed; 1153 } 1154 1155 bool MVEGatherScatterLowering::runOnFunction(Function &F) { 1156 if (!EnableMaskedGatherScatters) 1157 return false; 1158 auto &TPC = getAnalysis<TargetPassConfig>(); 1159 auto &TM = TPC.getTM<TargetMachine>(); 1160 auto *ST = &TM.getSubtarget<ARMSubtarget>(F); 1161 if (!ST->hasMVEIntegerOps()) 1162 return false; 1163 LI = &getAnalysis<LoopInfoWrapperPass>().getLoopInfo(); 1164 SmallVector<IntrinsicInst *, 4> Gathers; 1165 SmallVector<IntrinsicInst *, 4> Scatters; 1166 1167 bool Changed = false; 1168 1169 for (BasicBlock &BB : F) { 1170 Changed |= SimplifyInstructionsInBlock(&BB); 1171 1172 for (Instruction &I : BB) { 1173 IntrinsicInst *II = dyn_cast<IntrinsicInst>(&I); 1174 if (II && II->getIntrinsicID() == Intrinsic::masked_gather && 1175 isa<FixedVectorType>(II->getType())) { 1176 Gathers.push_back(II); 1177 Changed |= optimiseAddress(II->getArgOperand(0), II->getParent(), LI); 1178 } else if (II && II->getIntrinsicID() == Intrinsic::masked_scatter && 1179 isa<FixedVectorType>(II->getArgOperand(0)->getType())) { 1180 Scatters.push_back(II); 1181 Changed |= optimiseAddress(II->getArgOperand(1), II->getParent(), LI); 1182 } 1183 } 1184 } 1185 for (unsigned i = 0; i < Gathers.size(); i++) { 1186 IntrinsicInst *I = Gathers[i]; 1187 Value *L = lowerGather(I); 1188 if (L == nullptr) 1189 continue; 1190 1191 // Get rid of any now dead instructions 1192 SimplifyInstructionsInBlock(cast<Instruction>(L)->getParent()); 1193 Changed = true; 1194 } 1195 1196 for (unsigned i = 0; i < Scatters.size(); i++) { 1197 IntrinsicInst *I = Scatters[i]; 1198 Value *S = lowerScatter(I); 1199 if (S == nullptr) 1200 continue; 1201 1202 // Get rid of any now dead instructions 1203 SimplifyInstructionsInBlock(cast<Instruction>(S)->getParent()); 1204 Changed = true; 1205 } 1206 return Changed; 1207 } 1208