12cab237bSDimitry Andric //===- ScalarizeMaskedMemIntrin.cpp - Scalarize unsupported masked mem ----===//
22cab237bSDimitry Andric //                                    instrinsics
35517e702SDimitry Andric //
45517e702SDimitry Andric //                     The LLVM Compiler Infrastructure
55517e702SDimitry Andric //
65517e702SDimitry Andric // This file is distributed under the University of Illinois Open Source
75517e702SDimitry Andric // License. See LICENSE.TXT for details.
85517e702SDimitry Andric //
95517e702SDimitry Andric //===----------------------------------------------------------------------===//
105517e702SDimitry Andric //
115517e702SDimitry Andric // This pass replaces masked memory intrinsics - when unsupported by the target
125517e702SDimitry Andric // - with a chain of basic blocks, that deal with the elements one-by-one if the
135517e702SDimitry Andric // appropriate mask bit is set.
145517e702SDimitry Andric //
155517e702SDimitry Andric //===----------------------------------------------------------------------===//
165517e702SDimitry Andric 
172cab237bSDimitry Andric #include "llvm/ADT/Twine.h"
185517e702SDimitry Andric #include "llvm/Analysis/TargetTransformInfo.h"
192cab237bSDimitry Andric #include "llvm/CodeGen/TargetSubtargetInfo.h"
202cab237bSDimitry Andric #include "llvm/IR/BasicBlock.h"
212cab237bSDimitry Andric #include "llvm/IR/Constant.h"
222cab237bSDimitry Andric #include "llvm/IR/Constants.h"
232cab237bSDimitry Andric #include "llvm/IR/DerivedTypes.h"
242cab237bSDimitry Andric #include "llvm/IR/Function.h"
255517e702SDimitry Andric #include "llvm/IR/IRBuilder.h"
262cab237bSDimitry Andric #include "llvm/IR/InstrTypes.h"
272cab237bSDimitry Andric #include "llvm/IR/Instruction.h"
282cab237bSDimitry Andric #include "llvm/IR/Instructions.h"
292cab237bSDimitry Andric #include "llvm/IR/IntrinsicInst.h"
302cab237bSDimitry Andric #include "llvm/IR/Intrinsics.h"
312cab237bSDimitry Andric #include "llvm/IR/Type.h"
322cab237bSDimitry Andric #include "llvm/IR/Value.h"
332cab237bSDimitry Andric #include "llvm/Pass.h"
342cab237bSDimitry Andric #include "llvm/Support/Casting.h"
352cab237bSDimitry Andric #include <algorithm>
362cab237bSDimitry Andric #include <cassert>
375517e702SDimitry Andric 
385517e702SDimitry Andric using namespace llvm;
395517e702SDimitry Andric 
405517e702SDimitry Andric #define DEBUG_TYPE "scalarize-masked-mem-intrin"
415517e702SDimitry Andric 
425517e702SDimitry Andric namespace {
435517e702SDimitry Andric 
445517e702SDimitry Andric class ScalarizeMaskedMemIntrin : public FunctionPass {
452cab237bSDimitry Andric   const TargetTransformInfo *TTI = nullptr;
465517e702SDimitry Andric 
475517e702SDimitry Andric public:
485517e702SDimitry Andric   static char ID; // Pass identification, replacement for typeid
492cab237bSDimitry Andric 
ScalarizeMaskedMemIntrin()502cab237bSDimitry Andric   explicit ScalarizeMaskedMemIntrin() : FunctionPass(ID) {
515517e702SDimitry Andric     initializeScalarizeMaskedMemIntrinPass(*PassRegistry::getPassRegistry());
525517e702SDimitry Andric   }
532cab237bSDimitry Andric 
545517e702SDimitry Andric   bool runOnFunction(Function &F) override;
555517e702SDimitry Andric 
getPassName() const565517e702SDimitry Andric   StringRef getPassName() const override {
575517e702SDimitry Andric     return "Scalarize Masked Memory Intrinsics";
585517e702SDimitry Andric   }
595517e702SDimitry Andric 
getAnalysisUsage(AnalysisUsage & AU) const605517e702SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
615517e702SDimitry Andric     AU.addRequired<TargetTransformInfoWrapperPass>();
625517e702SDimitry Andric   }
635517e702SDimitry Andric 
645517e702SDimitry Andric private:
655517e702SDimitry Andric   bool optimizeBlock(BasicBlock &BB, bool &ModifiedDT);
665517e702SDimitry Andric   bool optimizeCallInst(CallInst *CI, bool &ModifiedDT);
675517e702SDimitry Andric };
682cab237bSDimitry Andric 
692cab237bSDimitry Andric } // end anonymous namespace
705517e702SDimitry Andric 
715517e702SDimitry Andric char ScalarizeMaskedMemIntrin::ID = 0;
722cab237bSDimitry Andric 
73302affcbSDimitry Andric INITIALIZE_PASS(ScalarizeMaskedMemIntrin, DEBUG_TYPE,
74302affcbSDimitry Andric                 "Scalarize unsupported masked memory intrinsics", false, false)
755517e702SDimitry Andric 
createScalarizeMaskedMemIntrinPass()765517e702SDimitry Andric FunctionPass *llvm::createScalarizeMaskedMemIntrinPass() {
775517e702SDimitry Andric   return new ScalarizeMaskedMemIntrin();
785517e702SDimitry Andric }
795517e702SDimitry Andric 
isConstantIntVector(Value * Mask)80*b5893f02SDimitry Andric static bool isConstantIntVector(Value *Mask) {
81*b5893f02SDimitry Andric   Constant *C = dyn_cast<Constant>(Mask);
82*b5893f02SDimitry Andric   if (!C)
83*b5893f02SDimitry Andric     return false;
84*b5893f02SDimitry Andric 
85*b5893f02SDimitry Andric   unsigned NumElts = Mask->getType()->getVectorNumElements();
86*b5893f02SDimitry Andric   for (unsigned i = 0; i != NumElts; ++i) {
87*b5893f02SDimitry Andric     Constant *CElt = C->getAggregateElement(i);
88*b5893f02SDimitry Andric     if (!CElt || !isa<ConstantInt>(CElt))
89*b5893f02SDimitry Andric       return false;
90*b5893f02SDimitry Andric   }
91*b5893f02SDimitry Andric 
92*b5893f02SDimitry Andric   return true;
93*b5893f02SDimitry Andric }
94*b5893f02SDimitry Andric 
955517e702SDimitry Andric // Translate a masked load intrinsic like
965517e702SDimitry Andric // <16 x i32 > @llvm.masked.load( <16 x i32>* %addr, i32 align,
975517e702SDimitry Andric //                               <16 x i1> %mask, <16 x i32> %passthru)
985517e702SDimitry Andric // to a chain of basic blocks, with loading element one-by-one if
995517e702SDimitry Andric // the appropriate mask bit is set
1005517e702SDimitry Andric //
1015517e702SDimitry Andric //  %1 = bitcast i8* %addr to i32*
1025517e702SDimitry Andric //  %2 = extractelement <16 x i1> %mask, i32 0
103*b5893f02SDimitry Andric //  br i1 %2, label %cond.load, label %else
1045517e702SDimitry Andric //
1055517e702SDimitry Andric // cond.load:                                        ; preds = %0
106*b5893f02SDimitry Andric //  %3 = getelementptr i32* %1, i32 0
107*b5893f02SDimitry Andric //  %4 = load i32* %3
108*b5893f02SDimitry Andric //  %5 = insertelement <16 x i32> %passthru, i32 %4, i32 0
1095517e702SDimitry Andric //  br label %else
1105517e702SDimitry Andric //
1115517e702SDimitry Andric // else:                                             ; preds = %0, %cond.load
112*b5893f02SDimitry Andric //  %res.phi.else = phi <16 x i32> [ %5, %cond.load ], [ undef, %0 ]
113*b5893f02SDimitry Andric //  %6 = extractelement <16 x i1> %mask, i32 1
114*b5893f02SDimitry Andric //  br i1 %6, label %cond.load1, label %else2
1155517e702SDimitry Andric //
1165517e702SDimitry Andric // cond.load1:                                       ; preds = %else
117*b5893f02SDimitry Andric //  %7 = getelementptr i32* %1, i32 1
118*b5893f02SDimitry Andric //  %8 = load i32* %7
119*b5893f02SDimitry Andric //  %9 = insertelement <16 x i32> %res.phi.else, i32 %8, i32 1
1205517e702SDimitry Andric //  br label %else2
1215517e702SDimitry Andric //
1225517e702SDimitry Andric // else2:                                          ; preds = %else, %cond.load1
123*b5893f02SDimitry Andric //  %res.phi.else3 = phi <16 x i32> [ %9, %cond.load1 ], [ %res.phi.else, %else ]
124*b5893f02SDimitry Andric //  %10 = extractelement <16 x i1> %mask, i32 2
125*b5893f02SDimitry Andric //  br i1 %10, label %cond.load4, label %else5
1265517e702SDimitry Andric //
scalarizeMaskedLoad(CallInst * CI)1275517e702SDimitry Andric static void scalarizeMaskedLoad(CallInst *CI) {
1285517e702SDimitry Andric   Value *Ptr = CI->getArgOperand(0);
1295517e702SDimitry Andric   Value *Alignment = CI->getArgOperand(1);
1305517e702SDimitry Andric   Value *Mask = CI->getArgOperand(2);
1315517e702SDimitry Andric   Value *Src0 = CI->getArgOperand(3);
1325517e702SDimitry Andric 
1335517e702SDimitry Andric   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
134*b5893f02SDimitry Andric   VectorType *VecType = cast<VectorType>(CI->getType());
1355517e702SDimitry Andric 
136*b5893f02SDimitry Andric   Type *EltTy = VecType->getElementType();
1375517e702SDimitry Andric 
1385517e702SDimitry Andric   IRBuilder<> Builder(CI->getContext());
1395517e702SDimitry Andric   Instruction *InsertPt = CI;
1405517e702SDimitry Andric   BasicBlock *IfBlock = CI->getParent();
1415517e702SDimitry Andric 
1425517e702SDimitry Andric   Builder.SetInsertPoint(InsertPt);
1435517e702SDimitry Andric   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
1445517e702SDimitry Andric 
1455517e702SDimitry Andric   // Short-cut if the mask is all-true.
146*b5893f02SDimitry Andric   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
1475517e702SDimitry Andric     Value *NewI = Builder.CreateAlignedLoad(Ptr, AlignVal);
1485517e702SDimitry Andric     CI->replaceAllUsesWith(NewI);
1495517e702SDimitry Andric     CI->eraseFromParent();
1505517e702SDimitry Andric     return;
1515517e702SDimitry Andric   }
1525517e702SDimitry Andric 
1535517e702SDimitry Andric   // Adjust alignment for the scalar instruction.
154*b5893f02SDimitry Andric   AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
1555517e702SDimitry Andric   // Bitcast %addr fron i8* to EltTy*
1565517e702SDimitry Andric   Type *NewPtrType =
1575517e702SDimitry Andric       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
1585517e702SDimitry Andric   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
1595517e702SDimitry Andric   unsigned VectorWidth = VecType->getNumElements();
1605517e702SDimitry Andric 
1615517e702SDimitry Andric   // The result vector
162*b5893f02SDimitry Andric   Value *VResult = Src0;
1635517e702SDimitry Andric 
164*b5893f02SDimitry Andric   if (isConstantIntVector(Mask)) {
1655517e702SDimitry Andric     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
166*b5893f02SDimitry Andric       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
1675517e702SDimitry Andric         continue;
1685517e702SDimitry Andric       Value *Gep =
1695517e702SDimitry Andric           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
1705517e702SDimitry Andric       LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
1715517e702SDimitry Andric       VResult =
1725517e702SDimitry Andric           Builder.CreateInsertElement(VResult, Load, Builder.getInt32(Idx));
1735517e702SDimitry Andric     }
174*b5893f02SDimitry Andric     CI->replaceAllUsesWith(VResult);
1755517e702SDimitry Andric     CI->eraseFromParent();
1765517e702SDimitry Andric     return;
1775517e702SDimitry Andric   }
1785517e702SDimitry Andric 
1795517e702SDimitry Andric   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
1805517e702SDimitry Andric     // Fill the "else" block, created in the previous iteration
1815517e702SDimitry Andric     //
1825517e702SDimitry Andric     //  %res.phi.else3 = phi <16 x i32> [ %11, %cond.load1 ], [ %res.phi.else, %else ]
1835517e702SDimitry Andric     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
184*b5893f02SDimitry Andric     //  br i1 %mask_1, label %cond.load, label %else
1855517e702SDimitry Andric     //
1865517e702SDimitry Andric 
1875517e702SDimitry Andric     Value *Predicate =
1885517e702SDimitry Andric         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
1895517e702SDimitry Andric 
1905517e702SDimitry Andric     // Create "cond" block
1915517e702SDimitry Andric     //
1925517e702SDimitry Andric     //  %EltAddr = getelementptr i32* %1, i32 0
1935517e702SDimitry Andric     //  %Elt = load i32* %EltAddr
1945517e702SDimitry Andric     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
1955517e702SDimitry Andric     //
196*b5893f02SDimitry Andric     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt->getIterator(),
197*b5893f02SDimitry Andric                                                      "cond.load");
1985517e702SDimitry Andric     Builder.SetInsertPoint(InsertPt);
1995517e702SDimitry Andric 
2005517e702SDimitry Andric     Value *Gep =
2015517e702SDimitry Andric         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
2025517e702SDimitry Andric     LoadInst *Load = Builder.CreateAlignedLoad(Gep, AlignVal);
203*b5893f02SDimitry Andric     Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
204*b5893f02SDimitry Andric                                                     Builder.getInt32(Idx));
2055517e702SDimitry Andric 
2065517e702SDimitry Andric     // Create "else" block, fill it in the next iteration
2075517e702SDimitry Andric     BasicBlock *NewIfBlock =
2085517e702SDimitry Andric         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
2095517e702SDimitry Andric     Builder.SetInsertPoint(InsertPt);
2105517e702SDimitry Andric     Instruction *OldBr = IfBlock->getTerminator();
211*b5893f02SDimitry Andric     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
2125517e702SDimitry Andric     OldBr->eraseFromParent();
213*b5893f02SDimitry Andric     BasicBlock *PrevIfBlock = IfBlock;
2145517e702SDimitry Andric     IfBlock = NewIfBlock;
215*b5893f02SDimitry Andric 
216*b5893f02SDimitry Andric     // Create the phi to join the new and previous value.
217*b5893f02SDimitry Andric     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
218*b5893f02SDimitry Andric     Phi->addIncoming(NewVResult, CondBlock);
219*b5893f02SDimitry Andric     Phi->addIncoming(VResult, PrevIfBlock);
220*b5893f02SDimitry Andric     VResult = Phi;
2215517e702SDimitry Andric   }
2225517e702SDimitry Andric 
223*b5893f02SDimitry Andric   CI->replaceAllUsesWith(VResult);
2245517e702SDimitry Andric   CI->eraseFromParent();
2255517e702SDimitry Andric }
2265517e702SDimitry Andric 
2275517e702SDimitry Andric // Translate a masked store intrinsic, like
2285517e702SDimitry Andric // void @llvm.masked.store(<16 x i32> %src, <16 x i32>* %addr, i32 align,
2295517e702SDimitry Andric //                               <16 x i1> %mask)
2305517e702SDimitry Andric // to a chain of basic blocks, that stores element one-by-one if
2315517e702SDimitry Andric // the appropriate mask bit is set
2325517e702SDimitry Andric //
2335517e702SDimitry Andric //   %1 = bitcast i8* %addr to i32*
2345517e702SDimitry Andric //   %2 = extractelement <16 x i1> %mask, i32 0
235*b5893f02SDimitry Andric //   br i1 %2, label %cond.store, label %else
2365517e702SDimitry Andric //
2375517e702SDimitry Andric // cond.store:                                       ; preds = %0
238*b5893f02SDimitry Andric //   %3 = extractelement <16 x i32> %val, i32 0
239*b5893f02SDimitry Andric //   %4 = getelementptr i32* %1, i32 0
240*b5893f02SDimitry Andric //   store i32 %3, i32* %4
2415517e702SDimitry Andric //   br label %else
2425517e702SDimitry Andric //
2435517e702SDimitry Andric // else:                                             ; preds = %0, %cond.store
244*b5893f02SDimitry Andric //   %5 = extractelement <16 x i1> %mask, i32 1
245*b5893f02SDimitry Andric //   br i1 %5, label %cond.store1, label %else2
2465517e702SDimitry Andric //
2475517e702SDimitry Andric // cond.store1:                                      ; preds = %else
248*b5893f02SDimitry Andric //   %6 = extractelement <16 x i32> %val, i32 1
249*b5893f02SDimitry Andric //   %7 = getelementptr i32* %1, i32 1
250*b5893f02SDimitry Andric //   store i32 %6, i32* %7
2515517e702SDimitry Andric //   br label %else2
2525517e702SDimitry Andric //   . . .
scalarizeMaskedStore(CallInst * CI)2535517e702SDimitry Andric static void scalarizeMaskedStore(CallInst *CI) {
2545517e702SDimitry Andric   Value *Src = CI->getArgOperand(0);
2555517e702SDimitry Andric   Value *Ptr = CI->getArgOperand(1);
2565517e702SDimitry Andric   Value *Alignment = CI->getArgOperand(2);
2575517e702SDimitry Andric   Value *Mask = CI->getArgOperand(3);
2585517e702SDimitry Andric 
2595517e702SDimitry Andric   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
260*b5893f02SDimitry Andric   VectorType *VecType = cast<VectorType>(Src->getType());
2615517e702SDimitry Andric 
2625517e702SDimitry Andric   Type *EltTy = VecType->getElementType();
2635517e702SDimitry Andric 
2645517e702SDimitry Andric   IRBuilder<> Builder(CI->getContext());
2655517e702SDimitry Andric   Instruction *InsertPt = CI;
2665517e702SDimitry Andric   BasicBlock *IfBlock = CI->getParent();
2675517e702SDimitry Andric   Builder.SetInsertPoint(InsertPt);
2685517e702SDimitry Andric   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
2695517e702SDimitry Andric 
2705517e702SDimitry Andric   // Short-cut if the mask is all-true.
271*b5893f02SDimitry Andric   if (isa<Constant>(Mask) && cast<Constant>(Mask)->isAllOnesValue()) {
2725517e702SDimitry Andric     Builder.CreateAlignedStore(Src, Ptr, AlignVal);
2735517e702SDimitry Andric     CI->eraseFromParent();
2745517e702SDimitry Andric     return;
2755517e702SDimitry Andric   }
2765517e702SDimitry Andric 
2775517e702SDimitry Andric   // Adjust alignment for the scalar instruction.
278*b5893f02SDimitry Andric   AlignVal = MinAlign(AlignVal, EltTy->getPrimitiveSizeInBits() / 8);
2795517e702SDimitry Andric   // Bitcast %addr fron i8* to EltTy*
2805517e702SDimitry Andric   Type *NewPtrType =
2815517e702SDimitry Andric       EltTy->getPointerTo(cast<PointerType>(Ptr->getType())->getAddressSpace());
2825517e702SDimitry Andric   Value *FirstEltPtr = Builder.CreateBitCast(Ptr, NewPtrType);
2835517e702SDimitry Andric   unsigned VectorWidth = VecType->getNumElements();
2845517e702SDimitry Andric 
285*b5893f02SDimitry Andric   if (isConstantIntVector(Mask)) {
2865517e702SDimitry Andric     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
287*b5893f02SDimitry Andric       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
2885517e702SDimitry Andric         continue;
2895517e702SDimitry Andric       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
2905517e702SDimitry Andric       Value *Gep =
2915517e702SDimitry Andric           Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
2925517e702SDimitry Andric       Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
2935517e702SDimitry Andric     }
2945517e702SDimitry Andric     CI->eraseFromParent();
2955517e702SDimitry Andric     return;
2965517e702SDimitry Andric   }
2975517e702SDimitry Andric 
2985517e702SDimitry Andric   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
2995517e702SDimitry Andric     // Fill the "else" block, created in the previous iteration
3005517e702SDimitry Andric     //
3015517e702SDimitry Andric     //  %mask_1 = extractelement <16 x i1> %mask, i32 Idx
302*b5893f02SDimitry Andric     //  br i1 %mask_1, label %cond.store, label %else
3035517e702SDimitry Andric     //
3045517e702SDimitry Andric     Value *Predicate =
3055517e702SDimitry Andric         Builder.CreateExtractElement(Mask, Builder.getInt32(Idx));
3065517e702SDimitry Andric 
3075517e702SDimitry Andric     // Create "cond" block
3085517e702SDimitry Andric     //
3095517e702SDimitry Andric     //  %OneElt = extractelement <16 x i32> %Src, i32 Idx
3105517e702SDimitry Andric     //  %EltAddr = getelementptr i32* %1, i32 0
3115517e702SDimitry Andric     //  %store i32 %OneElt, i32* %EltAddr
3125517e702SDimitry Andric     //
3135517e702SDimitry Andric     BasicBlock *CondBlock =
3145517e702SDimitry Andric         IfBlock->splitBasicBlock(InsertPt->getIterator(), "cond.store");
3155517e702SDimitry Andric     Builder.SetInsertPoint(InsertPt);
3165517e702SDimitry Andric 
3175517e702SDimitry Andric     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx));
3185517e702SDimitry Andric     Value *Gep =
3195517e702SDimitry Andric         Builder.CreateInBoundsGEP(EltTy, FirstEltPtr, Builder.getInt32(Idx));
3205517e702SDimitry Andric     Builder.CreateAlignedStore(OneElt, Gep, AlignVal);
3215517e702SDimitry Andric 
3225517e702SDimitry Andric     // Create "else" block, fill it in the next iteration
3235517e702SDimitry Andric     BasicBlock *NewIfBlock =
3245517e702SDimitry Andric         CondBlock->splitBasicBlock(InsertPt->getIterator(), "else");
3255517e702SDimitry Andric     Builder.SetInsertPoint(InsertPt);
3265517e702SDimitry Andric     Instruction *OldBr = IfBlock->getTerminator();
327*b5893f02SDimitry Andric     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
3285517e702SDimitry Andric     OldBr->eraseFromParent();
3295517e702SDimitry Andric     IfBlock = NewIfBlock;
3305517e702SDimitry Andric   }
3315517e702SDimitry Andric   CI->eraseFromParent();
3325517e702SDimitry Andric }
3335517e702SDimitry Andric 
3345517e702SDimitry Andric // Translate a masked gather intrinsic like
3355517e702SDimitry Andric // <16 x i32 > @llvm.masked.gather.v16i32( <16 x i32*> %Ptrs, i32 4,
3365517e702SDimitry Andric //                               <16 x i1> %Mask, <16 x i32> %Src)
3375517e702SDimitry Andric // to a chain of basic blocks, with loading element one-by-one if
3385517e702SDimitry Andric // the appropriate mask bit is set
3395517e702SDimitry Andric //
3405517e702SDimitry Andric // %Ptrs = getelementptr i32, i32* %base, <16 x i64> %ind
3415517e702SDimitry Andric // %Mask0 = extractelement <16 x i1> %Mask, i32 0
342*b5893f02SDimitry Andric // br i1 %Mask0, label %cond.load, label %else
3435517e702SDimitry Andric //
3445517e702SDimitry Andric // cond.load:
3455517e702SDimitry Andric // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
3465517e702SDimitry Andric // %Load0 = load i32, i32* %Ptr0, align 4
3475517e702SDimitry Andric // %Res0 = insertelement <16 x i32> undef, i32 %Load0, i32 0
3485517e702SDimitry Andric // br label %else
3495517e702SDimitry Andric //
3505517e702SDimitry Andric // else:
3515517e702SDimitry Andric // %res.phi.else = phi <16 x i32>[%Res0, %cond.load], [undef, %0]
3525517e702SDimitry Andric // %Mask1 = extractelement <16 x i1> %Mask, i32 1
353*b5893f02SDimitry Andric // br i1 %Mask1, label %cond.load1, label %else2
3545517e702SDimitry Andric //
3555517e702SDimitry Andric // cond.load1:
3565517e702SDimitry Andric // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
3575517e702SDimitry Andric // %Load1 = load i32, i32* %Ptr1, align 4
3585517e702SDimitry Andric // %Res1 = insertelement <16 x i32> %res.phi.else, i32 %Load1, i32 1
3595517e702SDimitry Andric // br label %else2
3605517e702SDimitry Andric // . . .
3615517e702SDimitry Andric // %Result = select <16 x i1> %Mask, <16 x i32> %res.phi.select, <16 x i32> %Src
3625517e702SDimitry Andric // ret <16 x i32> %Result
scalarizeMaskedGather(CallInst * CI)3635517e702SDimitry Andric static void scalarizeMaskedGather(CallInst *CI) {
3645517e702SDimitry Andric   Value *Ptrs = CI->getArgOperand(0);
3655517e702SDimitry Andric   Value *Alignment = CI->getArgOperand(1);
3665517e702SDimitry Andric   Value *Mask = CI->getArgOperand(2);
3675517e702SDimitry Andric   Value *Src0 = CI->getArgOperand(3);
3685517e702SDimitry Andric 
369*b5893f02SDimitry Andric   VectorType *VecType = cast<VectorType>(CI->getType());
3705517e702SDimitry Andric 
3715517e702SDimitry Andric   IRBuilder<> Builder(CI->getContext());
3725517e702SDimitry Andric   Instruction *InsertPt = CI;
3735517e702SDimitry Andric   BasicBlock *IfBlock = CI->getParent();
3745517e702SDimitry Andric   Builder.SetInsertPoint(InsertPt);
3755517e702SDimitry Andric   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
3765517e702SDimitry Andric 
3775517e702SDimitry Andric   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
3785517e702SDimitry Andric 
3795517e702SDimitry Andric   // The result vector
380*b5893f02SDimitry Andric   Value *VResult = Src0;
3815517e702SDimitry Andric   unsigned VectorWidth = VecType->getNumElements();
3825517e702SDimitry Andric 
3835517e702SDimitry Andric   // Shorten the way if the mask is a vector of constants.
384*b5893f02SDimitry Andric   if (isConstantIntVector(Mask)) {
3855517e702SDimitry Andric     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
386*b5893f02SDimitry Andric       if (cast<Constant>(Mask)->getAggregateElement(Idx)->isNullValue())
3875517e702SDimitry Andric         continue;
3885517e702SDimitry Andric       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
3895517e702SDimitry Andric                                                 "Ptr" + Twine(Idx));
3905517e702SDimitry Andric       LoadInst *Load =
3915517e702SDimitry Andric           Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
3925517e702SDimitry Andric       VResult = Builder.CreateInsertElement(
3935517e702SDimitry Andric           VResult, Load, Builder.getInt32(Idx), "Res" + Twine(Idx));
3945517e702SDimitry Andric     }
395*b5893f02SDimitry Andric     CI->replaceAllUsesWith(VResult);
3965517e702SDimitry Andric     CI->eraseFromParent();
3975517e702SDimitry Andric     return;
3985517e702SDimitry Andric   }
3995517e702SDimitry Andric 
4005517e702SDimitry Andric   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
4015517e702SDimitry Andric     // Fill the "else" block, created in the previous iteration
4025517e702SDimitry Andric     //
4035517e702SDimitry Andric     //  %Mask1 = extractelement <16 x i1> %Mask, i32 1
404*b5893f02SDimitry Andric     //  br i1 %Mask1, label %cond.load, label %else
4055517e702SDimitry Andric     //
4065517e702SDimitry Andric 
4075517e702SDimitry Andric     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
4085517e702SDimitry Andric                                                     "Mask" + Twine(Idx));
4095517e702SDimitry Andric 
4105517e702SDimitry Andric     // Create "cond" block
4115517e702SDimitry Andric     //
4125517e702SDimitry Andric     //  %EltAddr = getelementptr i32* %1, i32 0
4135517e702SDimitry Andric     //  %Elt = load i32* %EltAddr
4145517e702SDimitry Andric     //  VResult = insertelement <16 x i32> VResult, i32 %Elt, i32 Idx
4155517e702SDimitry Andric     //
416*b5893f02SDimitry Andric     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.load");
4175517e702SDimitry Andric     Builder.SetInsertPoint(InsertPt);
4185517e702SDimitry Andric 
4195517e702SDimitry Andric     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
4205517e702SDimitry Andric                                               "Ptr" + Twine(Idx));
4215517e702SDimitry Andric     LoadInst *Load =
4225517e702SDimitry Andric         Builder.CreateAlignedLoad(Ptr, AlignVal, "Load" + Twine(Idx));
423*b5893f02SDimitry Andric     Value *NewVResult = Builder.CreateInsertElement(VResult, Load,
424*b5893f02SDimitry Andric                                                     Builder.getInt32(Idx),
4255517e702SDimitry Andric                                                     "Res" + Twine(Idx));
4265517e702SDimitry Andric 
4275517e702SDimitry Andric     // Create "else" block, fill it in the next iteration
4285517e702SDimitry Andric     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
4295517e702SDimitry Andric     Builder.SetInsertPoint(InsertPt);
4305517e702SDimitry Andric     Instruction *OldBr = IfBlock->getTerminator();
431*b5893f02SDimitry Andric     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
4325517e702SDimitry Andric     OldBr->eraseFromParent();
433*b5893f02SDimitry Andric     BasicBlock *PrevIfBlock = IfBlock;
4345517e702SDimitry Andric     IfBlock = NewIfBlock;
435*b5893f02SDimitry Andric 
436*b5893f02SDimitry Andric     PHINode *Phi = Builder.CreatePHI(VecType, 2, "res.phi.else");
437*b5893f02SDimitry Andric     Phi->addIncoming(NewVResult, CondBlock);
438*b5893f02SDimitry Andric     Phi->addIncoming(VResult, PrevIfBlock);
439*b5893f02SDimitry Andric     VResult = Phi;
4405517e702SDimitry Andric   }
4415517e702SDimitry Andric 
442*b5893f02SDimitry Andric   CI->replaceAllUsesWith(VResult);
4435517e702SDimitry Andric   CI->eraseFromParent();
4445517e702SDimitry Andric }
4455517e702SDimitry Andric 
4465517e702SDimitry Andric // Translate a masked scatter intrinsic, like
4475517e702SDimitry Andric // void @llvm.masked.scatter.v16i32(<16 x i32> %Src, <16 x i32*>* %Ptrs, i32 4,
4485517e702SDimitry Andric //                                  <16 x i1> %Mask)
4495517e702SDimitry Andric // to a chain of basic blocks, that stores element one-by-one if
4505517e702SDimitry Andric // the appropriate mask bit is set.
4515517e702SDimitry Andric //
4525517e702SDimitry Andric // %Ptrs = getelementptr i32, i32* %ptr, <16 x i64> %ind
4535517e702SDimitry Andric // %Mask0 = extractelement <16 x i1> %Mask, i32 0
454*b5893f02SDimitry Andric // br i1 %Mask0, label %cond.store, label %else
4555517e702SDimitry Andric //
4565517e702SDimitry Andric // cond.store:
4575517e702SDimitry Andric // %Elt0 = extractelement <16 x i32> %Src, i32 0
4585517e702SDimitry Andric // %Ptr0 = extractelement <16 x i32*> %Ptrs, i32 0
4595517e702SDimitry Andric // store i32 %Elt0, i32* %Ptr0, align 4
4605517e702SDimitry Andric // br label %else
4615517e702SDimitry Andric //
4625517e702SDimitry Andric // else:
4635517e702SDimitry Andric // %Mask1 = extractelement <16 x i1> %Mask, i32 1
464*b5893f02SDimitry Andric // br i1 %Mask1, label %cond.store1, label %else2
4655517e702SDimitry Andric //
4665517e702SDimitry Andric // cond.store1:
4675517e702SDimitry Andric // %Elt1 = extractelement <16 x i32> %Src, i32 1
4685517e702SDimitry Andric // %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
4695517e702SDimitry Andric // store i32 %Elt1, i32* %Ptr1, align 4
4705517e702SDimitry Andric // br label %else2
4715517e702SDimitry Andric //   . . .
scalarizeMaskedScatter(CallInst * CI)4725517e702SDimitry Andric static void scalarizeMaskedScatter(CallInst *CI) {
4735517e702SDimitry Andric   Value *Src = CI->getArgOperand(0);
4745517e702SDimitry Andric   Value *Ptrs = CI->getArgOperand(1);
4755517e702SDimitry Andric   Value *Alignment = CI->getArgOperand(2);
4765517e702SDimitry Andric   Value *Mask = CI->getArgOperand(3);
4775517e702SDimitry Andric 
4785517e702SDimitry Andric   assert(isa<VectorType>(Src->getType()) &&
4795517e702SDimitry Andric          "Unexpected data type in masked scatter intrinsic");
4805517e702SDimitry Andric   assert(isa<VectorType>(Ptrs->getType()) &&
4815517e702SDimitry Andric          isa<PointerType>(Ptrs->getType()->getVectorElementType()) &&
4825517e702SDimitry Andric          "Vector of pointers is expected in masked scatter intrinsic");
4835517e702SDimitry Andric 
4845517e702SDimitry Andric   IRBuilder<> Builder(CI->getContext());
4855517e702SDimitry Andric   Instruction *InsertPt = CI;
4865517e702SDimitry Andric   BasicBlock *IfBlock = CI->getParent();
4875517e702SDimitry Andric   Builder.SetInsertPoint(InsertPt);
4885517e702SDimitry Andric   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
4895517e702SDimitry Andric 
4905517e702SDimitry Andric   unsigned AlignVal = cast<ConstantInt>(Alignment)->getZExtValue();
4915517e702SDimitry Andric   unsigned VectorWidth = Src->getType()->getVectorNumElements();
4925517e702SDimitry Andric 
4935517e702SDimitry Andric   // Shorten the way if the mask is a vector of constants.
494*b5893f02SDimitry Andric   if (isConstantIntVector(Mask)) {
4955517e702SDimitry Andric     for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
496*b5893f02SDimitry Andric       if (cast<ConstantVector>(Mask)->getAggregateElement(Idx)->isNullValue())
4975517e702SDimitry Andric         continue;
4985517e702SDimitry Andric       Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
4995517e702SDimitry Andric                                                    "Elt" + Twine(Idx));
5005517e702SDimitry Andric       Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
5015517e702SDimitry Andric                                                 "Ptr" + Twine(Idx));
5025517e702SDimitry Andric       Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
5035517e702SDimitry Andric     }
5045517e702SDimitry Andric     CI->eraseFromParent();
5055517e702SDimitry Andric     return;
5065517e702SDimitry Andric   }
507*b5893f02SDimitry Andric 
5085517e702SDimitry Andric   for (unsigned Idx = 0; Idx < VectorWidth; ++Idx) {
5095517e702SDimitry Andric     // Fill the "else" block, created in the previous iteration
5105517e702SDimitry Andric     //
5115517e702SDimitry Andric     //  %Mask1 = extractelement <16 x i1> %Mask, i32 Idx
512*b5893f02SDimitry Andric     //  br i1 %Mask1, label %cond.store, label %else
5135517e702SDimitry Andric     //
5145517e702SDimitry Andric     Value *Predicate = Builder.CreateExtractElement(Mask, Builder.getInt32(Idx),
5155517e702SDimitry Andric                                                     "Mask" + Twine(Idx));
5165517e702SDimitry Andric 
5175517e702SDimitry Andric     // Create "cond" block
5185517e702SDimitry Andric     //
5195517e702SDimitry Andric     //  %Elt1 = extractelement <16 x i32> %Src, i32 1
5205517e702SDimitry Andric     //  %Ptr1 = extractelement <16 x i32*> %Ptrs, i32 1
5215517e702SDimitry Andric     //  %store i32 %Elt1, i32* %Ptr1
5225517e702SDimitry Andric     //
5235517e702SDimitry Andric     BasicBlock *CondBlock = IfBlock->splitBasicBlock(InsertPt, "cond.store");
5245517e702SDimitry Andric     Builder.SetInsertPoint(InsertPt);
5255517e702SDimitry Andric 
5265517e702SDimitry Andric     Value *OneElt = Builder.CreateExtractElement(Src, Builder.getInt32(Idx),
5275517e702SDimitry Andric                                                  "Elt" + Twine(Idx));
5285517e702SDimitry Andric     Value *Ptr = Builder.CreateExtractElement(Ptrs, Builder.getInt32(Idx),
5295517e702SDimitry Andric                                               "Ptr" + Twine(Idx));
5305517e702SDimitry Andric     Builder.CreateAlignedStore(OneElt, Ptr, AlignVal);
5315517e702SDimitry Andric 
5325517e702SDimitry Andric     // Create "else" block, fill it in the next iteration
5335517e702SDimitry Andric     BasicBlock *NewIfBlock = CondBlock->splitBasicBlock(InsertPt, "else");
5345517e702SDimitry Andric     Builder.SetInsertPoint(InsertPt);
5355517e702SDimitry Andric     Instruction *OldBr = IfBlock->getTerminator();
536*b5893f02SDimitry Andric     BranchInst::Create(CondBlock, NewIfBlock, Predicate, OldBr);
5375517e702SDimitry Andric     OldBr->eraseFromParent();
5385517e702SDimitry Andric     IfBlock = NewIfBlock;
5395517e702SDimitry Andric   }
5405517e702SDimitry Andric   CI->eraseFromParent();
5415517e702SDimitry Andric }
5425517e702SDimitry Andric 
runOnFunction(Function & F)5435517e702SDimitry Andric bool ScalarizeMaskedMemIntrin::runOnFunction(Function &F) {
5445517e702SDimitry Andric   bool EverMadeChange = false;
5455517e702SDimitry Andric 
5465517e702SDimitry Andric   TTI = &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
5475517e702SDimitry Andric 
5485517e702SDimitry Andric   bool MadeChange = true;
5495517e702SDimitry Andric   while (MadeChange) {
5505517e702SDimitry Andric     MadeChange = false;
5515517e702SDimitry Andric     for (Function::iterator I = F.begin(); I != F.end();) {
5525517e702SDimitry Andric       BasicBlock *BB = &*I++;
5535517e702SDimitry Andric       bool ModifiedDTOnIteration = false;
5545517e702SDimitry Andric       MadeChange |= optimizeBlock(*BB, ModifiedDTOnIteration);
5555517e702SDimitry Andric 
5565517e702SDimitry Andric       // Restart BB iteration if the dominator tree of the Function was changed
5575517e702SDimitry Andric       if (ModifiedDTOnIteration)
5585517e702SDimitry Andric         break;
5595517e702SDimitry Andric     }
5605517e702SDimitry Andric 
5615517e702SDimitry Andric     EverMadeChange |= MadeChange;
5625517e702SDimitry Andric   }
5635517e702SDimitry Andric 
5645517e702SDimitry Andric   return EverMadeChange;
5655517e702SDimitry Andric }
5665517e702SDimitry Andric 
optimizeBlock(BasicBlock & BB,bool & ModifiedDT)5675517e702SDimitry Andric bool ScalarizeMaskedMemIntrin::optimizeBlock(BasicBlock &BB, bool &ModifiedDT) {
5685517e702SDimitry Andric   bool MadeChange = false;
5695517e702SDimitry Andric 
5705517e702SDimitry Andric   BasicBlock::iterator CurInstIterator = BB.begin();
5715517e702SDimitry Andric   while (CurInstIterator != BB.end()) {
5725517e702SDimitry Andric     if (CallInst *CI = dyn_cast<CallInst>(&*CurInstIterator++))
5735517e702SDimitry Andric       MadeChange |= optimizeCallInst(CI, ModifiedDT);
5745517e702SDimitry Andric     if (ModifiedDT)
5755517e702SDimitry Andric       return true;
5765517e702SDimitry Andric   }
5775517e702SDimitry Andric 
5785517e702SDimitry Andric   return MadeChange;
5795517e702SDimitry Andric }
5805517e702SDimitry Andric 
optimizeCallInst(CallInst * CI,bool & ModifiedDT)5815517e702SDimitry Andric bool ScalarizeMaskedMemIntrin::optimizeCallInst(CallInst *CI,
5825517e702SDimitry Andric                                                 bool &ModifiedDT) {
5835517e702SDimitry Andric   IntrinsicInst *II = dyn_cast<IntrinsicInst>(CI);
5845517e702SDimitry Andric   if (II) {
5855517e702SDimitry Andric     switch (II->getIntrinsicID()) {
5865517e702SDimitry Andric     default:
5875517e702SDimitry Andric       break;
5882cab237bSDimitry Andric     case Intrinsic::masked_load:
5895517e702SDimitry Andric       // Scalarize unsupported vector masked load
5905517e702SDimitry Andric       if (!TTI->isLegalMaskedLoad(CI->getType())) {
5915517e702SDimitry Andric         scalarizeMaskedLoad(CI);
5925517e702SDimitry Andric         ModifiedDT = true;
5935517e702SDimitry Andric         return true;
5945517e702SDimitry Andric       }
5955517e702SDimitry Andric       return false;
5962cab237bSDimitry Andric     case Intrinsic::masked_store:
5975517e702SDimitry Andric       if (!TTI->isLegalMaskedStore(CI->getArgOperand(0)->getType())) {
5985517e702SDimitry Andric         scalarizeMaskedStore(CI);
5995517e702SDimitry Andric         ModifiedDT = true;
6005517e702SDimitry Andric         return true;
6015517e702SDimitry Andric       }
6025517e702SDimitry Andric       return false;
6032cab237bSDimitry Andric     case Intrinsic::masked_gather:
6045517e702SDimitry Andric       if (!TTI->isLegalMaskedGather(CI->getType())) {
6055517e702SDimitry Andric         scalarizeMaskedGather(CI);
6065517e702SDimitry Andric         ModifiedDT = true;
6075517e702SDimitry Andric         return true;
6085517e702SDimitry Andric       }
6095517e702SDimitry Andric       return false;
6102cab237bSDimitry Andric     case Intrinsic::masked_scatter:
6115517e702SDimitry Andric       if (!TTI->isLegalMaskedScatter(CI->getArgOperand(0)->getType())) {
6125517e702SDimitry Andric         scalarizeMaskedScatter(CI);
6135517e702SDimitry Andric         ModifiedDT = true;
6145517e702SDimitry Andric         return true;
6155517e702SDimitry Andric       }
6165517e702SDimitry Andric       return false;
6175517e702SDimitry Andric     }
6185517e702SDimitry Andric   }
6195517e702SDimitry Andric 
6205517e702SDimitry Andric   return false;
6215517e702SDimitry Andric }
622