1 //=== ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem      ===//
2 //===                                instrinsics                           ===//
3 //
4 //                     The LLVM Compiler Infrastructure
5 //
6 // This file is distributed under the University of Illinois Open Source
7 // License. See LICENSE.TXT for details.
8 //
9 //===----------------------------------------------------------------------===//
10 //
11 // This pass replaces masked memory intrinsics - when unsupported by the target
12 // - with a chain of basic blocks, that deal with the elements one-by-one if the
13 // appropriate mask bit is set.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "llvm/Analysis/TargetTransformInfo.h"
18 #include "llvm/IR/IRBuilder.h"
19 #include "llvm/Target/TargetSubtargetInfo.h"
20 
21 using namespace llvm;
22 
23 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
24 
25 namespace {
26 
27 class ScalarizeMaskedMemIntrin : public FunctionPass {
28   const TargetTransformInfo *TTI;
29 
30 public:
31   static char ID; // Pass identification, replacement for typeid
32   explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID), TTI(nullptr) {
33     initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
34   }
35   bool runOnFunction(Function &F) override;
36 
37   StringRef getPassName() const override {
38     return "Scalarize Masked Memory Intrinsics";
39   }
40 
41   void getAnalysisUsage(AnalysisUsage &AU) const override {
42     AU.addRequired<TargetTransformInfoWrapperPass>();
43   }
44 
45 private:
46   bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
47   bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
48 };
49 } // namespace
50 
51 char ScalarizeMaskedMemIntrin::ID = 0;
52 INITIALIZE_PASS_BEGIN(ScalarizeMaskedMemIntrin, "scalarize-masked-mem-intrin",
53                       "Scalarize unsupported masked memory intrinsics", false,
54                       false)
55 INITIALIZE_PASS_END(ScalarizeMaskedMemIntrin, "scalarize-masked-mem-intrin",
56                     "Scalarize unsupported masked memory intrinsics", false,
57                     false)
58 
59 FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
60   return new ScalarizeMaskedMemIntrin();
61 }
62 
63 // Translate a masked load intrinsic like
64 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
65 //                               <16 x i1> %mask, <16 x i32> %passthru)
66 // to a chain of basic blocks, with loading element one-by-one if
67 // the appropriate mask bit is set
68 //
69 //  %1 = bitcast i8* %addr to i32*
70 //  %2 = extractelement <16 x i1> %mask, i32 0
71 //  %3 = icmp eq i1 %2, true
72 //  br i1 %3, label %cond.load, label %else
73 //
74 // cond.load:                                        ; preds = %0
75 //  %4 = getelementptr i32* %1, i32 0
76 //  %5 = load i32* %4
77 //  %6 = insertelement <16 x i32> undef, i32 %5, i32 0
78 //  br label %else
79 //
80 // else:                                             ; preds = %0, %cond.load
81 //  %res.phi.else = phi <16 x i32> [ %6, %cond.load ], [ undef, %0 ]
82 //  %7 = extractelement <16 x i1> %mask, i32 1
83 //  %8 = icmp eq i1 %7, true
84 //  br i1 %8, label %cond.load1, label %else2
85 //
86 // cond.load1:                                       ; preds = %else
87 //  %9 = getelementptr i32* %1, i32 1
88 //  %10 = load i32* %9
89 //  %11 = insertelement <16 x i32> %res.phi.else, i32 %10, i32 1
90 //  br label %else2
91 //
92 // else2:                                          ; preds = %else, %cond.load1
93 //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
94 //  %12 = extractelement <16 x i1> %mask, i32 2
95 //  %13 = icmp eq i1 %12, true
96 //  br i1 %13, label %cond.load4, label %else5
97 //
98 static void scalarizeMaskedLoad(CallInst *CI) {
99   Value *Ptr = CI->getArgOperand(0);
100   Value *Alignment = CI->getArgOperand(1);
101   Value *Mask = CI->getArgOperand(2);
102   Value *Src0 = CI->getArgOperand(3);
103 
104   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
105   VectorType *VecType = dyn_cast<VectorType>(CI->getType());
106   assert(VecType && "Unexpected return type of masked load intrinsic");
107 
108   Type *EltTy = CI->getType()->getVectorElementType();
109 
110   IRBuilder<> Builder(CI->getContext());
111   Instruction *InsertPt = CI;
112   BasicBlock *IfBlock = CI->getParent();
113   BasicBlock *CondBlock = nullptr;
114   BasicBlock *PrevIfBlock = CI->getParent();
115 
116   Builder.SetInsertPoint(InsertPt);
117   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
118 
119   // Short-cut if the mask is all-true.
120   bool IsAllOnesMask =
121       isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
122 
123   if (IsAllOnesMask) {
124     Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
125     CI->replaceAllUsesWith(NewI);
126     CI->eraseFromParent();
127     return;
128   }
129 
130   // Adjust alignment for the scalar instruction.
131   AlignVal = std::min(AlignVal, VecType->getScalarSizeInBits() / 8);
132   // Bitcast %addr fron i8* to EltTy*
133   Type *NewPtrType =
134       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
135   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
136   unsigned VectorWidth = VecType->getNumElements();
137 
138   Value *UndefVal = UndefValue::get(VecType);
139 
140   // The result vector
141   Value *VResult = UndefVal;
142 
143   if (isa<ConstantVector>(Mask)) {
144     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
145       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
146         continue;
147       Value *Gep =
148           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
149       LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
150       VResult =
151           Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
152     }
153     Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
154     CI->replaceAllUsesWith(NewI);
155     CI->eraseFromParent();
156     return;
157   }
158 
159   PHINode *Phi = nullptr;
160   Value *PrevPhi = UndefVal;
161 
162   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
163 
164     // Fill the "else" block, created in the previous iteration
165     //
166     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
167     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
168     //  %to_load = icmp eq i1 %mask_1, true
169     //  br i1 %to_load, label %cond.load, label %else
170     //
171     if (Idx > 0) {
172       Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
173       Phi->addIncoming(VResult, CondBlock);
174       Phi->addIncoming(PrevPhi, PrevIfBlock);
175       PrevPhi = Phi;
176       VResult = Phi;
177     }
178 
179     Value *Predicate =
180         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
181     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
182                                     ConstantInt::get(Predicate->getType(), 1));
183 
184     // Create "cond" block
185     //
186     //  %EltAddr = getelementptr i32* %1, i32 0
187     //  %Elt = load i32* %EltAddr
188     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
189     //
190     CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.load");
191     Builder.SetInsertPoint(InsertPt);
192 
193     Value *Gep =
194         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
195     LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
196     VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
197 
198     // Create "else" block, fill it in the next iteration
199     BasicBlock *NewIfBlock =
200         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
201     Builder.SetInsertPoint(InsertPt);
202     Instruction *OldBr = IfBlock->getTerminator();
203     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
204     OldBr->eraseFromParent();
205     PrevIfBlock = IfBlock;
206     IfBlock = NewIfBlock;
207   }
208 
209   Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
210   Phi->addIncoming(VResult, CondBlock);
211   Phi->addIncoming(PrevPhi, PrevIfBlock);
212   Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
213   CI->replaceAllUsesWith(NewI);
214   CI->eraseFromParent();
215 }
216 
217 // Translate a masked store intrinsic, like
218 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
219 //                               <16 x i1> %mask)
220 // to a chain of basic blocks, that stores element one-by-one if
221 // the appropriate mask bit is set
222 //
223 //   %1 = bitcast i8* %addr to i32*
224 //   %2 = extractelement <16 x i1> %mask, i32 0
225 //   %3 = icmp eq i1 %2, true
226 //   br i1 %3, label %cond.store, label %else
227 //
228 // cond.store:                                       ; preds = %0
229 //   %4 = extractelement <16 x i32> %val, i32 0
230 //   %5 = getelementptr i32* %1, i32 0
231 //   store i32 %4, i32* %5
232 //   br label %else
233 //
234 // else:                                             ; preds = %0, %cond.store
235 //   %6 = extractelement <16 x i1> %mask, i32 1
236 //   %7 = icmp eq i1 %6, true
237 //   br i1 %7, label %cond.store1, label %else2
238 //
239 // cond.store1:                                      ; preds = %else
240 //   %8 = extractelement <16 x i32> %val, i32 1
241 //   %9 = getelementptr i32* %1, i32 1
242 //   store i32 %8, i32* %9
243 //   br label %else2
244 //   . . .
245 static void scalarizeMaskedStore(CallInst *CI) {
246   Value *Src = CI->getArgOperand(0);
247   Value *Ptr = CI->getArgOperand(1);
248   Value *Alignment = CI->getArgOperand(2);
249   Value *Mask = CI->getArgOperand(3);
250 
251   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
252   VectorType *VecType = dyn_cast<VectorType>(Src->getType());
253   assert(VecType && "Unexpected data type in masked store intrinsic");
254 
255   Type *EltTy = VecType->getElementType();
256 
257   IRBuilder<> Builder(CI->getContext());
258   Instruction *InsertPt = CI;
259   BasicBlock *IfBlock = CI->getParent();
260   Builder.SetInsertPoint(InsertPt);
261   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
262 
263   // Short-cut if the mask is all-true.
264   bool IsAllOnesMask =
265       isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue();
266 
267   if (IsAllOnesMask) {
268     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
269     CI->eraseFromParent();
270     return;
271   }
272 
273   // Adjust alignment for the scalar instruction.
274   AlignVal = std::max(AlignVal, VecType->getScalarSizeInBits() / 8);
275   // Bitcast %addr fron i8* to EltTy*
276   Type *NewPtrType =
277       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
278   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
279   unsigned VectorWidth = VecType->getNumElements();
280 
281   if (isa<ConstantVector>(Mask)) {
282     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
283       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
284         continue;
285       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
286       Value *Gep =
287           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
288       Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
289     }
290     CI->eraseFromParent();
291     return;
292   }
293 
294   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
295 
296     // Fill the "else" block, created in the previous iteration
297     //
298     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
299     //  %to_store = icmp eq i1 %mask_1, true
300     //  br i1 %to_store, label %cond.store, label %else
301     //
302     Value *Predicate =
303         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
304     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
305                                     ConstantInt::get(Predicate->getType(), 1));
306 
307     // Create "cond" block
308     //
309     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
310     //  %EltAddr = getelementptr i32* %1, i32 0
311     //  %store i32 %OneElt, i32* %EltAddr
312     //
313     BasicBlock *CondBlock =
314         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
315     Builder.SetInsertPoint(InsertPt);
316 
317     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
318     Value *Gep =
319         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
320     Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
321 
322     // Create "else" block, fill it in the next iteration
323     BasicBlock *NewIfBlock =
324         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
325     Builder.SetInsertPoint(InsertPt);
326     Instruction *OldBr = IfBlock->getTerminator();
327     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
328     OldBr->eraseFromParent();
329     IfBlock = NewIfBlock;
330   }
331   CI->eraseFromParent();
332 }
333 
334 // Translate a masked gather intrinsic like
335 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
336 //                               <16 x i1> %Mask, <16 x i32> %Src)
337 // to a chain of basic blocks, with loading element one-by-one if
338 // the appropriate mask bit is set
339 //
340 // % Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
341 // % Mask0 = extractelement <16 x i1> %Mask, i32 0
342 // % ToLoad0 = icmp eq i1 % Mask0, true
343 // br i1 % ToLoad0, label %cond.load, label %else
344 //
345 // cond.load:
346 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
347 // % Load0 = load i32, i32* % Ptr0, align 4
348 // % Res0 = insertelement <16 x i32> undef, i32 % Load0, i32 0
349 // br label %else
350 //
351 // else:
352 // %res.phi.else = phi <16 x i32>[% Res0, %cond.load], [undef, % 0]
353 // % Mask1 = extractelement <16 x i1> %Mask, i32 1
354 // % ToLoad1 = icmp eq i1 % Mask1, true
355 // br i1 % ToLoad1, label %cond.load1, label %else2
356 //
357 // cond.load1:
358 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
359 // % Load1 = load i32, i32* % Ptr1, align 4
360 // % Res1 = insertelement <16 x i32> %res.phi.else, i32 % Load1, i32 1
361 // br label %else2
362 // . . .
363 // % Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
364 // ret <16 x i32> %Result
365 static void scalarizeMaskedGather(CallInst *CI) {
366   Value *Ptrs = CI->getArgOperand(0);
367   Value *Alignment = CI->getArgOperand(1);
368   Value *Mask = CI->getArgOperand(2);
369   Value *Src0 = CI->getArgOperand(3);
370 
371   VectorType *VecType = dyn_cast<VectorType>(CI->getType());
372 
373   assert(VecType && "Unexpected return type of masked load intrinsic");
374 
375   IRBuilder<> Builder(CI->getContext());
376   Instruction *InsertPt = CI;
377   BasicBlock *IfBlock = CI->getParent();
378   BasicBlock *CondBlock = nullptr;
379   BasicBlock *PrevIfBlock = CI->getParent();
380   Builder.SetInsertPoint(InsertPt);
381   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
382 
383   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
384 
385   Value *UndefVal = UndefValue::get(VecType);
386 
387   // The result vector
388   Value *VResult = UndefVal;
389   unsigned VectorWidth = VecType->getNumElements();
390 
391   // Shorten the way if the mask is a vector of constants.
392   bool IsConstMask = isa<ConstantVector>(Mask);
393 
394   if (IsConstMask) {
395     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
396       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
397         continue;
398       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
399                                                 "Ptr" + Twine(Idx));
400       LoadInst *Load =
401           Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
402       VResult = Builder.CreateInsertElement(
403           VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
404     }
405     Value *NewI = Builder.CreateSelect(Mask, VResult, Src0);
406     CI->replaceAllUsesWith(NewI);
407     CI->eraseFromParent();
408     return;
409   }
410 
411   PHINode *Phi = nullptr;
412   Value *PrevPhi = UndefVal;
413 
414   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
415 
416     // Fill the "else" block, created in the previous iteration
417     //
418     //  %Mask1 = extractelement <16 x i1> %Mask, i32 1
419     //  %ToLoad1 = icmp eq i1 %Mask1, true
420     //  br i1 %ToLoad1, label %cond.load, label %else
421     //
422     if (Idx > 0) {
423       Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
424       Phi->addIncoming(VResult, CondBlock);
425       Phi->addIncoming(PrevPhi, PrevIfBlock);
426       PrevPhi = Phi;
427       VResult = Phi;
428     }
429 
430     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
431                                                     "Mask" + Twine(Idx));
432     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
433                                     ConstantInt::get(Predicate->getType(), 1),
434                                     "ToLoad" + Twine(Idx));
435 
436     // Create "cond" block
437     //
438     //  %EltAddr = getelementptr i32* %1, i32 0
439     //  %Elt = load i32* %EltAddr
440     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
441     //
442     CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
443     Builder.SetInsertPoint(InsertPt);
444 
445     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
446                                               "Ptr" + Twine(Idx));
447     LoadInst *Load =
448         Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
449     VResult = Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx),
450                                           "Res" + Twine(Idx));
451 
452     // Create "else" block, fill it in the next iteration
453     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
454     Builder.SetInsertPoint(InsertPt);
455     Instruction *OldBr = IfBlock->getTerminator();
456     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
457     OldBr->eraseFromParent();
458     PrevIfBlock = IfBlock;
459     IfBlock = NewIfBlock;
460   }
461 
462   Phi = Builder.CreatePHI(VecType, 2, "res.phi.select");
463   Phi->addIncoming(VResult, CondBlock);
464   Phi->addIncoming(PrevPhi, PrevIfBlock);
465   Value *NewI = Builder.CreateSelect(Mask, Phi, Src0);
466   CI->replaceAllUsesWith(NewI);
467   CI->eraseFromParent();
468 }
469 
470 // Translate a masked scatter intrinsic, like
471 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
472 //                                  <16 x i1> %Mask)
473 // to a chain of basic blocks, that stores element one-by-one if
474 // the appropriate mask bit is set.
475 //
476 // % Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
477 // % Mask0 = extractelement <16 x i1> % Mask, i32 0
478 // % ToStore0 = icmp eq i1 % Mask0, true
479 // br i1 %ToStore0, label %cond.store, label %else
480 //
481 // cond.store:
482 // % Elt0 = extractelement <16 x i32> %Src, i32 0
483 // % Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
484 // store i32 %Elt0, i32* % Ptr0, align 4
485 // br label %else
486 //
487 // else:
488 // % Mask1 = extractelement <16 x i1> % Mask, i32 1
489 // % ToStore1 = icmp eq i1 % Mask1, true
490 // br i1 % ToStore1, label %cond.store1, label %else2
491 //
492 // cond.store1:
493 // % Elt1 = extractelement <16 x i32> %Src, i32 1
494 // % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
495 // store i32 % Elt1, i32* % Ptr1, align 4
496 // br label %else2
497 //   . . .
498 static void scalarizeMaskedScatter(CallInst *CI) {
499   Value *Src = CI->getArgOperand(0);
500   Value *Ptrs = CI->getArgOperand(1);
501   Value *Alignment = CI->getArgOperand(2);
502   Value *Mask = CI->getArgOperand(3);
503 
504   assert(isa<VectorType>(Src->getType()) &&
505          "Unexpected data type in masked scatter intrinsic");
506   assert(isa<VectorType>(Ptrs->getType()) &&
507          isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
508          "Vector of pointers is expected in masked scatter intrinsic");
509 
510   IRBuilder<> Builder(CI->getContext());
511   Instruction *InsertPt = CI;
512   BasicBlock *IfBlock = CI->getParent();
513   Builder.SetInsertPoint(InsertPt);
514   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
515 
516   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
517   unsigned VectorWidth = Src->getType()->getVectorNumElements();
518 
519   // Shorten the way if the mask is a vector of constants.
520   bool IsConstMask = isa<ConstantVector>(Mask);
521 
522   if (IsConstMask) {
523     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
524       if (cast<ConstantVector>(Mask)->getOperand(Idx)->isNullValue())
525         continue;
526       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
527                                                    "Elt" + Twine(Idx));
528       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
529                                                 "Ptr" + Twine(Idx));
530       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
531     }
532     CI->eraseFromParent();
533     return;
534   }
535   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
536     // Fill the "else" block, created in the previous iteration
537     //
538     //  % Mask1 = extractelement <16 x i1> % Mask, i32 Idx
539     //  % ToStore = icmp eq i1 % Mask1, true
540     //  br i1 % ToStore, label %cond.store, label %else
541     //
542     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
543                                                     "Mask" + Twine(Idx));
544     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Predicate,
545                                     ConstantInt::get(Predicate->getType(), 1),
546                                     "ToStore" + Twine(Idx));
547 
548     // Create "cond" block
549     //
550     //  % Elt1 = extractelement <16 x i32> %Src, i32 1
551     //  % Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
552     //  %store i32 % Elt1, i32* % Ptr1
553     //
554     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
555     Builder.SetInsertPoint(InsertPt);
556 
557     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
558                                                  "Elt" + Twine(Idx));
559     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
560                                               "Ptr" + Twine(Idx));
561     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
562 
563     // Create "else" block, fill it in the next iteration
564     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
565     Builder.SetInsertPoint(InsertPt);
566     Instruction *OldBr = IfBlock->getTerminator();
567     BranchInst::Create(CondBlock, NewIfBlock, Cmp, OldBr);
568     OldBr->eraseFromParent();
569     IfBlock = NewIfBlock;
570   }
571   CI->eraseFromParent();
572 }
573 
574 bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
575   if (skipFunction(F))
576     return false;
577 
578   bool EverMadeChange = false;
579 
580   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
581 
582   bool MadeChange = true;
583   while (MadeChange) {
584     MadeChange = false;
585     for (Function::iterator I = F.begin(); I != F.end();) {
586       BasicBlock *BB = &*I++;
587       bool ModifiedDTOnIteration = false;
588       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
589 
590       // Restart BB iteration if the dominator tree of the Function was changed
591       if (ModifiedDTOnIteration)
592         break;
593     }
594 
595     EverMadeChange |= MadeChange;
596   }
597 
598   return EverMadeChange;
599 }
600 
601 bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
602   bool MadeChange = false;
603 
604   BasicBlock::iterator CurInstIterator = BB.begin();
605   while (CurInstIterator != BB.end()) {
606     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
607       MadeChange |= optimizeCallInst(CI, ModifiedDT);
608     if (ModifiedDT)
609       return true;
610   }
611 
612   return MadeChange;
613 }
614 
615 bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
616                                                 bool &ModifiedDT) {
617 
618   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
619   if (II) {
620     switch (II->getIntrinsicID()) {
621     default:
622       break;
623     case Intrinsic::masked_load: {
624       // Scalarize unsupported vector masked load
625       if (!TTI->isLegalMaskedLoad(CI->getType())) {
626         scalarizeMaskedLoad(CI);
627         ModifiedDT = true;
628         return true;
629       }
630       return false;
631     }
632     case Intrinsic::masked_store: {
633       if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
634         scalarizeMaskedStore(CI);
635         ModifiedDT = true;
636         return true;
637       }
638       return false;
639     }
640     case Intrinsic::masked_gather: {
641       if (!TTI->isLegalMaskedGather(CI->getType())) {
642         scalarizeMaskedGather(CI);
643         ModifiedDT = true;
644         return true;
645       }
646       return false;
647     }
648     case Intrinsic::masked_scatter: {
649       if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
650         scalarizeMaskedScatter(CI);
651         ModifiedDT = true;
652         return true;
653       }
654       return false;
655     }
656     }
657   }
658 
659   return false;
660 }
661