1 //=== ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ===// 2 //=== instrinsics ===// 3 // 4 // The LLVM Compiler Infrastructure 5 // 6 // This file is distributed under the University of Illinois Open Source 7 // License. See LICENSE.TXT for details. 8 // 9 //===----------------------------------------------------------------------===// 10 // 11 // This pass replaces masked memory intrinsics - when unsupported by the target 12 // - with a chain of basic blocks, that deal with the elements one-by-one if the 13 // appropriate mask bit is set. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "llvm/Analysis/TargetTransformInfo.h" 18 #include "llvm/IR/IRBuilder.h" 19 #include "llvm/Target/TargetSubtargetInfo.h" 20 21 using namespace llvm; 22 23 #define DEBUG_TYPE "scalarize-masked-mem-intrin" 24 25 namespace { 26 27 class ScalarizeMaskedMemIntrin : public FunctionPass { 28 const TargetTransformInfo *TTI; 29 30 public: 31 static char ID; // Pass identification, replacement for typeid 32 explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID), TTI(nullptr) { 33 initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry()); 34 } 35 bool runOnFunction(Function &F) override; 36 37 StringRef getPassName() const override { 38 return "Scalarize Masked Memory Intrinsics"; 39 } 40 41 void getAnalysisUsage(AnalysisUsage &AU) const override { 42 AU.addRequired<TargetTransformInfoWrapperPass>(); 43 } 44 45 private: 46 bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT); 47 bool optimizeCallInst(CallInst *CI, bool &ModifiedDT); 48 }; 49 } // namespace 50 51 char ScalarizeMaskedMemIntrin::ID = 0; 52 INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE, 53 "Scalarize unsupported masked memory intrinsics", false, false) 54 55 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() { 56 return new ScalarizeMaskedMemIntrin(); 57 } 58 59 // Translate a masked load intrinsic like 60 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align, 61 // <16 x i1> %mask, <16 x i32> %passthru) 62 // to a chain of basic blocks, with loading element one-by-one if 63 // the appropriate mask bit is set 64 // 65 // %1 = bitcast i8* %addr to i32* 66 // %2 = extractelement <16 x i1> %mask, i32 0 67 // %3 = icmp eq i1 %2, true 68 // br i1 %3, label %cond.load, label %else 69 // 70 // cond.load: ; preds = %0 71 // %4 = getelementptr i32* %1, i32 0 72 // %5 = load i32* %4 73 // %6 = insertelement <16 x i32> undef, i32 %5, i32 0 74 // br label %else 75 // 76 // else: ; preds = %0, %cond.load 77 // %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ] 78 // %7 = extractelement <16 x i1> %mask, i32 1 79 // %8 = icmp eq i1 %7, true 80 // br i1 %8, label %cond.load1, label %else2 81 // 82 // cond.load1: ; preds = %else 83 // %9 = getelementptr i32* %1, i32 1 84 // %10 = load i32* %9 85 // %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1 86 // br label %else2 87 // 88 // else2: ; preds = %else, %cond.load1 89 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 90 // %12 = extractelement <16 x i1> %mask, i32 2 91 // %13 = icmp eq i1 %12, true 92 // br i1 %13, label %cond.load4, label %else5 93 // 94 static void scalarizeMaskedLoad(CallInst *CI) { 95 Value *Ptr = CI->getArgOperand(0); 96 Value *Alignment = CI->getArgOperand(1); 97 Value *Mask = CI->getArgOperand(2); 98 Value *Src0 = CI->getArgOperand(3); 99 100 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 101 VectorType *VecType = dyn_cast<VectorType>(CI->getType()); 102 assert(VecType && "Unexpected return type of masked load intrinsic"); 103 104 Type *EltTy = CI->getType()->getVectorElementType(); 105 106 IRBuilder<> Builder(CI->getContext()); 107 Instruction *InsertPt = CI; 108 BasicBlock *IfBlock = CI->getParent(); 109 BasicBlock *CondBlock = nullptr; 110 BasicBlock *PrevIfBlock = CI->getParent(); 111 112 Builder.SetInsertPoint(InsertPt); 113 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 114 115 // Short-cut if the mask is all-true. 116 bool IsAllOnesMask = 117 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue(); 118 119 if (IsAllOnesMask) { 120 Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal); 121 CI->replaceAllUsesWith(NewI); 122 CI->eraseFromParent(); 123 return; 124 } 125 126 // Adjust alignment for the scalar instruction. 127 AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8); 128 // Bitcast %addr fron i8* to EltTy* 129 Type *NewPtrType = 130 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace()); 131 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 132 unsigned VectorWidth = VecType->getNumElements(); 133 134 Value *UndefVal = UndefValue::get(VecType); 135 136 // The result vector 137 Value *VResult = UndefVal; 138 139 if (isa<ConstantVector>(Mask)) { 140 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 141 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue()) 142 continue; 143 Value *Gep = 144 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); 145 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal); 146 VResult = 147 Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx)); 148 } 149 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0); 150 CI->replaceAllUsesWith(NewI); 151 CI->eraseFromParent(); 152 return; 153 } 154 155 PHINode *Phi = nullptr; 156 Value *PrevPhi = UndefVal; 157 158 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 159 160 // Fill the "else" block, created in the previous iteration 161 // 162 // %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ] 163 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 164 // %to_load = icmp eq i1 %mask_1, true 165 // br i1 %to_load, label %cond.load, label %else 166 // 167 if (Idx > 0) { 168 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 169 Phi->addIncoming(VResult, CondBlock); 170 Phi->addIncoming(PrevPhi, PrevIfBlock); 171 PrevPhi = Phi; 172 VResult = Phi; 173 } 174 175 Value *Predicate = 176 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx)); 177 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate, 178 ConstantInt::get(Predicate->getType(), 1)); 179 180 // Create "cond" block 181 // 182 // %EltAddr = getelementptr i32* %1, i32 0 183 // %Elt = load i32* %EltAddr 184 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 185 // 186 CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load"); 187 Builder.SetInsertPoint(InsertPt); 188 189 Value *Gep = 190 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); 191 LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal); 192 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx)); 193 194 // Create "else" block, fill it in the next iteration 195 BasicBlock *NewIfBlock = 196 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 197 Builder.SetInsertPoint(InsertPt); 198 Instruction *OldBr = IfBlock->getTerminator(); 199 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); 200 OldBr->eraseFromParent(); 201 PrevIfBlock = IfBlock; 202 IfBlock = NewIfBlock; 203 } 204 205 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select"); 206 Phi->addIncoming(VResult, CondBlock); 207 Phi->addIncoming(PrevPhi, PrevIfBlock); 208 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0); 209 CI->replaceAllUsesWith(NewI); 210 CI->eraseFromParent(); 211 } 212 213 // Translate a masked store intrinsic, like 214 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align, 215 // <16 x i1> %mask) 216 // to a chain of basic blocks, that stores element one-by-one if 217 // the appropriate mask bit is set 218 // 219 // %1 = bitcast i8* %addr to i32* 220 // %2 = extractelement <16 x i1> %mask, i32 0 221 // %3 = icmp eq i1 %2, true 222 // br i1 %3, label %cond.store, label %else 223 // 224 // cond.store: ; preds = %0 225 // %4 = extractelement <16 x i32> %val, i32 0 226 // %5 = getelementptr i32* %1, i32 0 227 // store i32 %4, i32* %5 228 // br label %else 229 // 230 // else: ; preds = %0, %cond.store 231 // %6 = extractelement <16 x i1> %mask, i32 1 232 // %7 = icmp eq i1 %6, true 233 // br i1 %7, label %cond.store1, label %else2 234 // 235 // cond.store1: ; preds = %else 236 // %8 = extractelement <16 x i32> %val, i32 1 237 // %9 = getelementptr i32* %1, i32 1 238 // store i32 %8, i32* %9 239 // br label %else2 240 // . . . 241 static void scalarizeMaskedStore(CallInst *CI) { 242 Value *Src = CI->getArgOperand(0); 243 Value *Ptr = CI->getArgOperand(1); 244 Value *Alignment = CI->getArgOperand(2); 245 Value *Mask = CI->getArgOperand(3); 246 247 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 248 VectorType *VecType = dyn_cast<VectorType>(Src->getType()); 249 assert(VecType && "Unexpected data type in masked store intrinsic"); 250 251 Type *EltTy = VecType->getElementType(); 252 253 IRBuilder<> Builder(CI->getContext()); 254 Instruction *InsertPt = CI; 255 BasicBlock *IfBlock = CI->getParent(); 256 Builder.SetInsertPoint(InsertPt); 257 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 258 259 // Short-cut if the mask is all-true. 260 bool IsAllOnesMask = 261 isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue(); 262 263 if (IsAllOnesMask) { 264 Builder.CreateAlignedStore(Src, Ptr, AlignVal); 265 CI->eraseFromParent(); 266 return; 267 } 268 269 // Adjust alignment for the scalar instruction. 270 AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8); 271 // Bitcast %addr fron i8* to EltTy* 272 Type *NewPtrType = 273 EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace()); 274 Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType); 275 unsigned VectorWidth = VecType->getNumElements(); 276 277 if (isa<ConstantVector>(Mask)) { 278 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 279 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue()) 280 continue; 281 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx)); 282 Value *Gep = 283 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); 284 Builder.CreateAlignedStore(OneElt, Gep, AlignVal); 285 } 286 CI->eraseFromParent(); 287 return; 288 } 289 290 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 291 292 // Fill the "else" block, created in the previous iteration 293 // 294 // %mask_1 = extractelement <16 x i1> %mask, i32 Idx 295 // %to_store = icmp eq i1 %mask_1, true 296 // br i1 %to_store, label %cond.store, label %else 297 // 298 Value *Predicate = 299 Builder.CreateExtractElement(Mask, Builder.getInt32(Idx)); 300 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate, 301 ConstantInt::get(Predicate->getType(), 1)); 302 303 // Create "cond" block 304 // 305 // %OneElt = extractelement <16 x i32> %Src, i32 Idx 306 // %EltAddr = getelementptr i32* %1, i32 0 307 // %store i32 %OneElt, i32* %EltAddr 308 // 309 BasicBlock *CondBlock = 310 IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store"); 311 Builder.SetInsertPoint(InsertPt); 312 313 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx)); 314 Value *Gep = 315 Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx)); 316 Builder.CreateAlignedStore(OneElt, Gep, AlignVal); 317 318 // Create "else" block, fill it in the next iteration 319 BasicBlock *NewIfBlock = 320 CondBlock->splitBasicBlock(InsertPt->getIterator(), "else"); 321 Builder.SetInsertPoint(InsertPt); 322 Instruction *OldBr = IfBlock->getTerminator(); 323 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); 324 OldBr->eraseFromParent(); 325 IfBlock = NewIfBlock; 326 } 327 CI->eraseFromParent(); 328 } 329 330 // Translate a masked gather intrinsic like 331 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4, 332 // <16 x i1> %Mask, <16 x i32> %Src) 333 // to a chain of basic blocks, with loading element one-by-one if 334 // the appropriate mask bit is set 335 // 336 // % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind 337 // % Mask0 = extractelement <16 x i1> %Mask, i32 0 338 // % ToLoad0 = icmp eq i1 % Mask0, true 339 // br i1 % ToLoad0, label %cond.load, label %else 340 // 341 // cond.load: 342 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 343 // % Load0 = load i32, i32* % Ptr0, align 4 344 // % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0 345 // br label %else 346 // 347 // else: 348 // %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0] 349 // % Mask1 = extractelement <16 x i1> %Mask, i32 1 350 // % ToLoad1 = icmp eq i1 % Mask1, true 351 // br i1 % ToLoad1, label %cond.load1, label %else2 352 // 353 // cond.load1: 354 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 355 // % Load1 = load i32, i32* % Ptr1, align 4 356 // % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1 357 // br label %else2 358 // . . . 359 // % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src 360 // ret <16 x i32> %Result 361 static void scalarizeMaskedGather(CallInst *CI) { 362 Value *Ptrs = CI->getArgOperand(0); 363 Value *Alignment = CI->getArgOperand(1); 364 Value *Mask = CI->getArgOperand(2); 365 Value *Src0 = CI->getArgOperand(3); 366 367 VectorType *VecType = dyn_cast<VectorType>(CI->getType()); 368 369 assert(VecType && "Unexpected return type of masked load intrinsic"); 370 371 IRBuilder<> Builder(CI->getContext()); 372 Instruction *InsertPt = CI; 373 BasicBlock *IfBlock = CI->getParent(); 374 BasicBlock *CondBlock = nullptr; 375 BasicBlock *PrevIfBlock = CI->getParent(); 376 Builder.SetInsertPoint(InsertPt); 377 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 378 379 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 380 381 Value *UndefVal = UndefValue::get(VecType); 382 383 // The result vector 384 Value *VResult = UndefVal; 385 unsigned VectorWidth = VecType->getNumElements(); 386 387 // Shorten the way if the mask is a vector of constants. 388 bool IsConstMask = isa<ConstantVector>(Mask); 389 390 if (IsConstMask) { 391 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 392 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue()) 393 continue; 394 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), 395 "Ptr" + Twine(Idx)); 396 LoadInst *Load = 397 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx)); 398 VResult = Builder.CreateInsertElement( 399 VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx)); 400 } 401 Value *NewI = Builder.CreateSelect(Mask, VResult, Src0); 402 CI->replaceAllUsesWith(NewI); 403 CI->eraseFromParent(); 404 return; 405 } 406 407 PHINode *Phi = nullptr; 408 Value *PrevPhi = UndefVal; 409 410 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 411 412 // Fill the "else" block, created in the previous iteration 413 // 414 // %Mask1 = extractelement <16 x i1> %Mask, i32 1 415 // %ToLoad1 = icmp eq i1 %Mask1, true 416 // br i1 %ToLoad1, label %cond.load, label %else 417 // 418 if (Idx > 0) { 419 Phi = Builder.CreatePHI(VecType, 2, "res.phi.else"); 420 Phi->addIncoming(VResult, CondBlock); 421 Phi->addIncoming(PrevPhi, PrevIfBlock); 422 PrevPhi = Phi; 423 VResult = Phi; 424 } 425 426 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx), 427 "Mask" + Twine(Idx)); 428 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate, 429 ConstantInt::get(Predicate->getType(), 1), 430 "ToLoad" + Twine(Idx)); 431 432 // Create "cond" block 433 // 434 // %EltAddr = getelementptr i32* %1, i32 0 435 // %Elt = load i32* %EltAddr 436 // VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx 437 // 438 CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load"); 439 Builder.SetInsertPoint(InsertPt); 440 441 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), 442 "Ptr" + Twine(Idx)); 443 LoadInst *Load = 444 Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx)); 445 VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx), 446 "Res" + Twine(Idx)); 447 448 // Create "else" block, fill it in the next iteration 449 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); 450 Builder.SetInsertPoint(InsertPt); 451 Instruction *OldBr = IfBlock->getTerminator(); 452 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); 453 OldBr->eraseFromParent(); 454 PrevIfBlock = IfBlock; 455 IfBlock = NewIfBlock; 456 } 457 458 Phi = Builder.CreatePHI(VecType, 2, "res.phi.select"); 459 Phi->addIncoming(VResult, CondBlock); 460 Phi->addIncoming(PrevPhi, PrevIfBlock); 461 Value *NewI = Builder.CreateSelect(Mask, Phi, Src0); 462 CI->replaceAllUsesWith(NewI); 463 CI->eraseFromParent(); 464 } 465 466 // Translate a masked scatter intrinsic, like 467 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4, 468 // <16 x i1> %Mask) 469 // to a chain of basic blocks, that stores element one-by-one if 470 // the appropriate mask bit is set. 471 // 472 // % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind 473 // % Mask0 = extractelement <16 x i1> % Mask, i32 0 474 // % ToStore0 = icmp eq i1 % Mask0, true 475 // br i1 %ToStore0, label %cond.store, label %else 476 // 477 // cond.store: 478 // % Elt0 = extractelement <16 x i32> %Src, i32 0 479 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0 480 // store i32 %Elt0, i32* % Ptr0, align 4 481 // br label %else 482 // 483 // else: 484 // % Mask1 = extractelement <16 x i1> % Mask, i32 1 485 // % ToStore1 = icmp eq i1 % Mask1, true 486 // br i1 % ToStore1, label %cond.store1, label %else2 487 // 488 // cond.store1: 489 // % Elt1 = extractelement <16 x i32> %Src, i32 1 490 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 491 // store i32 % Elt1, i32* % Ptr1, align 4 492 // br label %else2 493 // . . . 494 static void scalarizeMaskedScatter(CallInst *CI) { 495 Value *Src = CI->getArgOperand(0); 496 Value *Ptrs = CI->getArgOperand(1); 497 Value *Alignment = CI->getArgOperand(2); 498 Value *Mask = CI->getArgOperand(3); 499 500 assert(isa<VectorType>(Src->getType()) && 501 "Unexpected data type in masked scatter intrinsic"); 502 assert(isa<VectorType>(Ptrs->getType()) && 503 isa<PointerType>(Ptrs->getType()->getVectorElementType()) && 504 "Vector of pointers is expected in masked scatter intrinsic"); 505 506 IRBuilder<> Builder(CI->getContext()); 507 Instruction *InsertPt = CI; 508 BasicBlock *IfBlock = CI->getParent(); 509 Builder.SetInsertPoint(InsertPt); 510 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 511 512 unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue(); 513 unsigned VectorWidth = Src->getType()->getVectorNumElements(); 514 515 // Shorten the way if the mask is a vector of constants. 516 bool IsConstMask = isa<ConstantVector>(Mask); 517 518 if (IsConstMask) { 519 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 520 if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue()) 521 continue; 522 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx), 523 "Elt" + Twine(Idx)); 524 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), 525 "Ptr" + Twine(Idx)); 526 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 527 } 528 CI->eraseFromParent(); 529 return; 530 } 531 for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) { 532 // Fill the "else" block, created in the previous iteration 533 // 534 // % Mask1 = extractelement <16 x i1> % Mask, i32 Idx 535 // % ToStore = icmp eq i1 % Mask1, true 536 // br i1 % ToStore, label %cond.store, label %else 537 // 538 Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx), 539 "Mask" + Twine(Idx)); 540 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate, 541 ConstantInt::get(Predicate->getType(), 1), 542 "ToStore" + Twine(Idx)); 543 544 // Create "cond" block 545 // 546 // % Elt1 = extractelement <16 x i32> %Src, i32 1 547 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1 548 // %store i32 % Elt1, i32* % Ptr1 549 // 550 BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store"); 551 Builder.SetInsertPoint(InsertPt); 552 553 Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx), 554 "Elt" + Twine(Idx)); 555 Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx), 556 "Ptr" + Twine(Idx)); 557 Builder.CreateAlignedStore(OneElt, Ptr, AlignVal); 558 559 // Create "else" block, fill it in the next iteration 560 BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else"); 561 Builder.SetInsertPoint(InsertPt); 562 Instruction *OldBr = IfBlock->getTerminator(); 563 BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr); 564 OldBr->eraseFromParent(); 565 IfBlock = NewIfBlock; 566 } 567 CI->eraseFromParent(); 568 } 569 570 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) { 571 if (skipFunction(F)) 572 return false; 573 574 bool EverMadeChange = false; 575 576 TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 577 578 bool MadeChange = true; 579 while (MadeChange) { 580 MadeChange = false; 581 for (Function::iterator I = F.begin(); I != F.end();) { 582 BasicBlock *BB = &*I++; 583 bool ModifiedDTOnIteration = false; 584 MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration); 585 586 // Restart BB iteration if the dominator tree of the Function was changed 587 if (ModifiedDTOnIteration) 588 break; 589 } 590 591 EverMadeChange |= MadeChange; 592 } 593 594 return EverMadeChange; 595 } 596 597 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) { 598 bool MadeChange = false; 599 600 BasicBlock::iterator CurInstIterator = BB.begin(); 601 while (CurInstIterator != BB.end()) { 602 if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++)) 603 MadeChange |= optimizeCallInst(CI, ModifiedDT); 604 if (ModifiedDT) 605 return true; 606 } 607 608 return MadeChange; 609 } 610 611 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI, 612 bool &ModifiedDT) { 613 614 IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI); 615 if (II) { 616 switch (II->getIntrinsicID()) { 617 default: 618 break; 619 case Intrinsic::masked_load: { 620 // Scalarize unsupported vector masked load 621 if (!TTI->isLegalMaskedLoad(CI->getType())) { 622 scalarizeMaskedLoad(CI); 623 ModifiedDT = true; 624 return true; 625 } 626 return false; 627 } 628 case Intrinsic::masked_store: { 629 if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) { 630 scalarizeMaskedStore(CI); 631 ModifiedDT = true; 632 return true; 633 } 634 return false; 635 } 636 case Intrinsic::masked_gather: { 637 if (!TTI->isLegalMaskedGather(CI->getType())) { 638 scalarizeMaskedGather(CI); 639 ModifiedDT = true; 640 return true; 641 } 642 return false; 643 } 644 case Intrinsic::masked_scatter: { 645 if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) { 646 scalarizeMaskedScatter(CI); 647 ModifiedDT = true; 648 return true; 649 } 650 return false; 651 } 652 } 653 } 654 655 return false; 656 } 657