1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// 2 // instrinsics 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass replaces masked memory intrinsics - when unsupported by the target 11 // - with a chain of basic blocks, that deal with the elements one-by-one if the 12 // appropriate mask bit is set. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/IR/BasicBlock.h" 20 #include "llvm/IR/Constant.h" 21 #include "llvm/IR/Constants.h" 22 #include "llvm/IR/DerivedTypes.h" 23 #include "llvm/IR/Function.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/IR/InstrTypes.h" 26 #include "llvm/IR/Instruction.h" 27 #include "llvm/IR/Instructions.h" 28 #include "llvm/IR/IntrinsicInst.h" 29 #include "llvm/IR/Intrinsics.h" 30 #include "llvm/IR/Type.h" 31 #include "llvm/IR/Value.h" 32 #include "llvm/InitializePasses.h" 33 #include "llvm/Pass.h" 34 #include "llvm/Support/Casting.h" 35 #include "llvm/Transforms/Scalar.h" 36 #include <algorithm> 37 #include <cassert> 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "scalarize-masked-mem-intrin" 42 43 namespace { 44 45 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { 46 public: 47 static char ID; // Pass identification, replacement for typeid 48 49 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { 50 initializeScalarizeMaskedMemIntrinLegacyPassPass( 51 *PassRegistry::getPassRegistry()); 52 } 53 54 bool runOnFunction(Function &F) override; 55 56 StringRef getPassName() const override { 57 return "Scalarize Masked Memory Intrinsics"; 58 } 59 60 void getAnalysisUsage(AnalysisUsage &AU) const override { 61 AU.addRequired<TargetTransformInfoWrapperPass>(); 62 } 63 }; 64 65 } // end anonymous namespace 66 67 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 68 const TargetTransformInfo &TTI, const DataLayout &DL); 69 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 70 const TargetTransformInfo &TTI, 71 const DataLayout &DL); 72 73 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; 74 75 INITIALIZE_PASS(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 76 "Scalarize unsupported masked memory intrinsics", false, false) 77 78 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { 79 return new ScalarizeMaskedMemIntrinLegacyPass(); 80 } 81 82 static bool isConstantIntVector(Value *Mask) { 83 Constant *C = dyn_cast<Constant>(Mask); 84 if (!C) 85 return false; 86 87 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements(); 88 for (unsigned i = 0; i != NumElts; ++i) { 89 Constant *CElt = C->getAggregateElement(i); 90 if (!CElt || !isa<ConstantInt>(CElt)) 91 return false; 92 } 93 94 return true; 95 } 96 97 // Translate a masked load intrinsic like 98 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, 99 // <16 x i1> %mask, <16 x i32> %passthru) 100 // to a chain of basic blocks, with loading element one-by-one if 101 // the appropriate mask bit is set 102 // 103 // %1 = bitcast i8* %addr to i32* 104 // %2 = extractelement <16 x i1> %mask, i32 0 105 // br i1 %2, label %cond.load, label %else 106 // 107 // cond.load: ; preds = %0 108 // %3 = getelementptr i32* %1, i32 0 109 // %4 = load i32* %3 110 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 111 // br label %else 112 // 113 // else: ; preds = %0, %cond.load 114 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] 115 // %6 = extractelement <16 x i1> %mask, i32 1 116 // br i1 %6, label %cond.load1, label %else2 117 // 118 // cond.load1: ; preds = %else 119 // %7 = getelementptr i32* %1, i32 1 120 // %8 = load i32* %7 121 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 122 // br label %else2 123 // 124 // else2: ; preds = %else, %cond.load1 125 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] 126 // %10 = extractelement <16 x i1> %mask, i32 2 127 // br i1 %10, label %cond.load4, label %else5 128 // 129 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) { 130 Value *Ptr = CI->getArgOperand(0); 131 Value *Alignment = CI->getArgOperand(1); 132 Value *Mask = CI->getArgOperand(2); 133 Value *Src0 = CI->getArgOperand(3); 134 135 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 136 VectorType *VecType = cast<FixedVectorType>(CI->getType()); 137 138 Type *EltTy = VecType->getElementType(); 139 140 IRBuilder<> Builder(CI->getContext()); 141 Instruction *InsertPt = CI; 142 BasicBlock *IfBlock = CI->getParent(); 143 144 Builder.SetInsertPoint(InsertPt); 145 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 146 147 // Short-cut if the mask is all-true. 148 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 149 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); 150 CI->replaceAllUsesWith(NewI); 151 CI->eraseFromParent(); 152 return; 153 } 154 155 // Adjust alignment for the scalar instruction. 156 const Align AdjustedAlignVal = 157 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 158 // Bitcast %addr from i8* to EltTy* 159 Type *NewPtrType = 160 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 161 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 162 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 163 164 // The result vector 165 Value *VResult = Src0; 166 167 if (isConstantIntVector(Mask)) { 168 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 169 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 170 continue; 171 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 172 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 173 VResult = Builder.CreateInsertElement(VResult, Load, Idx); 174 } 175 CI->replaceAllUsesWith(VResult); 176 CI->eraseFromParent(); 177 return; 178 } 179 180 // If the mask is not v1i1, use scalar bit test operations. This generates 181 // better results on X86 at least. 182 Value *SclrMask; 183 if (VectorWidth != 1) { 184 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 185 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 186 } 187 188 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 189 // Fill the "else" block, created in the previous iteration 190 // 191 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 192 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 193 // %cond = icmp ne i16 %mask_1, 0 194 // br i1 %mask_1, label %cond.load, label %else 195 // 196 Value *Predicate; 197 if (VectorWidth != 1) { 198 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 199 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 200 Builder.getIntN(VectorWidth, 0)); 201 } else { 202 Predicate = Builder.CreateExtractElement(Mask, Idx); 203 } 204 205 // Create "cond" block 206 // 207 // %EltAddr = getelementptr i32* %1, i32 0 208 // %Elt = load i32* %EltAddr 209 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 210 // 211 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), 212 "cond.load"); 213 Builder.SetInsertPoint(InsertPt); 214 215 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 216 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 217 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 218 219 // Create "else" block, fill it in the next iteration 220 BasicBlock *NewIfBlock = 221 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 222 Builder.SetInsertPoint(InsertPt); 223 Instruction *OldBr = IfBlock->getTerminator(); 224 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 225 OldBr->eraseFromParent(); 226 BasicBlock *PrevIfBlock = IfBlock; 227 IfBlock = NewIfBlock; 228 229 // Create the phi to join the new and previous value. 230 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 231 Phi->addIncoming(NewVResult, CondBlock); 232 Phi->addIncoming(VResult, PrevIfBlock); 233 VResult = Phi; 234 } 235 236 CI->replaceAllUsesWith(VResult); 237 CI->eraseFromParent(); 238 239 ModifiedDT = true; 240 } 241 242 // Translate a masked store intrinsic, like 243 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, 244 // <16 x i1> %mask) 245 // to a chain of basic blocks, that stores element one-by-one if 246 // the appropriate mask bit is set 247 // 248 // %1 = bitcast i8* %addr to i32* 249 // %2 = extractelement <16 x i1> %mask, i32 0 250 // br i1 %2, label %cond.store, label %else 251 // 252 // cond.store: ; preds = %0 253 // %3 = extractelement <16 x i32> %val, i32 0 254 // %4 = getelementptr i32* %1, i32 0 255 // store i32 %3, i32* %4 256 // br label %else 257 // 258 // else: ; preds = %0, %cond.store 259 // %5 = extractelement <16 x i1> %mask, i32 1 260 // br i1 %5, label %cond.store1, label %else2 261 // 262 // cond.store1: ; preds = %else 263 // %6 = extractelement <16 x i32> %val, i32 1 264 // %7 = getelementptr i32* %1, i32 1 265 // store i32 %6, i32* %7 266 // br label %else2 267 // . . . 268 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) { 269 Value *Src = CI->getArgOperand(0); 270 Value *Ptr = CI->getArgOperand(1); 271 Value *Alignment = CI->getArgOperand(2); 272 Value *Mask = CI->getArgOperand(3); 273 274 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 275 auto *VecType = cast<VectorType>(Src->getType()); 276 277 Type *EltTy = VecType->getElementType(); 278 279 IRBuilder<> Builder(CI->getContext()); 280 Instruction *InsertPt = CI; 281 BasicBlock *IfBlock = CI->getParent(); 282 Builder.SetInsertPoint(InsertPt); 283 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 284 285 // Short-cut if the mask is all-true. 286 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 287 Builder.CreateAlignedStore(Src, Ptr, AlignVal); 288 CI->eraseFromParent(); 289 return; 290 } 291 292 // Adjust alignment for the scalar instruction. 293 const Align AdjustedAlignVal = 294 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 295 // Bitcast %addr from i8* to EltTy* 296 Type *NewPtrType = 297 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 298 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 299 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 300 301 if (isConstantIntVector(Mask)) { 302 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 303 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 304 continue; 305 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 306 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 307 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 308 } 309 CI->eraseFromParent(); 310 return; 311 } 312 313 // If the mask is not v1i1, use scalar bit test operations. This generates 314 // better results on X86 at least. 315 Value *SclrMask; 316 if (VectorWidth != 1) { 317 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 318 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 319 } 320 321 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 322 // Fill the "else" block, created in the previous iteration 323 // 324 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 325 // %cond = icmp ne i16 %mask_1, 0 326 // br i1 %mask_1, label %cond.store, label %else 327 // 328 Value *Predicate; 329 if (VectorWidth != 1) { 330 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 331 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 332 Builder.getIntN(VectorWidth, 0)); 333 } else { 334 Predicate = Builder.CreateExtractElement(Mask, Idx); 335 } 336 337 // Create "cond" block 338 // 339 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 340 // %EltAddr = getelementptr i32* %1, i32 0 341 // %store i32 %OneElt, i32* %EltAddr 342 // 343 BasicBlock *CondBlock = 344 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); 345 Builder.SetInsertPoint(InsertPt); 346 347 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 348 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 349 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 350 351 // Create "else" block, fill it in the next iteration 352 BasicBlock *NewIfBlock = 353 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 354 Builder.SetInsertPoint(InsertPt); 355 Instruction *OldBr = IfBlock->getTerminator(); 356 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 357 OldBr->eraseFromParent(); 358 IfBlock = NewIfBlock; 359 } 360 CI->eraseFromParent(); 361 362 ModifiedDT = true; 363 } 364 365 // Translate a masked gather intrinsic like 366 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, 367 // <16 x i1> %Mask, <16 x i32> %Src) 368 // to a chain of basic blocks, with loading element one-by-one if 369 // the appropriate mask bit is set 370 // 371 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind 372 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 373 // br i1 %Mask0, label %cond.load, label %else 374 // 375 // cond.load: 376 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 377 // %Load0 = load i32, i32* %Ptr0, align 4 378 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 379 // br label %else 380 // 381 // else: 382 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] 383 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 384 // br i1 %Mask1, label %cond.load1, label %else2 385 // 386 // cond.load1: 387 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 388 // %Load1 = load i32, i32* %Ptr1, align 4 389 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 390 // br label %else2 391 // . . . 392 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src 393 // ret <16 x i32> %Result 394 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) { 395 Value *Ptrs = CI->getArgOperand(0); 396 Value *Alignment = CI->getArgOperand(1); 397 Value *Mask = CI->getArgOperand(2); 398 Value *Src0 = CI->getArgOperand(3); 399 400 auto *VecType = cast<FixedVectorType>(CI->getType()); 401 Type *EltTy = VecType->getElementType(); 402 403 IRBuilder<> Builder(CI->getContext()); 404 Instruction *InsertPt = CI; 405 BasicBlock *IfBlock = CI->getParent(); 406 Builder.SetInsertPoint(InsertPt); 407 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 408 409 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 410 411 // The result vector 412 Value *VResult = Src0; 413 unsigned VectorWidth = VecType->getNumElements(); 414 415 // Shorten the way if the mask is a vector of constants. 416 if (isConstantIntVector(Mask)) { 417 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 418 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 419 continue; 420 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 421 LoadInst *Load = 422 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 423 VResult = 424 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 425 } 426 CI->replaceAllUsesWith(VResult); 427 CI->eraseFromParent(); 428 return; 429 } 430 431 // If the mask is not v1i1, use scalar bit test operations. This generates 432 // better results on X86 at least. 433 Value *SclrMask; 434 if (VectorWidth != 1) { 435 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 436 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 437 } 438 439 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 440 // Fill the "else" block, created in the previous iteration 441 // 442 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 443 // %cond = icmp ne i16 %mask_1, 0 444 // br i1 %Mask1, label %cond.load, label %else 445 // 446 447 Value *Predicate; 448 if (VectorWidth != 1) { 449 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 450 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 451 Builder.getIntN(VectorWidth, 0)); 452 } else { 453 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 454 } 455 456 // Create "cond" block 457 // 458 // %EltAddr = getelementptr i32* %1, i32 0 459 // %Elt = load i32* %EltAddr 460 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 461 // 462 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load"); 463 Builder.SetInsertPoint(InsertPt); 464 465 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 466 LoadInst *Load = 467 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 468 Value *NewVResult = 469 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 470 471 // Create "else" block, fill it in the next iteration 472 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); 473 Builder.SetInsertPoint(InsertPt); 474 Instruction *OldBr = IfBlock->getTerminator(); 475 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 476 OldBr->eraseFromParent(); 477 BasicBlock *PrevIfBlock = IfBlock; 478 IfBlock = NewIfBlock; 479 480 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 481 Phi->addIncoming(NewVResult, CondBlock); 482 Phi->addIncoming(VResult, PrevIfBlock); 483 VResult = Phi; 484 } 485 486 CI->replaceAllUsesWith(VResult); 487 CI->eraseFromParent(); 488 489 ModifiedDT = true; 490 } 491 492 // Translate a masked scatter intrinsic, like 493 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, 494 // <16 x i1> %Mask) 495 // to a chain of basic blocks, that stores element one-by-one if 496 // the appropriate mask bit is set. 497 // 498 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind 499 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 500 // br i1 %Mask0, label %cond.store, label %else 501 // 502 // cond.store: 503 // %Elt0 = extractelement <16 x i32> %Src, i32 0 504 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 505 // store i32 %Elt0, i32* %Ptr0, align 4 506 // br label %else 507 // 508 // else: 509 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 510 // br i1 %Mask1, label %cond.store1, label %else2 511 // 512 // cond.store1: 513 // %Elt1 = extractelement <16 x i32> %Src, i32 1 514 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 515 // store i32 %Elt1, i32* %Ptr1, align 4 516 // br label %else2 517 // . . . 518 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) { 519 Value *Src = CI->getArgOperand(0); 520 Value *Ptrs = CI->getArgOperand(1); 521 Value *Alignment = CI->getArgOperand(2); 522 Value *Mask = CI->getArgOperand(3); 523 524 auto *SrcFVTy = cast<FixedVectorType>(Src->getType()); 525 526 assert( 527 isa<VectorType>(Ptrs->getType()) && 528 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && 529 "Vector of pointers is expected in masked scatter intrinsic"); 530 531 IRBuilder<> Builder(CI->getContext()); 532 Instruction *InsertPt = CI; 533 BasicBlock *IfBlock = CI->getParent(); 534 Builder.SetInsertPoint(InsertPt); 535 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 536 537 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 538 unsigned VectorWidth = SrcFVTy->getNumElements(); 539 540 // Shorten the way if the mask is a vector of constants. 541 if (isConstantIntVector(Mask)) { 542 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 543 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 544 continue; 545 Value *OneElt = 546 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 547 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 548 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 549 } 550 CI->eraseFromParent(); 551 return; 552 } 553 554 // If the mask is not v1i1, use scalar bit test operations. This generates 555 // better results on X86 at least. 556 Value *SclrMask; 557 if (VectorWidth != 1) { 558 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 559 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 560 } 561 562 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 563 // Fill the "else" block, created in the previous iteration 564 // 565 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 566 // %cond = icmp ne i16 %mask_1, 0 567 // br i1 %Mask1, label %cond.store, label %else 568 // 569 Value *Predicate; 570 if (VectorWidth != 1) { 571 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 572 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 573 Builder.getIntN(VectorWidth, 0)); 574 } else { 575 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 576 } 577 578 // Create "cond" block 579 // 580 // %Elt1 = extractelement <16 x i32> %Src, i32 1 581 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 582 // %store i32 %Elt1, i32* %Ptr1 583 // 584 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store"); 585 Builder.SetInsertPoint(InsertPt); 586 587 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 588 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 589 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 590 591 // Create "else" block, fill it in the next iteration 592 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); 593 Builder.SetInsertPoint(InsertPt); 594 Instruction *OldBr = IfBlock->getTerminator(); 595 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 596 OldBr->eraseFromParent(); 597 IfBlock = NewIfBlock; 598 } 599 CI->eraseFromParent(); 600 601 ModifiedDT = true; 602 } 603 604 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) { 605 Value *Ptr = CI->getArgOperand(0); 606 Value *Mask = CI->getArgOperand(1); 607 Value *PassThru = CI->getArgOperand(2); 608 609 auto *VecType = cast<FixedVectorType>(CI->getType()); 610 611 Type *EltTy = VecType->getElementType(); 612 613 IRBuilder<> Builder(CI->getContext()); 614 Instruction *InsertPt = CI; 615 BasicBlock *IfBlock = CI->getParent(); 616 617 Builder.SetInsertPoint(InsertPt); 618 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 619 620 unsigned VectorWidth = VecType->getNumElements(); 621 622 // The result vector 623 Value *VResult = PassThru; 624 625 // Shorten the way if the mask is a vector of constants. 626 // Create a build_vector pattern, with loads/undefs as necessary and then 627 // shuffle blend with the pass through value. 628 if (isConstantIntVector(Mask)) { 629 unsigned MemIndex = 0; 630 VResult = UndefValue::get(VecType); 631 SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem); 632 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 633 Value *InsertElt; 634 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) { 635 InsertElt = UndefValue::get(EltTy); 636 ShuffleMask[Idx] = Idx + VectorWidth; 637 } else { 638 Value *NewPtr = 639 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 640 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1), 641 "Load" + Twine(Idx)); 642 ShuffleMask[Idx] = Idx; 643 ++MemIndex; 644 } 645 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx, 646 "Res" + Twine(Idx)); 647 } 648 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask); 649 CI->replaceAllUsesWith(VResult); 650 CI->eraseFromParent(); 651 return; 652 } 653 654 // If the mask is not v1i1, use scalar bit test operations. This generates 655 // better results on X86 at least. 656 Value *SclrMask; 657 if (VectorWidth != 1) { 658 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 659 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 660 } 661 662 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 663 // Fill the "else" block, created in the previous iteration 664 // 665 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 666 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 667 // br i1 %mask_1, label %cond.load, label %else 668 // 669 670 Value *Predicate; 671 if (VectorWidth != 1) { 672 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 673 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 674 Builder.getIntN(VectorWidth, 0)); 675 } else { 676 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 677 } 678 679 // Create "cond" block 680 // 681 // %EltAddr = getelementptr i32* %1, i32 0 682 // %Elt = load i32* %EltAddr 683 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 684 // 685 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), 686 "cond.load"); 687 Builder.SetInsertPoint(InsertPt); 688 689 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1)); 690 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 691 692 // Move the pointer if there are more blocks to come. 693 Value *NewPtr; 694 if ((Idx + 1) != VectorWidth) 695 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 696 697 // Create "else" block, fill it in the next iteration 698 BasicBlock *NewIfBlock = 699 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 700 Builder.SetInsertPoint(InsertPt); 701 Instruction *OldBr = IfBlock->getTerminator(); 702 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 703 OldBr->eraseFromParent(); 704 BasicBlock *PrevIfBlock = IfBlock; 705 IfBlock = NewIfBlock; 706 707 // Create the phi to join the new and previous value. 708 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 709 ResultPhi->addIncoming(NewVResult, CondBlock); 710 ResultPhi->addIncoming(VResult, PrevIfBlock); 711 VResult = ResultPhi; 712 713 // Add a PHI for the pointer if this isn't the last iteration. 714 if ((Idx + 1) != VectorWidth) { 715 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 716 PtrPhi->addIncoming(NewPtr, CondBlock); 717 PtrPhi->addIncoming(Ptr, PrevIfBlock); 718 Ptr = PtrPhi; 719 } 720 } 721 722 CI->replaceAllUsesWith(VResult); 723 CI->eraseFromParent(); 724 725 ModifiedDT = true; 726 } 727 728 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) { 729 Value *Src = CI->getArgOperand(0); 730 Value *Ptr = CI->getArgOperand(1); 731 Value *Mask = CI->getArgOperand(2); 732 733 auto *VecType = cast<FixedVectorType>(Src->getType()); 734 735 IRBuilder<> Builder(CI->getContext()); 736 Instruction *InsertPt = CI; 737 BasicBlock *IfBlock = CI->getParent(); 738 739 Builder.SetInsertPoint(InsertPt); 740 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 741 742 Type *EltTy = VecType->getElementType(); 743 744 unsigned VectorWidth = VecType->getNumElements(); 745 746 // Shorten the way if the mask is a vector of constants. 747 if (isConstantIntVector(Mask)) { 748 unsigned MemIndex = 0; 749 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 750 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 751 continue; 752 Value *OneElt = 753 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 754 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 755 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1)); 756 ++MemIndex; 757 } 758 CI->eraseFromParent(); 759 return; 760 } 761 762 // If the mask is not v1i1, use scalar bit test operations. This generates 763 // better results on X86 at least. 764 Value *SclrMask; 765 if (VectorWidth != 1) { 766 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 767 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 768 } 769 770 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 771 // Fill the "else" block, created in the previous iteration 772 // 773 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 774 // br i1 %mask_1, label %cond.store, label %else 775 // 776 Value *Predicate; 777 if (VectorWidth != 1) { 778 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 779 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 780 Builder.getIntN(VectorWidth, 0)); 781 } else { 782 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 783 } 784 785 // Create "cond" block 786 // 787 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 788 // %EltAddr = getelementptr i32* %1, i32 0 789 // %store i32 %OneElt, i32* %EltAddr 790 // 791 BasicBlock *CondBlock = 792 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); 793 Builder.SetInsertPoint(InsertPt); 794 795 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 796 Builder.CreateAlignedStore(OneElt, Ptr, Align(1)); 797 798 // Move the pointer if there are more blocks to come. 799 Value *NewPtr; 800 if ((Idx + 1) != VectorWidth) 801 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 802 803 // Create "else" block, fill it in the next iteration 804 BasicBlock *NewIfBlock = 805 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 806 Builder.SetInsertPoint(InsertPt); 807 Instruction *OldBr = IfBlock->getTerminator(); 808 BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr); 809 OldBr->eraseFromParent(); 810 BasicBlock *PrevIfBlock = IfBlock; 811 IfBlock = NewIfBlock; 812 813 // Add a PHI for the pointer if this isn't the last iteration. 814 if ((Idx + 1) != VectorWidth) { 815 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 816 PtrPhi->addIncoming(NewPtr, CondBlock); 817 PtrPhi->addIncoming(Ptr, PrevIfBlock); 818 Ptr = PtrPhi; 819 } 820 } 821 CI->eraseFromParent(); 822 823 ModifiedDT = true; 824 } 825 826 static bool runImpl(Function &F, const TargetTransformInfo &TTI) { 827 bool EverMadeChange = false; 828 bool MadeChange = true; 829 auto &DL = F.getParent()->getDataLayout(); 830 while (MadeChange) { 831 MadeChange = false; 832 for (Function::iterator I = F.begin(); I != F.end();) { 833 BasicBlock *BB = &*I++; 834 bool ModifiedDTOnIteration = false; 835 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL); 836 837 // Restart BB iteration if the dominator tree of the Function was changed 838 if (ModifiedDTOnIteration) 839 break; 840 } 841 842 EverMadeChange |= MadeChange; 843 } 844 return EverMadeChange; 845 } 846 847 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { 848 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 849 return runImpl(F, TTI); 850 } 851 852 PreservedAnalyses 853 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { 854 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 855 if (!runImpl(F, TTI)) 856 return PreservedAnalyses::all(); 857 PreservedAnalyses PA; 858 PA.preserve<TargetIRAnalysis>(); 859 return PA; 860 } 861 862 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 863 const TargetTransformInfo &TTI, 864 const DataLayout &DL) { 865 bool MadeChange = false; 866 867 BasicBlock::iterator CurInstIterator = BB.begin(); 868 while (CurInstIterator != BB.end()) { 869 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) 870 MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL); 871 if (ModifiedDT) 872 return true; 873 } 874 875 return MadeChange; 876 } 877 878 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 879 const TargetTransformInfo &TTI, 880 const DataLayout &DL) { 881 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 882 if (II) { 883 // The scalarization code below does not work for scalable vectors. 884 if (isa<ScalableVectorType>(II->getType()) || 885 any_of(II->arg_operands(), 886 [](Value *V) { return isa<ScalableVectorType>(V->getType()); })) 887 return false; 888 889 switch (II->getIntrinsicID()) { 890 default: 891 break; 892 case Intrinsic::masked_load: 893 // Scalarize unsupported vector masked load 894 if (TTI.isLegalMaskedLoad( 895 CI->getType(), 896 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue())) 897 return false; 898 scalarizeMaskedLoad(CI, ModifiedDT); 899 return true; 900 case Intrinsic::masked_store: 901 if (TTI.isLegalMaskedStore( 902 CI->getArgOperand(0)->getType(), 903 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue())) 904 return false; 905 scalarizeMaskedStore(CI, ModifiedDT); 906 return true; 907 case Intrinsic::masked_gather: { 908 unsigned AlignmentInt = 909 cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue(); 910 Type *LoadTy = CI->getType(); 911 Align Alignment = 912 DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy); 913 if (TTI.isLegalMaskedGather(LoadTy, Alignment)) 914 return false; 915 scalarizeMaskedGather(CI, ModifiedDT); 916 return true; 917 } 918 case Intrinsic::masked_scatter: { 919 unsigned AlignmentInt = 920 cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue(); 921 Type *StoreTy = CI->getArgOperand(0)->getType(); 922 Align Alignment = 923 DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy); 924 if (TTI.isLegalMaskedScatter(StoreTy, Alignment)) 925 return false; 926 scalarizeMaskedScatter(CI, ModifiedDT); 927 return true; 928 } 929 case Intrinsic::masked_expandload: 930 if (TTI.isLegalMaskedExpandLoad(CI->getType())) 931 return false; 932 scalarizeMaskedExpandLoad(CI, ModifiedDT); 933 return true; 934 case Intrinsic::masked_compressstore: 935 if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) 936 return false; 937 scalarizeMaskedCompressStore(CI, ModifiedDT); 938 return true; 939 } 940 } 941 942 return false; 943 } 944