1 //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
2 //                                    instrinsics
3 //
4 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
5 // See https://llvm.org/LICENSE.txt for license information.
6 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass replaces masked memory intrinsics - when unsupported by the target
11 // - with a chain of basic blocks, that deal with the elements one-by-one if the
12 // appropriate mask bit is set.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/Transforms/Scalar/ScalarizeMaskedMemIntrin.h"
17 #include "llvm/ADT/Twine.h"
18 #include "llvm/Analysis/TargetTransformInfo.h"
19 #include "llvm/IR/BasicBlock.h"
20 #include "llvm/IR/Constant.h"
21 #include "llvm/IR/Constants.h"
22 #include "llvm/IR/DerivedTypes.h"
23 #include "llvm/IR/Function.h"
24 #include "llvm/IR/IRBuilder.h"
25 #include "llvm/IR/InstrTypes.h"
26 #include "llvm/IR/Instruction.h"
27 #include "llvm/IR/Instructions.h"
28 #include "llvm/IR/IntrinsicInst.h"
29 #include "llvm/IR/Intrinsics.h"
30 #include "llvm/IR/Type.h"
31 #include "llvm/IR/Value.h"
32 #include "llvm/InitializePasses.h"
33 #include "llvm/Pass.h"
34 #include "llvm/Support/Casting.h"
35 #include "llvm/Transforms/Scalar.h"
36 #include <algorithm>
37 #include <cassert>
38 
39 using namespace llvm;
40 
41 #define DEBUG_TYPE "scalarize-masked-mem-intrin"
42 
43 namespace {
44 
45 class ScalarizeMaskedMemIntrinLegacyPass : public FunctionPass {
46 public:
47   static char ID; // Pass identification, replacement for typeid
48 
49   explicit ScalarizeMaskedMemIntrinLegacyPass() : FunctionPass(ID) {
50     initializeScalarizeMaskedMemIntrinLegacyPassPass(
51         *PassRegistry::getPassRegistry());
52   }
53 
54   bool runOnFunction(Function &F) override;
55 
56   StringRef getPassName() const override {
57     return "Scalarize Masked Memory Intrinsics";
58   }
59 
60   void getAnalysisUsage(AnalysisUsage &AU) const override {
61     AU.addRequired<TargetTransformInfoWrapperPass>();
62   }
63 };
64 
65 } // end anonymous namespace
66 
67 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
68                           const TargetTransformInfo &TTI, const DataLayout &DL);
69 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
70                              const TargetTransformInfo &TTI,
71                              const DataLayout &DL);
72 
73 char ScalarizeMaskedMemIntrinLegacyPass::ID = 0;
74 
75 INITIALIZE_PASS(ScalarizeMaskedMemIntrinLegacyPass, DEBUG_TYPE,
76                 "Scalarize unsupported masked memory intrinsics", false, false)
77 
78 FunctionPass *llvm::createScalarizeMaskedMemIntrinLegacyPass() {
79   return new ScalarizeMaskedMemIntrinLegacyPass();
80 }
81 
82 static bool isConstantIntVector(Value *Mask) {
83   Constant *C = dyn_cast<Constant>(Mask);
84   if (!C)
85     return false;
86 
87   unsigned NumElts = cast<FixedVectorType>(Mask->getType())->getNumElements();
88   for (unsigned i = 0; i != NumElts; ++i) {
89     Constant *CElt = C->getAggregateElement(i);
90     if (!CElt || !isa<ConstantInt>(CElt))
91       return false;
92   }
93 
94   return true;
95 }
96 
97 // Translate a masked load intrinsic like
98 // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
99 //                               <16 x i1> %mask, <16 x i32> %passthru)
100 // to a chain of basic blocks, with loading element one-by-one if
101 // the appropriate mask bit is set
102 //
103 //  %1 = bitcast i8* %addr to i32*
104 //  %2 = extractelement <16 x i1> %mask, i32 0
105 //  br i1 %2, label %cond.load, label %else
106 //
107 // cond.load:                                        ; preds = %0
108 //  %3 = getelementptr i32* %1, i32 0
109 //  %4 = load i32* %3
110 //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
111 //  br label %else
112 //
113 // else:                                             ; preds = %0, %cond.load
114 //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
115 //  %6 = extractelement <16 x i1> %mask, i32 1
116 //  br i1 %6, label %cond.load1, label %else2
117 //
118 // cond.load1:                                       ; preds = %else
119 //  %7 = getelementptr i32* %1, i32 1
120 //  %8 = load i32* %7
121 //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
122 //  br label %else2
123 //
124 // else2:                                          ; preds = %else, %cond.load1
125 //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
126 //  %10 = extractelement <16 x i1> %mask, i32 2
127 //  br i1 %10, label %cond.load4, label %else5
128 //
129 static void scalarizeMaskedLoad(CallInst *CI, bool &ModifiedDT) {
130   Value *Ptr = CI->getArgOperand(0);
131   Value *Alignment = CI->getArgOperand(1);
132   Value *Mask = CI->getArgOperand(2);
133   Value *Src0 = CI->getArgOperand(3);
134 
135   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
136   VectorType *VecType = cast<FixedVectorType>(CI->getType());
137 
138   Type *EltTy = VecType->getElementType();
139 
140   IRBuilder<> Builder(CI->getContext());
141   Instruction *InsertPt = CI;
142   BasicBlock *IfBlock = CI->getParent();
143 
144   Builder.SetInsertPoint(InsertPt);
145   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
146 
147   // Short-cut if the mask is all-true.
148   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
149     Value *NewI = Builder.CreateAlignedLoad(VecType, Ptr, AlignVal);
150     CI->replaceAllUsesWith(NewI);
151     CI->eraseFromParent();
152     return;
153   }
154 
155   // Adjust alignment for the scalar instruction.
156   const Align AdjustedAlignVal =
157       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
158   // Bitcast %addr from i8* to EltTy*
159   Type *NewPtrType =
160       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
161   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
162   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
163 
164   // The result vector
165   Value *VResult = Src0;
166 
167   if (isConstantIntVector(Mask)) {
168     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
169       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
170         continue;
171       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
172       LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
173       VResult = Builder.CreateInsertElement(VResult, Load, Idx);
174     }
175     CI->replaceAllUsesWith(VResult);
176     CI->eraseFromParent();
177     return;
178   }
179 
180   // If the mask is not v1i1, use scalar bit test operations. This generates
181   // better results on X86 at least.
182   Value *SclrMask;
183   if (VectorWidth != 1) {
184     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
185     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
186   }
187 
188   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
189     // Fill the "else" block, created in the previous iteration
190     //
191     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
192     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
193     //  %cond = icmp ne i16 %mask_1, 0
194     //  br i1 %mask_1, label %cond.load, label %else
195     //
196     Value *Predicate;
197     if (VectorWidth != 1) {
198       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
199       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
200                                        Builder.getIntN(VectorWidth, 0));
201     } else {
202       Predicate = Builder.CreateExtractElement(Mask, Idx);
203     }
204 
205     // Create "cond" block
206     //
207     //  %EltAddr = getelementptr i32* %1, i32 0
208     //  %Elt = load i32* %EltAddr
209     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
210     //
211     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
212                                                      "cond.load");
213     Builder.SetInsertPoint(InsertPt);
214 
215     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
216     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Gep, AdjustedAlignVal);
217     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
218 
219     // Create "else" block, fill it in the next iteration
220     BasicBlock *NewIfBlock =
221         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
222     Builder.SetInsertPoint(InsertPt);
223     Instruction *OldBr = IfBlock->getTerminator();
224     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
225     OldBr->eraseFromParent();
226     BasicBlock *PrevIfBlock = IfBlock;
227     IfBlock = NewIfBlock;
228 
229     // Create the phi to join the new and previous value.
230     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
231     Phi->addIncoming(NewVResult, CondBlock);
232     Phi->addIncoming(VResult, PrevIfBlock);
233     VResult = Phi;
234   }
235 
236   CI->replaceAllUsesWith(VResult);
237   CI->eraseFromParent();
238 
239   ModifiedDT = true;
240 }
241 
242 // Translate a masked store intrinsic, like
243 // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
244 //                               <16 x i1> %mask)
245 // to a chain of basic blocks, that stores element one-by-one if
246 // the appropriate mask bit is set
247 //
248 //   %1 = bitcast i8* %addr to i32*
249 //   %2 = extractelement <16 x i1> %mask, i32 0
250 //   br i1 %2, label %cond.store, label %else
251 //
252 // cond.store:                                       ; preds = %0
253 //   %3 = extractelement <16 x i32> %val, i32 0
254 //   %4 = getelementptr i32* %1, i32 0
255 //   store i32 %3, i32* %4
256 //   br label %else
257 //
258 // else:                                             ; preds = %0, %cond.store
259 //   %5 = extractelement <16 x i1> %mask, i32 1
260 //   br i1 %5, label %cond.store1, label %else2
261 //
262 // cond.store1:                                      ; preds = %else
263 //   %6 = extractelement <16 x i32> %val, i32 1
264 //   %7 = getelementptr i32* %1, i32 1
265 //   store i32 %6, i32* %7
266 //   br label %else2
267 //   . . .
268 static void scalarizeMaskedStore(CallInst *CI, bool &ModifiedDT) {
269   Value *Src = CI->getArgOperand(0);
270   Value *Ptr = CI->getArgOperand(1);
271   Value *Alignment = CI->getArgOperand(2);
272   Value *Mask = CI->getArgOperand(3);
273 
274   const Align AlignVal = cast<ConstantInt>(Alignment)->getAlignValue();
275   auto *VecType = cast<VectorType>(Src->getType());
276 
277   Type *EltTy = VecType->getElementType();
278 
279   IRBuilder<> Builder(CI->getContext());
280   Instruction *InsertPt = CI;
281   BasicBlock *IfBlock = CI->getParent();
282   Builder.SetInsertPoint(InsertPt);
283   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
284 
285   // Short-cut if the mask is all-true.
286   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
287     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
288     CI->eraseFromParent();
289     return;
290   }
291 
292   // Adjust alignment for the scalar instruction.
293   const Align AdjustedAlignVal =
294       commonAlignment(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
295   // Bitcast %addr from i8* to EltTy*
296   Type *NewPtrType =
297       EltTy->getPointerTo(Ptr->getType()->getPointerAddressSpace());
298   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
299   unsigned VectorWidth = cast<FixedVectorType>(VecType)->getNumElements();
300 
301   if (isConstantIntVector(Mask)) {
302     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
303       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
304         continue;
305       Value *OneElt = Builder.CreateExtractElement(Src, Idx);
306       Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
307       Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
308     }
309     CI->eraseFromParent();
310     return;
311   }
312 
313   // If the mask is not v1i1, use scalar bit test operations. This generates
314   // better results on X86 at least.
315   Value *SclrMask;
316   if (VectorWidth != 1) {
317     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
318     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
319   }
320 
321   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
322     // Fill the "else" block, created in the previous iteration
323     //
324     //  %mask_1 = and i16 %scalar_mask, i32 1 << Idx
325     //  %cond = icmp ne i16 %mask_1, 0
326     //  br i1 %mask_1, label %cond.store, label %else
327     //
328     Value *Predicate;
329     if (VectorWidth != 1) {
330       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
331       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
332                                        Builder.getIntN(VectorWidth, 0));
333     } else {
334       Predicate = Builder.CreateExtractElement(Mask, Idx);
335     }
336 
337     // Create "cond" block
338     //
339     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
340     //  %EltAddr = getelementptr i32* %1, i32 0
341     //  %store i32 %OneElt, i32* %EltAddr
342     //
343     BasicBlock *CondBlock =
344         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
345     Builder.SetInsertPoint(InsertPt);
346 
347     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
348     Value *Gep = Builder.CreateConstInBoundsGEP1_32(EltTy, FirstEltPtr, Idx);
349     Builder.CreateAlignedStore(OneElt, Gep, AdjustedAlignVal);
350 
351     // Create "else" block, fill it in the next iteration
352     BasicBlock *NewIfBlock =
353         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
354     Builder.SetInsertPoint(InsertPt);
355     Instruction *OldBr = IfBlock->getTerminator();
356     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
357     OldBr->eraseFromParent();
358     IfBlock = NewIfBlock;
359   }
360   CI->eraseFromParent();
361 
362   ModifiedDT = true;
363 }
364 
365 // Translate a masked gather intrinsic like
366 // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
367 //                               <16 x i1> %Mask, <16 x i32> %Src)
368 // to a chain of basic blocks, with loading element one-by-one if
369 // the appropriate mask bit is set
370 //
371 // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
372 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
373 // br i1 %Mask0, label %cond.load, label %else
374 //
375 // cond.load:
376 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
377 // %Load0 = load i32, i32* %Ptr0, align 4
378 // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
379 // br label %else
380 //
381 // else:
382 // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
383 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
384 // br i1 %Mask1, label %cond.load1, label %else2
385 //
386 // cond.load1:
387 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
388 // %Load1 = load i32, i32* %Ptr1, align 4
389 // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
390 // br label %else2
391 // . . .
392 // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
393 // ret <16 x i32> %Result
394 static void scalarizeMaskedGather(CallInst *CI, bool &ModifiedDT) {
395   Value *Ptrs = CI->getArgOperand(0);
396   Value *Alignment = CI->getArgOperand(1);
397   Value *Mask = CI->getArgOperand(2);
398   Value *Src0 = CI->getArgOperand(3);
399 
400   auto *VecType = cast<FixedVectorType>(CI->getType());
401   Type *EltTy = VecType->getElementType();
402 
403   IRBuilder<> Builder(CI->getContext());
404   Instruction *InsertPt = CI;
405   BasicBlock *IfBlock = CI->getParent();
406   Builder.SetInsertPoint(InsertPt);
407   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
408 
409   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
410 
411   // The result vector
412   Value *VResult = Src0;
413   unsigned VectorWidth = VecType->getNumElements();
414 
415   // Shorten the way if the mask is a vector of constants.
416   if (isConstantIntVector(Mask)) {
417     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
418       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
419         continue;
420       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
421       LoadInst *Load =
422           Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
423       VResult =
424           Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
425     }
426     CI->replaceAllUsesWith(VResult);
427     CI->eraseFromParent();
428     return;
429   }
430 
431   // If the mask is not v1i1, use scalar bit test operations. This generates
432   // better results on X86 at least.
433   Value *SclrMask;
434   if (VectorWidth != 1) {
435     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
436     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
437   }
438 
439   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
440     // Fill the "else" block, created in the previous iteration
441     //
442     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
443     //  %cond = icmp ne i16 %mask_1, 0
444     //  br i1 %Mask1, label %cond.load, label %else
445     //
446 
447     Value *Predicate;
448     if (VectorWidth != 1) {
449       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
450       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
451                                        Builder.getIntN(VectorWidth, 0));
452     } else {
453       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
454     }
455 
456     // Create "cond" block
457     //
458     //  %EltAddr = getelementptr i32* %1, i32 0
459     //  %Elt = load i32* %EltAddr
460     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
461     //
462     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
463     Builder.SetInsertPoint(InsertPt);
464 
465     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
466     LoadInst *Load =
467         Builder.CreateAlignedLoad(EltTy, Ptr, AlignVal, "Load" + Twine(Idx));
468     Value *NewVResult =
469         Builder.CreateInsertElement(VResult, Load, Idx, "Res" + Twine(Idx));
470 
471     // Create "else" block, fill it in the next iteration
472     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
473     Builder.SetInsertPoint(InsertPt);
474     Instruction *OldBr = IfBlock->getTerminator();
475     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
476     OldBr->eraseFromParent();
477     BasicBlock *PrevIfBlock = IfBlock;
478     IfBlock = NewIfBlock;
479 
480     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
481     Phi->addIncoming(NewVResult, CondBlock);
482     Phi->addIncoming(VResult, PrevIfBlock);
483     VResult = Phi;
484   }
485 
486   CI->replaceAllUsesWith(VResult);
487   CI->eraseFromParent();
488 
489   ModifiedDT = true;
490 }
491 
492 // Translate a masked scatter intrinsic, like
493 // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
494 //                                  <16 x i1> %Mask)
495 // to a chain of basic blocks, that stores element one-by-one if
496 // the appropriate mask bit is set.
497 //
498 // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
499 // %Mask0 = extractelement <16 x i1> %Mask, i32 0
500 // br i1 %Mask0, label %cond.store, label %else
501 //
502 // cond.store:
503 // %Elt0 = extractelement <16 x i32> %Src, i32 0
504 // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
505 // store i32 %Elt0, i32* %Ptr0, align 4
506 // br label %else
507 //
508 // else:
509 // %Mask1 = extractelement <16 x i1> %Mask, i32 1
510 // br i1 %Mask1, label %cond.store1, label %else2
511 //
512 // cond.store1:
513 // %Elt1 = extractelement <16 x i32> %Src, i32 1
514 // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
515 // store i32 %Elt1, i32* %Ptr1, align 4
516 // br label %else2
517 //   . . .
518 static void scalarizeMaskedScatter(CallInst *CI, bool &ModifiedDT) {
519   Value *Src = CI->getArgOperand(0);
520   Value *Ptrs = CI->getArgOperand(1);
521   Value *Alignment = CI->getArgOperand(2);
522   Value *Mask = CI->getArgOperand(3);
523 
524   auto *SrcFVTy = cast<FixedVectorType>(Src->getType());
525 
526   assert(
527       isa<VectorType>(Ptrs->getType()) &&
528       isa<PointerType>(cast<VectorType>(Ptrs->getType())->getElementType()) &&
529       "Vector of pointers is expected in masked scatter intrinsic");
530 
531   IRBuilder<> Builder(CI->getContext());
532   Instruction *InsertPt = CI;
533   BasicBlock *IfBlock = CI->getParent();
534   Builder.SetInsertPoint(InsertPt);
535   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
536 
537   MaybeAlign AlignVal = cast<ConstantInt>(Alignment)->getMaybeAlignValue();
538   unsigned VectorWidth = SrcFVTy->getNumElements();
539 
540   // Shorten the way if the mask is a vector of constants.
541   if (isConstantIntVector(Mask)) {
542     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
543       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
544         continue;
545       Value *OneElt =
546           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
547       Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
548       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
549     }
550     CI->eraseFromParent();
551     return;
552   }
553 
554   // If the mask is not v1i1, use scalar bit test operations. This generates
555   // better results on X86 at least.
556   Value *SclrMask;
557   if (VectorWidth != 1) {
558     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
559     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
560   }
561 
562   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
563     // Fill the "else" block, created in the previous iteration
564     //
565     //  %Mask1 = and i16 %scalar_mask, i32 1 << Idx
566     //  %cond = icmp ne i16 %mask_1, 0
567     //  br i1 %Mask1, label %cond.store, label %else
568     //
569     Value *Predicate;
570     if (VectorWidth != 1) {
571       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
572       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
573                                        Builder.getIntN(VectorWidth, 0));
574     } else {
575       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
576     }
577 
578     // Create "cond" block
579     //
580     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
581     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
582     //  %store i32 %Elt1, i32* %Ptr1
583     //
584     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
585     Builder.SetInsertPoint(InsertPt);
586 
587     Value *OneElt = Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
588     Value *Ptr = Builder.CreateExtractElement(Ptrs, Idx, "Ptr" + Twine(Idx));
589     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
590 
591     // Create "else" block, fill it in the next iteration
592     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
593     Builder.SetInsertPoint(InsertPt);
594     Instruction *OldBr = IfBlock->getTerminator();
595     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
596     OldBr->eraseFromParent();
597     IfBlock = NewIfBlock;
598   }
599   CI->eraseFromParent();
600 
601   ModifiedDT = true;
602 }
603 
604 static void scalarizeMaskedExpandLoad(CallInst *CI, bool &ModifiedDT) {
605   Value *Ptr = CI->getArgOperand(0);
606   Value *Mask = CI->getArgOperand(1);
607   Value *PassThru = CI->getArgOperand(2);
608 
609   auto *VecType = cast<FixedVectorType>(CI->getType());
610 
611   Type *EltTy = VecType->getElementType();
612 
613   IRBuilder<> Builder(CI->getContext());
614   Instruction *InsertPt = CI;
615   BasicBlock *IfBlock = CI->getParent();
616 
617   Builder.SetInsertPoint(InsertPt);
618   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
619 
620   unsigned VectorWidth = VecType->getNumElements();
621 
622   // The result vector
623   Value *VResult = PassThru;
624 
625   // Shorten the way if the mask is a vector of constants.
626   // Create a build_vector pattern, with loads/undefs as necessary and then
627   // shuffle blend with the pass through value.
628   if (isConstantIntVector(Mask)) {
629     unsigned MemIndex = 0;
630     VResult = UndefValue::get(VecType);
631     SmallVector<int, 16> ShuffleMask(VectorWidth, UndefMaskElem);
632     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
633       Value *InsertElt;
634       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue()) {
635         InsertElt = UndefValue::get(EltTy);
636         ShuffleMask[Idx] = Idx + VectorWidth;
637       } else {
638         Value *NewPtr =
639             Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
640         InsertElt = Builder.CreateAlignedLoad(EltTy, NewPtr, Align(1),
641                                               "Load" + Twine(Idx));
642         ShuffleMask[Idx] = Idx;
643         ++MemIndex;
644       }
645       VResult = Builder.CreateInsertElement(VResult, InsertElt, Idx,
646                                             "Res" + Twine(Idx));
647     }
648     VResult = Builder.CreateShuffleVector(VResult, PassThru, ShuffleMask);
649     CI->replaceAllUsesWith(VResult);
650     CI->eraseFromParent();
651     return;
652   }
653 
654   // If the mask is not v1i1, use scalar bit test operations. This generates
655   // better results on X86 at least.
656   Value *SclrMask;
657   if (VectorWidth != 1) {
658     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
659     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
660   }
661 
662   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
663     // Fill the "else" block, created in the previous iteration
664     //
665     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
666     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
667     //  br i1 %mask_1, label %cond.load, label %else
668     //
669 
670     Value *Predicate;
671     if (VectorWidth != 1) {
672       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
673       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
674                                        Builder.getIntN(VectorWidth, 0));
675     } else {
676       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
677     }
678 
679     // Create "cond" block
680     //
681     //  %EltAddr = getelementptr i32* %1, i32 0
682     //  %Elt = load i32* %EltAddr
683     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
684     //
685     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
686                                                      "cond.load");
687     Builder.SetInsertPoint(InsertPt);
688 
689     LoadInst *Load = Builder.CreateAlignedLoad(EltTy, Ptr, Align(1));
690     Value *NewVResult = Builder.CreateInsertElement(VResult, Load, Idx);
691 
692     // Move the pointer if there are more blocks to come.
693     Value *NewPtr;
694     if ((Idx + 1) != VectorWidth)
695       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
696 
697     // Create "else" block, fill it in the next iteration
698     BasicBlock *NewIfBlock =
699         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
700     Builder.SetInsertPoint(InsertPt);
701     Instruction *OldBr = IfBlock->getTerminator();
702     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
703     OldBr->eraseFromParent();
704     BasicBlock *PrevIfBlock = IfBlock;
705     IfBlock = NewIfBlock;
706 
707     // Create the phi to join the new and previous value.
708     PHINode *ResultPhi = Builder.CreatePHI(VecType, 2, "res.phi.else");
709     ResultPhi->addIncoming(NewVResult, CondBlock);
710     ResultPhi->addIncoming(VResult, PrevIfBlock);
711     VResult = ResultPhi;
712 
713     // Add a PHI for the pointer if this isn't the last iteration.
714     if ((Idx + 1) != VectorWidth) {
715       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
716       PtrPhi->addIncoming(NewPtr, CondBlock);
717       PtrPhi->addIncoming(Ptr, PrevIfBlock);
718       Ptr = PtrPhi;
719     }
720   }
721 
722   CI->replaceAllUsesWith(VResult);
723   CI->eraseFromParent();
724 
725   ModifiedDT = true;
726 }
727 
728 static void scalarizeMaskedCompressStore(CallInst *CI, bool &ModifiedDT) {
729   Value *Src = CI->getArgOperand(0);
730   Value *Ptr = CI->getArgOperand(1);
731   Value *Mask = CI->getArgOperand(2);
732 
733   auto *VecType = cast<FixedVectorType>(Src->getType());
734 
735   IRBuilder<> Builder(CI->getContext());
736   Instruction *InsertPt = CI;
737   BasicBlock *IfBlock = CI->getParent();
738 
739   Builder.SetInsertPoint(InsertPt);
740   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
741 
742   Type *EltTy = VecType->getElementType();
743 
744   unsigned VectorWidth = VecType->getNumElements();
745 
746   // Shorten the way if the mask is a vector of constants.
747   if (isConstantIntVector(Mask)) {
748     unsigned MemIndex = 0;
749     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
750       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
751         continue;
752       Value *OneElt =
753           Builder.CreateExtractElement(Src, Idx, "Elt" + Twine(Idx));
754       Value *NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, MemIndex);
755       Builder.CreateAlignedStore(OneElt, NewPtr, Align(1));
756       ++MemIndex;
757     }
758     CI->eraseFromParent();
759     return;
760   }
761 
762   // If the mask is not v1i1, use scalar bit test operations. This generates
763   // better results on X86 at least.
764   Value *SclrMask;
765   if (VectorWidth != 1) {
766     Type *SclrMaskTy = Builder.getIntNTy(VectorWidth);
767     SclrMask = Builder.CreateBitCast(Mask, SclrMaskTy, "scalar_mask");
768   }
769 
770   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
771     // Fill the "else" block, created in the previous iteration
772     //
773     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
774     //  br i1 %mask_1, label %cond.store, label %else
775     //
776     Value *Predicate;
777     if (VectorWidth != 1) {
778       Value *Mask = Builder.getInt(APInt::getOneBitSet(VectorWidth, Idx));
779       Predicate = Builder.CreateICmpNE(Builder.CreateAnd(SclrMask, Mask),
780                                        Builder.getIntN(VectorWidth, 0));
781     } else {
782       Predicate = Builder.CreateExtractElement(Mask, Idx, "Mask" + Twine(Idx));
783     }
784 
785     // Create "cond" block
786     //
787     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
788     //  %EltAddr = getelementptr i32* %1, i32 0
789     //  %store i32 %OneElt, i32* %EltAddr
790     //
791     BasicBlock *CondBlock =
792         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
793     Builder.SetInsertPoint(InsertPt);
794 
795     Value *OneElt = Builder.CreateExtractElement(Src, Idx);
796     Builder.CreateAlignedStore(OneElt, Ptr, Align(1));
797 
798     // Move the pointer if there are more blocks to come.
799     Value *NewPtr;
800     if ((Idx + 1) != VectorWidth)
801       NewPtr = Builder.CreateConstInBoundsGEP1_32(EltTy, Ptr, 1);
802 
803     // Create "else" block, fill it in the next iteration
804     BasicBlock *NewIfBlock =
805         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
806     Builder.SetInsertPoint(InsertPt);
807     Instruction *OldBr = IfBlock->getTerminator();
808     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
809     OldBr->eraseFromParent();
810     BasicBlock *PrevIfBlock = IfBlock;
811     IfBlock = NewIfBlock;
812 
813     // Add a PHI for the pointer if this isn't the last iteration.
814     if ((Idx + 1) != VectorWidth) {
815       PHINode *PtrPhi = Builder.CreatePHI(Ptr->getType(), 2, "ptr.phi.else");
816       PtrPhi->addIncoming(NewPtr, CondBlock);
817       PtrPhi->addIncoming(Ptr, PrevIfBlock);
818       Ptr = PtrPhi;
819     }
820   }
821   CI->eraseFromParent();
822 
823   ModifiedDT = true;
824 }
825 
826 static bool runImpl(Function &F, const TargetTransformInfo &TTI) {
827   bool EverMadeChange = false;
828   bool MadeChange = true;
829   auto &DL = F.getParent()->getDataLayout();
830   while (MadeChange) {
831     MadeChange = false;
832     for (Function::iterator I = F.begin(); I != F.end();) {
833       BasicBlock *BB = &*I++;
834       bool ModifiedDTOnIteration = false;
835       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration, TTI, DL);
836 
837       // Restart BB iteration if the dominator tree of the Function was changed
838       if (ModifiedDTOnIteration)
839         break;
840     }
841 
842     EverMadeChange |= MadeChange;
843   }
844   return EverMadeChange;
845 }
846 
847 bool ScalarizeMaskedMemIntrinLegacyPass::runOnFunction(Function &F) {
848   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
849   return runImpl(F, TTI);
850 }
851 
852 PreservedAnalyses
853 ScalarizeMaskedMemIntrinPass::run(Function &F, FunctionAnalysisManager &AM) {
854   auto &TTI = AM.getResult<TargetIRAnalysis>(F);
855   if (!runImpl(F, TTI))
856     return PreservedAnalyses::all();
857   PreservedAnalyses PA;
858   PA.preserve<TargetIRAnalysis>();
859   return PA;
860 }
861 
862 static bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT,
863                           const TargetTransformInfo &TTI,
864                           const DataLayout &DL) {
865   bool MadeChange = false;
866 
867   BasicBlock::iterator CurInstIterator = BB.begin();
868   while (CurInstIterator != BB.end()) {
869     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
870       MadeChange |= optimizeCallInst(CI, ModifiedDT, TTI, DL);
871     if (ModifiedDT)
872       return true;
873   }
874 
875   return MadeChange;
876 }
877 
878 static bool optimizeCallInst(CallInst *CI, bool &ModifiedDT,
879                              const TargetTransformInfo &TTI,
880                              const DataLayout &DL) {
881   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
882   if (II) {
883     // The scalarization code below does not work for scalable vectors.
884     if (isa<ScalableVectorType>(II->getType()) ||
885         any_of(II->arg_operands(),
886                [](Value *V) { return isa<ScalableVectorType>(V->getType()); }))
887       return false;
888 
889     switch (II->getIntrinsicID()) {
890     default:
891       break;
892     case Intrinsic::masked_load:
893       // Scalarize unsupported vector masked load
894       if (TTI.isLegalMaskedLoad(
895               CI->getType(),
896               cast<ConstantInt>(CI->getArgOperand(1))->getAlignValue()))
897         return false;
898       scalarizeMaskedLoad(CI, ModifiedDT);
899       return true;
900     case Intrinsic::masked_store:
901       if (TTI.isLegalMaskedStore(
902               CI->getArgOperand(0)->getType(),
903               cast<ConstantInt>(CI->getArgOperand(2))->getAlignValue()))
904         return false;
905       scalarizeMaskedStore(CI, ModifiedDT);
906       return true;
907     case Intrinsic::masked_gather: {
908       unsigned AlignmentInt =
909           cast<ConstantInt>(CI->getArgOperand(1))->getZExtValue();
910       Type *LoadTy = CI->getType();
911       Align Alignment =
912           DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), LoadTy);
913       if (TTI.isLegalMaskedGather(LoadTy, Alignment))
914         return false;
915       scalarizeMaskedGather(CI, ModifiedDT);
916       return true;
917     }
918     case Intrinsic::masked_scatter: {
919       unsigned AlignmentInt =
920           cast<ConstantInt>(CI->getArgOperand(2))->getZExtValue();
921       Type *StoreTy = CI->getArgOperand(0)->getType();
922       Align Alignment =
923           DL.getValueOrABITypeAlignment(MaybeAlign(AlignmentInt), StoreTy);
924       if (TTI.isLegalMaskedScatter(StoreTy, Alignment))
925         return false;
926       scalarizeMaskedScatter(CI, ModifiedDT);
927       return true;
928     }
929     case Intrinsic::masked_expandload:
930       if (TTI.isLegalMaskedExpandLoad(CI->getType()))
931         return false;
932       scalarizeMaskedExpandLoad(CI, ModifiedDT);
933       return true;
934     case Intrinsic::masked_compressstore:
935       if (TTI.isLegalMaskedCompressStore(CI->getArgOperand(0)->getType()))
936         return false;
937       scalarizeMaskedCompressStore(CI, ModifiedDT);
938       return true;
939     }
940   }
941 
942   return false;
943 }
944