1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===// 2 // instrinsics 3 // 4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 5 // See https://llvm.org/LICENSE.txt for license information. 6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass replaces masked memory intrinsics - when unsupported by the target 11 // - with a chain of basic blocks, that deal with the elements one-by-one if the 12 // appropriate mask bit is set. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h" 17 #include "llvm/ADT/Twine.h" 18 #include "llvm/Analysis/DomTreeUpdater.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/IR/BasicBlock.h" 21 #include "llvm/IR/Constant.h" 22 #include "llvm/IR/Constants.h" 23 #include "llvm/IR/DerivedTypes.h" 24 #include "llvm/IR/Dominators.h" 25 #include "llvm/IR/Function.h" 26 #include "llvm/IR/IRBuilder.h" 27 #include "llvm/IR/InstrTypes.h" 28 #include "llvm/IR/Instruction.h" 29 #include "llvm/IR/Instructions.h" 30 #include "llvm/IR/IntrinsicInst.h" 31 #include "llvm/IR/Intrinsics.h" 32 #include "llvm/IR/Type.h" 33 #include "llvm/IR/Value.h" 34 #include "llvm/InitializePasses.h" 35 #include "llvm/Pass.h" 36 #include "llvm/Support/Casting.h" 37 #include "llvm/Transforms/Scalar.h" 38 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 39 #include <algorithm> 40 #include <cassert> 41 42 using namespace llvm; 43 44 #define DEBUG_TYPE "scalarize-masked-mem-intrin" 45 46 namespace { 47 48 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass { 49 public: 50 static char ID; // Pass identification, replacement for typeid 51 52 explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) { 53 initializeScalarizeMaskedMemIntrinLegacyPassPass( 54 *PassRegistry::getPassRegistry()); 55 } 56 57 bool runOnFunction(Function &F) override; 58 59 StringRef getPassName() const override { 60 return "Scalarize Masked Memory Intrinsics"; 61 } 62 63 void getAnalysisUsage(AnalysisUsage &AU) const override { 64 AU.addRequired<TargetTransformInfoWrapperPass>(); 65 AU.addPreserved<DominatorTreeWrapperPass>(); 66 } 67 }; 68 69 } // end anonymous namespace 70 71 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 72 const TargetTransformInfo &TTI, const DataLayout &DL, 73 DomTreeUpdater *DTU); 74 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 75 const TargetTransformInfo &TTI, 76 const DataLayout &DL, DomTreeUpdater *DTU); 77 78 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0; 79 80 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 81 "Scalarize unsupported masked memory intrinsics", false, 82 false) 83 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 84 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass) 85 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE, 86 "Scalarize unsupported masked memory intrinsics", false, 87 false) 88 89 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() { 90 return new ScalarizeMaskedMemIntrinLegacyPass(); 91 } 92 93 static bool isConstantIntVector(Value *Mask) { 94 Constant *C = dyn_cast<Constant>(Mask); 95 if (!C) 96 return false; 97 98 unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements(); 99 for (unsigned i = 0; i != NumElts; ++i) { 100 Constant *CElt = C->getAggregateElement(i); 101 if (!CElt || !isa<ConstantInt>(CElt)) 102 return false; 103 } 104 105 return true; 106 } 107 108 // Translate a masked load intrinsic like 109 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, 110 // <16 x i1> %mask, <16 x i32> %passthru) 111 // to a chain of basic blocks, with loading element one-by-one if 112 // the appropriate mask bit is set 113 // 114 // %1 = bitcast i8* %addr to i32* 115 // %2 = extractelement <16 x i1> %mask, i32 0 116 // br i1 %2, label %cond.load, label %else 117 // 118 // cond.load: ; preds = %0 119 // %3 = getelementptr i32* %1, i32 0 120 // %4 = load i32* %3 121 // %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0 122 // br label %else 123 // 124 // else: ; preds = %0, %cond.load 125 // %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ] 126 // %6 = extractelement <16 x i1> %mask, i32 1 127 // br i1 %6, label %cond.load1, label %else2 128 // 129 // cond.load1: ; preds = %else 130 // %7 = getelementptr i32* %1, i32 1 131 // %8 = load i32* %7 132 // %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1 133 // br label %else2 134 // 135 // else2: ; preds = %else, %cond.load1 136 // %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ] 137 // %10 = extractelement <16 x i1> %mask, i32 2 138 // br i1 %10, label %cond.load4, label %else5 139 // 140 static void scalarizeMaskedLoad(CallInst *CI, DomTreeUpdater *DTU, 141 bool &ModifiedDT) { 142 Value *Ptr = CI->getArgOperand(0); 143 Value *Alignment = CI->getArgOperand(1); 144 Value *Mask = CI->getArgOperand(2); 145 Value *Src0 = CI->getArgOperand(3); 146 147 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 148 VectorType *VecType = cast<FixedVectorType>(CI->getType()); 149 150 Type *EltTy = VecType->getElementType(); 151 152 IRBuilder<> Builder(CI->getContext()); 153 Instruction *InsertPt = CI; 154 BasicBlock *IfBlock = CI->getParent(); 155 156 Builder.SetInsertPoint(InsertPt); 157 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 158 159 // Short-cut if the mask is all-true. 160 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 161 Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal); 162 CI->replaceAllUsesWith(NewI); 163 CI->eraseFromParent(); 164 return; 165 } 166 167 // Adjust alignment for the scalar instruction. 168 const Align AdjustedAlignVal = 169 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 170 // Bitcast %addr from i8* to EltTy* 171 Type *NewPtrType = 172 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 173 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 174 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 175 176 // The result vector 177 Value *VResult = Src0; 178 179 if (isConstantIntVector(Mask)) { 180 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 181 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 182 continue; 183 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 184 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 185 VResult = Builder.CreateInsertElement(VResult, Load, Idx); 186 } 187 CI->replaceAllUsesWith(VResult); 188 CI->eraseFromParent(); 189 return; 190 } 191 192 // If the mask is not v1i1, use scalar bit test operations. This generates 193 // better results on X86 at least. 194 Value *SclrMask; 195 if (VectorWidth != 1) { 196 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 197 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 198 } 199 200 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 201 // Fill the "else" block, created in the previous iteration 202 // 203 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 204 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 205 // %cond = icmp ne i16 %mask_1, 0 206 // br i1 %mask_1, label %cond.load, label %else 207 // 208 Value *Predicate; 209 if (VectorWidth != 1) { 210 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 211 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 212 Builder.getIntN(VectorWidth, 0)); 213 } else { 214 Predicate = Builder.CreateExtractElement(Mask, Idx); 215 } 216 217 // Create "cond" block 218 // 219 // %EltAddr = getelementptr i32* %1, i32 0 220 // %Elt = load i32* %EltAddr 221 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 222 // 223 Instruction *ThenTerm = 224 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 225 /*BranchWeights=*/nullptr, DTU); 226 227 BasicBlock *CondBlock = ThenTerm->getParent(); 228 CondBlock->setName("cond.load"); 229 230 Builder.SetInsertPoint(CondBlock->getTerminator()); 231 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 232 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal); 233 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 234 235 // Create "else" block, fill it in the next iteration 236 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 237 NewIfBlock->setName("else"); 238 BasicBlock *PrevIfBlock = IfBlock; 239 IfBlock = NewIfBlock; 240 241 // Create the phi to join the new and previous value. 242 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 243 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 244 Phi->addIncoming(NewVResult, CondBlock); 245 Phi->addIncoming(VResult, PrevIfBlock); 246 VResult = Phi; 247 } 248 249 CI->replaceAllUsesWith(VResult); 250 CI->eraseFromParent(); 251 252 ModifiedDT = true; 253 } 254 255 // Translate a masked store intrinsic, like 256 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, 257 // <16 x i1> %mask) 258 // to a chain of basic blocks, that stores element one-by-one if 259 // the appropriate mask bit is set 260 // 261 // %1 = bitcast i8* %addr to i32* 262 // %2 = extractelement <16 x i1> %mask, i32 0 263 // br i1 %2, label %cond.store, label %else 264 // 265 // cond.store: ; preds = %0 266 // %3 = extractelement <16 x i32> %val, i32 0 267 // %4 = getelementptr i32* %1, i32 0 268 // store i32 %3, i32* %4 269 // br label %else 270 // 271 // else: ; preds = %0, %cond.store 272 // %5 = extractelement <16 x i1> %mask, i32 1 273 // br i1 %5, label %cond.store1, label %else2 274 // 275 // cond.store1: ; preds = %else 276 // %6 = extractelement <16 x i32> %val, i32 1 277 // %7 = getelementptr i32* %1, i32 1 278 // store i32 %6, i32* %7 279 // br label %else2 280 // . . . 281 static void scalarizeMaskedStore(CallInst *CI, DomTreeUpdater *DTU, 282 bool &ModifiedDT) { 283 Value *Src = CI->getArgOperand(0); 284 Value *Ptr = CI->getArgOperand(1); 285 Value *Alignment = CI->getArgOperand(2); 286 Value *Mask = CI->getArgOperand(3); 287 288 const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue(); 289 auto *VecType = cast<VectorType>(Src->getType()); 290 291 Type *EltTy = VecType->getElementType(); 292 293 IRBuilder<> Builder(CI->getContext()); 294 Instruction *InsertPt = CI; 295 Builder.SetInsertPoint(InsertPt); 296 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 297 298 // Short-cut if the mask is all-true. 299 if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) { 300 Builder.CreateAlignedStore(Src, Ptr, AlignVal); 301 CI->eraseFromParent(); 302 return; 303 } 304 305 // Adjust alignment for the scalar instruction. 306 const Align AdjustedAlignVal = 307 commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8); 308 // Bitcast %addr from i8* to EltTy* 309 Type *NewPtrType = 310 EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace()); 311 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 312 unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements(); 313 314 if (isConstantIntVector(Mask)) { 315 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 316 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 317 continue; 318 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 319 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 320 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 321 } 322 CI->eraseFromParent(); 323 return; 324 } 325 326 // If the mask is not v1i1, use scalar bit test operations. This generates 327 // better results on X86 at least. 328 Value *SclrMask; 329 if (VectorWidth != 1) { 330 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 331 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 332 } 333 334 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 335 // Fill the "else" block, created in the previous iteration 336 // 337 // %mask_1 = and i16 %scalar_mask, i32 1 << Idx 338 // %cond = icmp ne i16 %mask_1, 0 339 // br i1 %mask_1, label %cond.store, label %else 340 // 341 Value *Predicate; 342 if (VectorWidth != 1) { 343 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 344 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 345 Builder.getIntN(VectorWidth, 0)); 346 } else { 347 Predicate = Builder.CreateExtractElement(Mask, Idx); 348 } 349 350 // Create "cond" block 351 // 352 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 353 // %EltAddr = getelementptr i32* %1, i32 0 354 // %store i32 %OneElt, i32* %EltAddr 355 // 356 Instruction *ThenTerm = 357 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 358 /*BranchWeights=*/nullptr, DTU); 359 360 BasicBlock *CondBlock = ThenTerm->getParent(); 361 CondBlock->setName("cond.store"); 362 363 Builder.SetInsertPoint(CondBlock->getTerminator()); 364 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 365 Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx); 366 Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal); 367 368 // Create "else" block, fill it in the next iteration 369 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 370 NewIfBlock->setName("else"); 371 372 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 373 } 374 CI->eraseFromParent(); 375 376 ModifiedDT = true; 377 } 378 379 // Translate a masked gather intrinsic like 380 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, 381 // <16 x i1> %Mask, <16 x i32> %Src) 382 // to a chain of basic blocks, with loading element one-by-one if 383 // the appropriate mask bit is set 384 // 385 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind 386 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 387 // br i1 %Mask0, label %cond.load, label %else 388 // 389 // cond.load: 390 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 391 // %Load0 = load i32, i32* %Ptr0, align 4 392 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0 393 // br label %else 394 // 395 // else: 396 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0] 397 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 398 // br i1 %Mask1, label %cond.load1, label %else2 399 // 400 // cond.load1: 401 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 402 // %Load1 = load i32, i32* %Ptr1, align 4 403 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1 404 // br label %else2 405 // . . . 406 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src 407 // ret <16 x i32> %Result 408 static void scalarizeMaskedGather(CallInst *CI, DomTreeUpdater *DTU, 409 bool &ModifiedDT) { 410 Value *Ptrs = CI->getArgOperand(0); 411 Value *Alignment = CI->getArgOperand(1); 412 Value *Mask = CI->getArgOperand(2); 413 Value *Src0 = CI->getArgOperand(3); 414 415 auto *VecType = cast<FixedVectorType>(CI->getType()); 416 Type *EltTy = VecType->getElementType(); 417 418 IRBuilder<> Builder(CI->getContext()); 419 Instruction *InsertPt = CI; 420 BasicBlock *IfBlock = CI->getParent(); 421 Builder.SetInsertPoint(InsertPt); 422 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 423 424 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 425 426 // The result vector 427 Value *VResult = Src0; 428 unsigned VectorWidth = VecType->getNumElements(); 429 430 // Shorten the way if the mask is a vector of constants. 431 if (isConstantIntVector(Mask)) { 432 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 433 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 434 continue; 435 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 436 LoadInst *Load = 437 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 438 VResult = 439 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 440 } 441 CI->replaceAllUsesWith(VResult); 442 CI->eraseFromParent(); 443 return; 444 } 445 446 // If the mask is not v1i1, use scalar bit test operations. This generates 447 // better results on X86 at least. 448 Value *SclrMask; 449 if (VectorWidth != 1) { 450 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 451 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 452 } 453 454 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 455 // Fill the "else" block, created in the previous iteration 456 // 457 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 458 // %cond = icmp ne i16 %mask_1, 0 459 // br i1 %Mask1, label %cond.load, label %else 460 // 461 462 Value *Predicate; 463 if (VectorWidth != 1) { 464 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 465 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 466 Builder.getIntN(VectorWidth, 0)); 467 } else { 468 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 469 } 470 471 // Create "cond" block 472 // 473 // %EltAddr = getelementptr i32* %1, i32 0 474 // %Elt = load i32* %EltAddr 475 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 476 // 477 Instruction *ThenTerm = 478 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 479 /*BranchWeights=*/nullptr, DTU); 480 481 BasicBlock *CondBlock = ThenTerm->getParent(); 482 CondBlock->setName("cond.load"); 483 484 Builder.SetInsertPoint(CondBlock->getTerminator()); 485 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 486 LoadInst *Load = 487 Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx)); 488 Value *NewVResult = 489 Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx)); 490 491 // Create "else" block, fill it in the next iteration 492 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 493 NewIfBlock->setName("else"); 494 BasicBlock *PrevIfBlock = IfBlock; 495 IfBlock = NewIfBlock; 496 497 // Create the phi to join the new and previous value. 498 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 499 PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 500 Phi->addIncoming(NewVResult, CondBlock); 501 Phi->addIncoming(VResult, PrevIfBlock); 502 VResult = Phi; 503 } 504 505 CI->replaceAllUsesWith(VResult); 506 CI->eraseFromParent(); 507 508 ModifiedDT = true; 509 } 510 511 // Translate a masked scatter intrinsic, like 512 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, 513 // <16 x i1> %Mask) 514 // to a chain of basic blocks, that stores element one-by-one if 515 // the appropriate mask bit is set. 516 // 517 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind 518 // %Mask0 = extractelement <16 x i1> %Mask, i32 0 519 // br i1 %Mask0, label %cond.store, label %else 520 // 521 // cond.store: 522 // %Elt0 = extractelement <16 x i32> %Src, i32 0 523 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 524 // store i32 %Elt0, i32* %Ptr0, align 4 525 // br label %else 526 // 527 // else: 528 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 529 // br i1 %Mask1, label %cond.store1, label %else2 530 // 531 // cond.store1: 532 // %Elt1 = extractelement <16 x i32> %Src, i32 1 533 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 534 // store i32 %Elt1, i32* %Ptr1, align 4 535 // br label %else2 536 // . . . 537 static void scalarizeMaskedScatter(CallInst *CI, DomTreeUpdater *DTU, 538 bool &ModifiedDT) { 539 Value *Src = CI->getArgOperand(0); 540 Value *Ptrs = CI->getArgOperand(1); 541 Value *Alignment = CI->getArgOperand(2); 542 Value *Mask = CI->getArgOperand(3); 543 544 auto *SrcFVTy = cast<FixedVectorType>(Src->getType()); 545 546 assert( 547 isa<VectorType>(Ptrs->getType()) && 548 isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) && 549 "Vector of pointers is expected in masked scatter intrinsic"); 550 551 IRBuilder<> Builder(CI->getContext()); 552 Instruction *InsertPt = CI; 553 Builder.SetInsertPoint(InsertPt); 554 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 555 556 MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue(); 557 unsigned VectorWidth = SrcFVTy->getNumElements(); 558 559 // Shorten the way if the mask is a vector of constants. 560 if (isConstantIntVector(Mask)) { 561 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 562 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 563 continue; 564 Value *OneElt = 565 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 566 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 567 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 568 } 569 CI->eraseFromParent(); 570 return; 571 } 572 573 // If the mask is not v1i1, use scalar bit test operations. This generates 574 // better results on X86 at least. 575 Value *SclrMask; 576 if (VectorWidth != 1) { 577 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 578 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 579 } 580 581 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 582 // Fill the "else" block, created in the previous iteration 583 // 584 // %Mask1 = and i16 %scalar_mask, i32 1 << Idx 585 // %cond = icmp ne i16 %mask_1, 0 586 // br i1 %Mask1, label %cond.store, label %else 587 // 588 Value *Predicate; 589 if (VectorWidth != 1) { 590 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 591 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 592 Builder.getIntN(VectorWidth, 0)); 593 } else { 594 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 595 } 596 597 // Create "cond" block 598 // 599 // %Elt1 = extractelement <16 x i32> %Src, i32 1 600 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 601 // %store i32 %Elt1, i32* %Ptr1 602 // 603 Instruction *ThenTerm = 604 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 605 /*BranchWeights=*/nullptr, DTU); 606 607 BasicBlock *CondBlock = ThenTerm->getParent(); 608 CondBlock->setName("cond.store"); 609 610 Builder.SetInsertPoint(CondBlock->getTerminator()); 611 Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 612 Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx)); 613 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 614 615 // Create "else" block, fill it in the next iteration 616 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 617 NewIfBlock->setName("else"); 618 619 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 620 } 621 CI->eraseFromParent(); 622 623 ModifiedDT = true; 624 } 625 626 static void scalarizeMaskedExpandLoad(CallInst *CI, DomTreeUpdater *DTU, 627 bool &ModifiedDT) { 628 Value *Ptr = CI->getArgOperand(0); 629 Value *Mask = CI->getArgOperand(1); 630 Value *PassThru = CI->getArgOperand(2); 631 632 auto *VecType = cast<FixedVectorType>(CI->getType()); 633 634 Type *EltTy = VecType->getElementType(); 635 636 IRBuilder<> Builder(CI->getContext()); 637 Instruction *InsertPt = CI; 638 BasicBlock *IfBlock = CI->getParent(); 639 640 Builder.SetInsertPoint(InsertPt); 641 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 642 643 unsigned VectorWidth = VecType->getNumElements(); 644 645 // The result vector 646 Value *VResult = PassThru; 647 648 // Shorten the way if the mask is a vector of constants. 649 // Create a build_vector pattern, with loads/undefs as necessary and then 650 // shuffle blend with the pass through value. 651 if (isConstantIntVector(Mask)) { 652 unsigned MemIndex = 0; 653 VResult = UndefValue::get(VecType); 654 SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem); 655 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 656 Value *InsertElt; 657 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) { 658 InsertElt = UndefValue::get(EltTy); 659 ShuffleMask[Idx] = Idx + VectorWidth; 660 } else { 661 Value *NewPtr = 662 Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 663 InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1), 664 "Load" + Twine(Idx)); 665 ShuffleMask[Idx] = Idx; 666 ++MemIndex; 667 } 668 VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx, 669 "Res" + Twine(Idx)); 670 } 671 VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask); 672 CI->replaceAllUsesWith(VResult); 673 CI->eraseFromParent(); 674 return; 675 } 676 677 // If the mask is not v1i1, use scalar bit test operations. This generates 678 // better results on X86 at least. 679 Value *SclrMask; 680 if (VectorWidth != 1) { 681 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 682 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 683 } 684 685 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 686 // Fill the "else" block, created in the previous iteration 687 // 688 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 689 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 690 // br i1 %mask_1, label %cond.load, label %else 691 // 692 693 Value *Predicate; 694 if (VectorWidth != 1) { 695 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 696 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 697 Builder.getIntN(VectorWidth, 0)); 698 } else { 699 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 700 } 701 702 // Create "cond" block 703 // 704 // %EltAddr = getelementptr i32* %1, i32 0 705 // %Elt = load i32* %EltAddr 706 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 707 // 708 Instruction *ThenTerm = 709 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 710 /*BranchWeights=*/nullptr, DTU); 711 712 BasicBlock *CondBlock = ThenTerm->getParent(); 713 CondBlock->setName("cond.load"); 714 715 Builder.SetInsertPoint(CondBlock->getTerminator()); 716 LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1)); 717 Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx); 718 719 // Move the pointer if there are more blocks to come. 720 Value *NewPtr; 721 if ((Idx + 1) != VectorWidth) 722 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 723 724 // Create "else" block, fill it in the next iteration 725 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 726 NewIfBlock->setName("else"); 727 BasicBlock *PrevIfBlock = IfBlock; 728 IfBlock = NewIfBlock; 729 730 // Create the phi to join the new and previous value. 731 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 732 PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 733 ResultPhi->addIncoming(NewVResult, CondBlock); 734 ResultPhi->addIncoming(VResult, PrevIfBlock); 735 VResult = ResultPhi; 736 737 // Add a PHI for the pointer if this isn't the last iteration. 738 if ((Idx + 1) != VectorWidth) { 739 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 740 PtrPhi->addIncoming(NewPtr, CondBlock); 741 PtrPhi->addIncoming(Ptr, PrevIfBlock); 742 Ptr = PtrPhi; 743 } 744 } 745 746 CI->replaceAllUsesWith(VResult); 747 CI->eraseFromParent(); 748 749 ModifiedDT = true; 750 } 751 752 static void scalarizeMaskedCompressStore(CallInst *CI, DomTreeUpdater *DTU, 753 bool &ModifiedDT) { 754 Value *Src = CI->getArgOperand(0); 755 Value *Ptr = CI->getArgOperand(1); 756 Value *Mask = CI->getArgOperand(2); 757 758 auto *VecType = cast<FixedVectorType>(Src->getType()); 759 760 IRBuilder<> Builder(CI->getContext()); 761 Instruction *InsertPt = CI; 762 BasicBlock *IfBlock = CI->getParent(); 763 764 Builder.SetInsertPoint(InsertPt); 765 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 766 767 Type *EltTy = VecType->getElementType(); 768 769 unsigned VectorWidth = VecType->getNumElements(); 770 771 // Shorten the way if the mask is a vector of constants. 772 if (isConstantIntVector(Mask)) { 773 unsigned MemIndex = 0; 774 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 775 if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) 776 continue; 777 Value *OneElt = 778 Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx)); 779 Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex); 780 Builder.CreateAlignedStore(OneElt, NewPtr, Align(1)); 781 ++MemIndex; 782 } 783 CI->eraseFromParent(); 784 return; 785 } 786 787 // If the mask is not v1i1, use scalar bit test operations. This generates 788 // better results on X86 at least. 789 Value *SclrMask; 790 if (VectorWidth != 1) { 791 Type *SclrMaskTy = Builder.getIntNTy(VectorWidth); 792 SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask"); 793 } 794 795 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 796 // Fill the "else" block, created in the previous iteration 797 // 798 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 799 // br i1 %mask_1, label %cond.store, label %else 800 // 801 Value *Predicate; 802 if (VectorWidth != 1) { 803 Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx)); 804 Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask), 805 Builder.getIntN(VectorWidth, 0)); 806 } else { 807 Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx)); 808 } 809 810 // Create "cond" block 811 // 812 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 813 // %EltAddr = getelementptr i32* %1, i32 0 814 // %store i32 %OneElt, i32* %EltAddr 815 // 816 Instruction *ThenTerm = 817 SplitBlockAndInsertIfThen(Predicate, InsertPt, /*Unreachable=*/false, 818 /*BranchWeights=*/nullptr, DTU); 819 820 BasicBlock *CondBlock = ThenTerm->getParent(); 821 CondBlock->setName("cond.store"); 822 823 Builder.SetInsertPoint(CondBlock->getTerminator()); 824 Value *OneElt = Builder.CreateExtractElement(Src, Idx); 825 Builder.CreateAlignedStore(OneElt, Ptr, Align(1)); 826 827 // Move the pointer if there are more blocks to come. 828 Value *NewPtr; 829 if ((Idx + 1) != VectorWidth) 830 NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1); 831 832 // Create "else" block, fill it in the next iteration 833 BasicBlock *NewIfBlock = ThenTerm->getSuccessor(0); 834 NewIfBlock->setName("else"); 835 BasicBlock *PrevIfBlock = IfBlock; 836 IfBlock = NewIfBlock; 837 838 Builder.SetInsertPoint(NewIfBlock, NewIfBlock->begin()); 839 840 // Add a PHI for the pointer if this isn't the last iteration. 841 if ((Idx + 1) != VectorWidth) { 842 PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else"); 843 PtrPhi->addIncoming(NewPtr, CondBlock); 844 PtrPhi->addIncoming(Ptr, PrevIfBlock); 845 Ptr = PtrPhi; 846 } 847 } 848 CI->eraseFromParent(); 849 850 ModifiedDT = true; 851 } 852 853 static bool runImpl(Function &F, const TargetTransformInfo &TTI, 854 DominatorTree *DT) { 855 Optional<DomTreeUpdater> DTU; 856 if (DT) 857 DTU.emplace(DT, DomTreeUpdater::UpdateStrategy::Lazy); 858 859 bool EverMadeChange = false; 860 bool MadeChange = true; 861 auto &DL = F.getParent()->getDataLayout(); 862 while (MadeChange) { 863 MadeChange = false; 864 for (Function::iterator I = F.begin(); I != F.end();) { 865 BasicBlock *BB = &*I++; 866 bool ModifiedDTOnIteration = false; 867 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL, 868 DTU.hasValue() ? DTU.getPointer() : nullptr); 869 870 871 // Restart BB iteration if the dominator tree of the Function was changed 872 if (ModifiedDTOnIteration) 873 break; 874 } 875 876 EverMadeChange |= MadeChange; 877 } 878 return EverMadeChange; 879 } 880 881 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) { 882 auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 883 DominatorTree *DT = nullptr; 884 if (auto *DTWP = getAnalysisIfAvailable<DominatorTreeWrapperPass>()) 885 DT = &DTWP->getDomTree(); 886 return runImpl(F, TTI, DT); 887 } 888 889 PreservedAnalyses 890 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) { 891 auto &TTI = AM.getResult<TargetIRAnalysis>(F); 892 auto *DT = AM.getCachedResult<DominatorTreeAnalysis>(F); 893 if (!runImpl(F, TTI, DT)) 894 return PreservedAnalyses::all(); 895 PreservedAnalyses PA; 896 PA.preserve<TargetIRAnalysis>(); 897 PA.preserve<DominatorTreeAnalysis>(); 898 return PA; 899 } 900 901 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT, 902 const TargetTransformInfo &TTI, const DataLayout &DL, 903 DomTreeUpdater *DTU) { 904 bool MadeChange = false; 905 906 BasicBlock::iterator CurInstIterator = BB.begin(); 907 while (CurInstIterator != BB.end()) { 908 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) 909 MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL, DTU); 910 if (ModifiedDT) 911 return true; 912 } 913 914 return MadeChange; 915 } 916 917 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT, 918 const TargetTransformInfo &TTI, 919 const DataLayout &DL, DomTreeUpdater *DTU) { 920 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 921 if (II) { 922 // The scalarization code below does not work for scalable vectors. 923 if (isa<ScalableVectorType>(II->getType()) || 924 any_of(II->arg_operands(), 925 [](Value *V) { return isa<ScalableVectorType>(V->getType()); })) 926 return false; 927 928 switch (II->getIntrinsicID()) { 929 default: 930 break; 931 case Intrinsic::masked_load: 932 // Scalarize unsupported vector masked load 933 if (TTI.isLegalMaskedLoad( 934 CI->getType(), 935 cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue())) 936 return false; 937 scalarizeMaskedLoad(CI, DTU, ModifiedDT); 938 return true; 939 case Intrinsic::masked_store: 940 if (TTI.isLegalMaskedStore( 941 CI->getArgOperand(0)->getType(), 942 cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue())) 943 return false; 944 scalarizeMaskedStore(CI, DTU, ModifiedDT); 945 return true; 946 case Intrinsic::masked_gather: { 947 unsigned AlignmentInt = 948 cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue(); 949 Type *LoadTy = CI->getType(); 950 Align Alignment = 951 DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy); 952 if (TTI.isLegalMaskedGather(LoadTy, Alignment)) 953 return false; 954 scalarizeMaskedGather(CI, DTU, ModifiedDT); 955 return true; 956 } 957 case Intrinsic::masked_scatter: { 958 unsigned AlignmentInt = 959 cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue(); 960 Type *StoreTy = CI->getArgOperand(0)->getType(); 961 Align Alignment = 962 DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy); 963 if (TTI.isLegalMaskedScatter(StoreTy, Alignment)) 964 return false; 965 scalarizeMaskedScatter(CI, DTU, ModifiedDT); 966 return true; 967 } 968 case Intrinsic::masked_expandload: 969 if (TTI.isLegalMaskedExpandLoad(CI->getType())) 970 return false; 971 scalarizeMaskedExpandLoad(CI, DTU, ModifiedDT); 972 return true; 973 case Intrinsic::masked_compressstore: 974 if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType())) 975 return false; 976 scalarizeMaskedCompressStore(CI, DTU, ModifiedDT); 977 return true; 978 } 979 } 980 981 return false; 982 } 983