1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===//
2 //
3 //                     The LLVM Compiler Infrastructure
4 //
5 // This file is distributed under the University of Illinois Open Source
6 // License. See LICENSE.TXT for details.
7 //
8 //===----------------------------------------------------------------------===//
9 //
10 // This pass tries to partially inline the fast path of well-known library
11 // functions, such as using square-root instructions for cases where sqrt()
12 // does not need to set errno.
13 //
14 //===----------------------------------------------------------------------===//
15 
16 #include "llvm/ADT/Statistic.h"
17 #include "llvm/Analysis/ConstantFolding.h"
18 #include "llvm/Analysis/TargetLibraryInfo.h"
19 #include "llvm/Analysis/TargetTransformInfo.h"
20 #include "llvm/Analysis/ValueTracking.h"
21 #include "llvm/CodeGen/TargetPassConfig.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/Target/TargetLowering.h"
24 #include "llvm/Target/TargetSubtargetInfo.h"
25 #include "llvm/Transforms/Scalar.h"
26 #include "llvm/Transforms/Utils/BasicBlockUtils.h"
27 
28 using namespace llvm;
29 
30 #define DEBUG_TYPE "expandmemcmp"
31 
32 STATISTIC(NumMemCmpCalls, "Number of memcmp calls");
33 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size");
34 STATISTIC(NumMemCmpGreaterThanMax,
35           "Number of memcmp calls with size greater than max size");
36 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls");
37 
38 static cl::opt<unsigned> MemCmpNumLoadsPerBlock(
39     "memcmp-num-loads-per-block", cl::Hidden, cl::init(1),
40     cl::desc("The number of loads per basic block for inline expansion of "
41              "memcmp that is only being compared against zero."));
42 
43 namespace {
44 
45 
46 // This class provides helper functions to expand a memcmp library call into an
47 // inline expansion.
48 class MemCmpExpansion {
49   struct ResultBlock {
50     BasicBlock *BB = nullptr;
51     PHINode *PhiSrc1 = nullptr;
52     PHINode *PhiSrc2 = nullptr;
53 
54     ResultBlock() = default;
55   };
56 
57   CallInst *const CI;
58   ResultBlock ResBlock;
59   const uint64_t Size;
60   unsigned MaxLoadSize;
61   uint64_t NumLoadsNonOneByte;
62   const uint64_t NumLoadsPerBlock;
63   std::vector<BasicBlock *> LoadCmpBlocks;
64   BasicBlock *EndBlock;
65   PHINode *PhiRes;
66   const bool IsUsedForZeroCmp;
67   const DataLayout &DL;
68   IRBuilder<> Builder;
69   // Represents the decomposition in blocks of the expansion. For example,
70   // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and
71   // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}.
72   // TODO(courbet): Involve the target more in this computation. On X86, 7
73   // bytes can be done more efficiently with two overlaping 4-byte loads than
74   // covering the interval with [{4, 0},{2, 4},{1, 6}}.
75   struct LoadEntry {
76     LoadEntry(unsigned LoadSize, uint64_t Offset)
77         : LoadSize(LoadSize), Offset(Offset) {
78       assert(Offset % LoadSize == 0 && "invalid load entry");
79     }
80 
81     uint64_t getGEPIndex() const { return Offset / LoadSize; }
82 
83     // The size of the load for this block, in bytes.
84     const unsigned LoadSize;
85     // The offset of this load WRT the base pointer, in bytes.
86     const uint64_t Offset;
87   };
88   SmallVector<LoadEntry, 8> LoadSequence;
89 
90   void createLoadCmpBlocks();
91   void createResultBlock();
92   void setupResultBlockPHINodes();
93   void setupEndBlockPHINodes();
94   Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex);
95   void emitLoadCompareBlock(unsigned BlockIndex);
96   void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
97                                          unsigned &LoadIndex);
98   void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex);
99   void emitMemCmpResultBlock();
100   Value *getMemCmpExpansionZeroCase();
101   Value *getMemCmpEqZeroOneBlock();
102   Value *getMemCmpOneBlock();
103 
104  public:
105   MemCmpExpansion(CallInst *CI, uint64_t Size,
106                   const TargetTransformInfo::MemCmpExpansionOptions &Options,
107                   unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
108                   unsigned NumLoadsPerBlock, const DataLayout &DL);
109 
110   unsigned getNumBlocks();
111   uint64_t getNumLoads() const { return LoadSequence.size(); }
112 
113   Value *getMemCmpExpansion();
114 };
115 
116 // Initialize the basic block structure required for expansion of memcmp call
117 // with given maximum load size and memcmp size parameter.
118 // This structure includes:
119 // 1. A list of load compare blocks - LoadCmpBlocks.
120 // 2. An EndBlock, split from original instruction point, which is the block to
121 // return from.
122 // 3. ResultBlock, block to branch to for early exit when a
123 // LoadCmpBlock finds a difference.
124 MemCmpExpansion::MemCmpExpansion(
125     CallInst *const CI, uint64_t Size,
126     const TargetTransformInfo::MemCmpExpansionOptions &Options,
127     const unsigned MaxNumLoads, const bool IsUsedForZeroCmp,
128     const unsigned NumLoadsPerBlock, const DataLayout &TheDataLayout)
129     : CI(CI),
130       Size(Size),
131       MaxLoadSize(0),
132       NumLoadsNonOneByte(0),
133       NumLoadsPerBlock(NumLoadsPerBlock),
134       IsUsedForZeroCmp(IsUsedForZeroCmp),
135       DL(TheDataLayout),
136       Builder(CI) {
137   assert(Size > 0 && "zero blocks");
138   // Scale the max size down if the target can load more bytes than we need.
139   size_t LoadSizeIndex = 0;
140   while (LoadSizeIndex < Options.LoadSizes.size() &&
141          Options.LoadSizes[LoadSizeIndex] > Size) {
142     ++LoadSizeIndex;
143   }
144   this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex];
145   // Compute the decomposition.
146   uint64_t CurSize = Size;
147   uint64_t Offset = 0;
148   while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) {
149     const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex];
150     assert(LoadSize > 0 && "zero load size");
151     const uint64_t NumLoadsForThisSize = CurSize / LoadSize;
152     if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) {
153       // Do not expand if the total number of loads is larger than what the
154       // target allows. Note that it's important that we exit before completing
155       // the expansion to avoid using a ton of memory to store the expansion for
156       // large sizes.
157       LoadSequence.clear();
158       return;
159     }
160     if (NumLoadsForThisSize > 0) {
161       for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) {
162         LoadSequence.push_back({LoadSize, Offset});
163         Offset += LoadSize;
164       }
165       if (LoadSize > 1) {
166         ++NumLoadsNonOneByte;
167       }
168       CurSize = CurSize % LoadSize;
169     }
170     ++LoadSizeIndex;
171   }
172   assert(LoadSequence.size() <= MaxNumLoads && "broken invariant");
173 }
174 
175 unsigned MemCmpExpansion::getNumBlocks() {
176   if (IsUsedForZeroCmp)
177     return getNumLoads() / NumLoadsPerBlock +
178            (getNumLoads() % NumLoadsPerBlock != 0 ? 1 : 0);
179   return getNumLoads();
180 }
181 
182 void MemCmpExpansion::createLoadCmpBlocks() {
183   for (unsigned i = 0; i < getNumBlocks(); i++) {
184     BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb",
185                                         EndBlock->getParent(), EndBlock);
186     LoadCmpBlocks.push_back(BB);
187   }
188 }
189 
190 void MemCmpExpansion::createResultBlock() {
191   ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block",
192                                    EndBlock->getParent(), EndBlock);
193 }
194 
195 // This function creates the IR instructions for loading and comparing 1 byte.
196 // It loads 1 byte from each source of the memcmp parameters with the given
197 // GEPIndex. It then subtracts the two loaded values and adds this result to the
198 // final phi node for selecting the memcmp result.
199 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex,
200                                                unsigned GEPIndex) {
201   Value *Source1 = CI->getArgOperand(0);
202   Value *Source2 = CI->getArgOperand(1);
203 
204   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
205   Type *LoadSizeType = Type::getInt8Ty(CI->getContext());
206   // Cast source to LoadSizeType*.
207   if (Source1->getType() != LoadSizeType)
208     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
209   if (Source2->getType() != LoadSizeType)
210     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
211 
212   // Get the base address using the GEPIndex.
213   if (GEPIndex != 0) {
214     Source1 = Builder.CreateGEP(LoadSizeType, Source1,
215                                 ConstantInt::get(LoadSizeType, GEPIndex));
216     Source2 = Builder.CreateGEP(LoadSizeType, Source2,
217                                 ConstantInt::get(LoadSizeType, GEPIndex));
218   }
219 
220   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
221   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
222 
223   LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext()));
224   LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext()));
225   Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2);
226 
227   PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]);
228 
229   if (BlockIndex < (LoadCmpBlocks.size() - 1)) {
230     // Early exit branch if difference found to EndBlock. Otherwise, continue to
231     // next LoadCmpBlock,
232     Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff,
233                                     ConstantInt::get(Diff->getType(), 0));
234     BranchInst *CmpBr =
235         BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp);
236     Builder.Insert(CmpBr);
237   } else {
238     // The last block has an unconditional branch to EndBlock.
239     BranchInst *CmpBr = BranchInst::Create(EndBlock);
240     Builder.Insert(CmpBr);
241   }
242 }
243 
244 /// Generate an equality comparison for one or more pairs of loaded values.
245 /// This is used in the case where the memcmp() call is compared equal or not
246 /// equal to zero.
247 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex,
248                                             unsigned &LoadIndex) {
249   assert(LoadIndex < getNumLoads() &&
250          "getCompareLoadPairs() called with no remaining loads");
251   std::vector<Value *> XorList, OrList;
252   Value *Diff;
253 
254   const unsigned NumLoads =
255       std::min(getNumLoads() - LoadIndex, NumLoadsPerBlock);
256 
257   // For a single-block expansion, start inserting before the memcmp call.
258   if (LoadCmpBlocks.empty())
259     Builder.SetInsertPoint(CI);
260   else
261     Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
262 
263   Value *Cmp = nullptr;
264   // If we have multiple loads per block, we need to generate a composite
265   // comparison using xor+or. The type for the combinations is the largest load
266   // type.
267   IntegerType *const MaxLoadType =
268       NumLoads == 1 ? nullptr
269                     : IntegerType::get(CI->getContext(), MaxLoadSize * 8);
270   for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) {
271     const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex];
272 
273     IntegerType *LoadSizeType =
274         IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
275 
276     Value *Source1 = CI->getArgOperand(0);
277     Value *Source2 = CI->getArgOperand(1);
278 
279     // Cast source to LoadSizeType*.
280     if (Source1->getType() != LoadSizeType)
281       Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
282     if (Source2->getType() != LoadSizeType)
283       Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
284 
285     // Get the base address using a GEP.
286     if (CurLoadEntry.Offset != 0) {
287       Source1 = Builder.CreateGEP(
288           LoadSizeType, Source1,
289           ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
290       Source2 = Builder.CreateGEP(
291           LoadSizeType, Source2,
292           ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
293     }
294 
295     // Get a constant or load a value for each source address.
296     Value *LoadSrc1 = nullptr;
297     if (auto *Source1C = dyn_cast<Constant>(Source1))
298       LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL);
299     if (!LoadSrc1)
300       LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
301 
302     Value *LoadSrc2 = nullptr;
303     if (auto *Source2C = dyn_cast<Constant>(Source2))
304       LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL);
305     if (!LoadSrc2)
306       LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
307 
308     if (NumLoads != 1) {
309       if (LoadSizeType != MaxLoadType) {
310         LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
311         LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
312       }
313       // If we have multiple loads per block, we need to generate a composite
314       // comparison using xor+or.
315       Diff = Builder.CreateXor(LoadSrc1, LoadSrc2);
316       Diff = Builder.CreateZExt(Diff, MaxLoadType);
317       XorList.push_back(Diff);
318     } else {
319       // If there's only one load per block, we just compare the loaded values.
320       Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2);
321     }
322   }
323 
324   auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> {
325     std::vector<Value *> OutList;
326     for (unsigned i = 0; i < InList.size() - 1; i = i + 2) {
327       Value *Or = Builder.CreateOr(InList[i], InList[i + 1]);
328       OutList.push_back(Or);
329     }
330     if (InList.size() % 2 != 0)
331       OutList.push_back(InList.back());
332     return OutList;
333   };
334 
335   if (!Cmp) {
336     // Pairwise OR the XOR results.
337     OrList = pairWiseOr(XorList);
338 
339     // Pairwise OR the OR results until one result left.
340     while (OrList.size() != 1) {
341       OrList = pairWiseOr(OrList);
342     }
343     Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0));
344   }
345 
346   return Cmp;
347 }
348 
349 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex,
350                                                         unsigned &LoadIndex) {
351   Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex);
352 
353   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
354                            ? EndBlock
355                            : LoadCmpBlocks[BlockIndex + 1];
356   // Early exit branch if difference found to ResultBlock. Otherwise,
357   // continue to next LoadCmpBlock or EndBlock.
358   BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp);
359   Builder.Insert(CmpBr);
360 
361   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
362   // since early exit to ResultBlock was not taken (no difference was found in
363   // any of the bytes).
364   if (BlockIndex == LoadCmpBlocks.size() - 1) {
365     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
366     PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
367   }
368 }
369 
370 // This function creates the IR intructions for loading and comparing using the
371 // given LoadSize. It loads the number of bytes specified by LoadSize from each
372 // source of the memcmp parameters. It then does a subtract to see if there was
373 // a difference in the loaded values. If a difference is found, it branches
374 // with an early exit to the ResultBlock for calculating which source was
375 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or
376 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with
377 // a special case through emitLoadCompareByteBlock. The special handling can
378 // simply subtract the loaded values and add it to the result phi node.
379 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) {
380   // There is one load per block in this case, BlockIndex == LoadIndex.
381   const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex];
382 
383   if (CurLoadEntry.LoadSize == 1) {
384     MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex,
385                                               CurLoadEntry.getGEPIndex());
386     return;
387   }
388 
389   Type *LoadSizeType =
390       IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8);
391   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
392   assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type");
393 
394   Value *Source1 = CI->getArgOperand(0);
395   Value *Source2 = CI->getArgOperand(1);
396 
397   Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]);
398   // Cast source to LoadSizeType*.
399   if (Source1->getType() != LoadSizeType)
400     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
401   if (Source2->getType() != LoadSizeType)
402     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
403 
404   // Get the base address using a GEP.
405   if (CurLoadEntry.Offset != 0) {
406     Source1 = Builder.CreateGEP(
407         LoadSizeType, Source1,
408         ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
409     Source2 = Builder.CreateGEP(
410         LoadSizeType, Source2,
411         ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex()));
412   }
413 
414   // Load LoadSizeType from the base address.
415   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
416   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
417 
418   if (DL.isLittleEndian()) {
419     Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
420                                                 Intrinsic::bswap, LoadSizeType);
421     LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
422     LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
423   }
424 
425   if (LoadSizeType != MaxLoadType) {
426     LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType);
427     LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType);
428   }
429 
430   // Add the loaded values to the phi nodes for calculating memcmp result only
431   // if result is not used in a zero equality.
432   if (!IsUsedForZeroCmp) {
433     ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]);
434     ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]);
435   }
436 
437   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2);
438   BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1))
439                            ? EndBlock
440                            : LoadCmpBlocks[BlockIndex + 1];
441   // Early exit branch if difference found to ResultBlock. Otherwise, continue
442   // to next LoadCmpBlock or EndBlock.
443   BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp);
444   Builder.Insert(CmpBr);
445 
446   // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0
447   // since early exit to ResultBlock was not taken (no difference was found in
448   // any of the bytes).
449   if (BlockIndex == LoadCmpBlocks.size() - 1) {
450     Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0);
451     PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]);
452   }
453 }
454 
455 // This function populates the ResultBlock with a sequence to calculate the
456 // memcmp result. It compares the two loaded source values and returns -1 if
457 // src1 < src2 and 1 if src1 > src2.
458 void MemCmpExpansion::emitMemCmpResultBlock() {
459   // Special case: if memcmp result is used in a zero equality, result does not
460   // need to be calculated and can simply return 1.
461   if (IsUsedForZeroCmp) {
462     BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
463     Builder.SetInsertPoint(ResBlock.BB, InsertPt);
464     Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1);
465     PhiRes->addIncoming(Res, ResBlock.BB);
466     BranchInst *NewBr = BranchInst::Create(EndBlock);
467     Builder.Insert(NewBr);
468     return;
469   }
470   BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt();
471   Builder.SetInsertPoint(ResBlock.BB, InsertPt);
472 
473   Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1,
474                                   ResBlock.PhiSrc2);
475 
476   Value *Res =
477       Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1),
478                            ConstantInt::get(Builder.getInt32Ty(), 1));
479 
480   BranchInst *NewBr = BranchInst::Create(EndBlock);
481   Builder.Insert(NewBr);
482   PhiRes->addIncoming(Res, ResBlock.BB);
483 }
484 
485 void MemCmpExpansion::setupResultBlockPHINodes() {
486   Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8);
487   Builder.SetInsertPoint(ResBlock.BB);
488   // Note: this assumes one load per block.
489   ResBlock.PhiSrc1 =
490       Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1");
491   ResBlock.PhiSrc2 =
492       Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2");
493 }
494 
495 void MemCmpExpansion::setupEndBlockPHINodes() {
496   Builder.SetInsertPoint(&EndBlock->front());
497   PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res");
498 }
499 
500 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() {
501   unsigned LoadIndex = 0;
502   // This loop populates each of the LoadCmpBlocks with the IR sequence to
503   // handle multiple loads per block.
504   for (unsigned I = 0; I < getNumBlocks(); ++I) {
505     emitLoadCompareBlockMultipleLoads(I, LoadIndex);
506   }
507 
508   emitMemCmpResultBlock();
509   return PhiRes;
510 }
511 
512 /// A memcmp expansion that compares equality with 0 and only has one block of
513 /// load and compare can bypass the compare, branch, and phi IR that is required
514 /// in the general case.
515 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() {
516   unsigned LoadIndex = 0;
517   Value *Cmp = getCompareLoadPairs(0, LoadIndex);
518   assert(LoadIndex == getNumLoads() && "some entries were not consumed");
519   return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext()));
520 }
521 
522 /// A memcmp expansion that only has one block of load and compare can bypass
523 /// the compare, branch, and phi IR that is required in the general case.
524 Value *MemCmpExpansion::getMemCmpOneBlock() {
525   assert(NumLoadsPerBlock == 1 && "Only handles one load pair per block");
526 
527   Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8);
528   Value *Source1 = CI->getArgOperand(0);
529   Value *Source2 = CI->getArgOperand(1);
530 
531   // Cast source to LoadSizeType*.
532   if (Source1->getType() != LoadSizeType)
533     Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo());
534   if (Source2->getType() != LoadSizeType)
535     Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo());
536 
537   // Load LoadSizeType from the base address.
538   Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1);
539   Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2);
540 
541   if (DL.isLittleEndian() && Size != 1) {
542     Function *Bswap = Intrinsic::getDeclaration(CI->getModule(),
543                                                 Intrinsic::bswap, LoadSizeType);
544     LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1);
545     LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2);
546   }
547 
548   if (Size < 4) {
549     // The i8 and i16 cases don't need compares. We zext the loaded values and
550     // subtract them to get the suitable negative, zero, or positive i32 result.
551     LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty());
552     LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty());
553     return Builder.CreateSub(LoadSrc1, LoadSrc2);
554   }
555 
556   // The result of memcmp is negative, zero, or positive, so produce that by
557   // subtracting 2 extended compare bits: sub (ugt, ult).
558   // If a target prefers to use selects to get -1/0/1, they should be able
559   // to transform this later. The inverse transform (going from selects to math)
560   // may not be possible in the DAG because the selects got converted into
561   // branches before we got there.
562   Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2);
563   Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2);
564   Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty());
565   Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty());
566   return Builder.CreateSub(ZextUGT, ZextULT);
567 }
568 
569 // This function expands the memcmp call into an inline expansion and returns
570 // the memcmp result.
571 Value *MemCmpExpansion::getMemCmpExpansion() {
572   // A memcmp with zero-comparison with only one block of load and compare does
573   // not need to set up any extra blocks. This case could be handled in the DAG,
574   // but since we have all of the machinery to flexibly expand any memcpy here,
575   // we choose to handle this case too to avoid fragmented lowering.
576   if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || getNumBlocks() != 1) {
577     BasicBlock *StartBlock = CI->getParent();
578     EndBlock = StartBlock->splitBasicBlock(CI, "endblock");
579     setupEndBlockPHINodes();
580     createResultBlock();
581 
582     // If return value of memcmp is not used in a zero equality, we need to
583     // calculate which source was larger. The calculation requires the
584     // two loaded source values of each load compare block.
585     // These will be saved in the phi nodes created by setupResultBlockPHINodes.
586     if (!IsUsedForZeroCmp) setupResultBlockPHINodes();
587 
588     // Create the number of required load compare basic blocks.
589     createLoadCmpBlocks();
590 
591     // Update the terminator added by splitBasicBlock to branch to the first
592     // LoadCmpBlock.
593     StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]);
594   }
595 
596   Builder.SetCurrentDebugLocation(CI->getDebugLoc());
597 
598   if (IsUsedForZeroCmp)
599     return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock()
600                                : getMemCmpExpansionZeroCase();
601 
602   // TODO: Handle more than one load pair per block in getMemCmpOneBlock().
603   if (getNumBlocks() == 1 && NumLoadsPerBlock == 1) return getMemCmpOneBlock();
604 
605   for (unsigned I = 0; I < getNumBlocks(); ++I) {
606     emitLoadCompareBlock(I);
607   }
608 
609   emitMemCmpResultBlock();
610   return PhiRes;
611 }
612 
613 // This function checks to see if an expansion of memcmp can be generated.
614 // It checks for constant compare size that is less than the max inline size.
615 // If an expansion cannot occur, returns false to leave as a library call.
616 // Otherwise, the library call is replaced with a new IR instruction sequence.
617 /// We want to transform:
618 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15)
619 /// To:
620 /// loadbb:
621 ///  %0 = bitcast i32* %buffer2 to i8*
622 ///  %1 = bitcast i32* %buffer1 to i8*
623 ///  %2 = bitcast i8* %1 to i64*
624 ///  %3 = bitcast i8* %0 to i64*
625 ///  %4 = load i64, i64* %2
626 ///  %5 = load i64, i64* %3
627 ///  %6 = call i64 @llvm.bswap.i64(i64 %4)
628 ///  %7 = call i64 @llvm.bswap.i64(i64 %5)
629 ///  %8 = sub i64 %6, %7
630 ///  %9 = icmp ne i64 %8, 0
631 ///  br i1 %9, label %res_block, label %loadbb1
632 /// res_block:                                        ; preds = %loadbb2,
633 /// %loadbb1, %loadbb
634 ///  %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ]
635 ///  %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ]
636 ///  %10 = icmp ult i64 %phi.src1, %phi.src2
637 ///  %11 = select i1 %10, i32 -1, i32 1
638 ///  br label %endblock
639 /// loadbb1:                                          ; preds = %loadbb
640 ///  %12 = bitcast i32* %buffer2 to i8*
641 ///  %13 = bitcast i32* %buffer1 to i8*
642 ///  %14 = bitcast i8* %13 to i32*
643 ///  %15 = bitcast i8* %12 to i32*
644 ///  %16 = getelementptr i32, i32* %14, i32 2
645 ///  %17 = getelementptr i32, i32* %15, i32 2
646 ///  %18 = load i32, i32* %16
647 ///  %19 = load i32, i32* %17
648 ///  %20 = call i32 @llvm.bswap.i32(i32 %18)
649 ///  %21 = call i32 @llvm.bswap.i32(i32 %19)
650 ///  %22 = zext i32 %20 to i64
651 ///  %23 = zext i32 %21 to i64
652 ///  %24 = sub i64 %22, %23
653 ///  %25 = icmp ne i64 %24, 0
654 ///  br i1 %25, label %res_block, label %loadbb2
655 /// loadbb2:                                          ; preds = %loadbb1
656 ///  %26 = bitcast i32* %buffer2 to i8*
657 ///  %27 = bitcast i32* %buffer1 to i8*
658 ///  %28 = bitcast i8* %27 to i16*
659 ///  %29 = bitcast i8* %26 to i16*
660 ///  %30 = getelementptr i16, i16* %28, i16 6
661 ///  %31 = getelementptr i16, i16* %29, i16 6
662 ///  %32 = load i16, i16* %30
663 ///  %33 = load i16, i16* %31
664 ///  %34 = call i16 @llvm.bswap.i16(i16 %32)
665 ///  %35 = call i16 @llvm.bswap.i16(i16 %33)
666 ///  %36 = zext i16 %34 to i64
667 ///  %37 = zext i16 %35 to i64
668 ///  %38 = sub i64 %36, %37
669 ///  %39 = icmp ne i64 %38, 0
670 ///  br i1 %39, label %res_block, label %loadbb3
671 /// loadbb3:                                          ; preds = %loadbb2
672 ///  %40 = bitcast i32* %buffer2 to i8*
673 ///  %41 = bitcast i32* %buffer1 to i8*
674 ///  %42 = getelementptr i8, i8* %41, i8 14
675 ///  %43 = getelementptr i8, i8* %40, i8 14
676 ///  %44 = load i8, i8* %42
677 ///  %45 = load i8, i8* %43
678 ///  %46 = zext i8 %44 to i32
679 ///  %47 = zext i8 %45 to i32
680 ///  %48 = sub i32 %46, %47
681 ///  br label %endblock
682 /// endblock:                                         ; preds = %res_block,
683 /// %loadbb3
684 ///  %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ]
685 ///  ret i32 %phi.res
686 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI,
687                          const TargetLowering *TLI, const DataLayout *DL) {
688   NumMemCmpCalls++;
689 
690   // Early exit from expansion if -Oz.
691   if (CI->getFunction()->optForMinSize())
692     return false;
693 
694   // Early exit from expansion if size is not a constant.
695   ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2));
696   if (!SizeCast) {
697     NumMemCmpNotConstant++;
698     return false;
699   }
700   const uint64_t SizeVal = SizeCast->getZExtValue();
701 
702   if (SizeVal == 0) {
703     return false;
704   }
705 
706   // TTI call to check if target would like to expand memcmp. Also, get the
707   // available load sizes.
708   const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI);
709   const auto *const Options = TTI->enableMemCmpExpansion(IsUsedForZeroCmp);
710   if (!Options) return false;
711 
712   const unsigned MaxNumLoads =
713       TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize());
714 
715   MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads,
716                             IsUsedForZeroCmp, MemCmpNumLoadsPerBlock, *DL);
717 
718   // Don't expand if this will require more loads than desired by the target.
719   if (Expansion.getNumLoads() == 0) {
720     NumMemCmpGreaterThanMax++;
721     return false;
722   }
723 
724   NumMemCmpInlined++;
725 
726   Value *Res = Expansion.getMemCmpExpansion();
727 
728   // Replace call with result of expansion and erase call.
729   CI->replaceAllUsesWith(Res);
730   CI->eraseFromParent();
731 
732   return true;
733 }
734 
735 
736 
737 class ExpandMemCmpPass : public FunctionPass {
738 public:
739   static char ID;
740 
741   ExpandMemCmpPass() : FunctionPass(ID) {
742     initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry());
743   }
744 
745   bool runOnFunction(Function &F) override {
746     if (skipFunction(F)) return false;
747 
748     auto *TPC = getAnalysisIfAvailable<TargetPassConfig>();
749     if (!TPC) {
750       return false;
751     }
752     const TargetLowering* TL =
753         TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering();
754 
755     const TargetLibraryInfo *TLI =
756         &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI();
757     const TargetTransformInfo *TTI =
758         &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
759     auto PA = runImpl(F, TLI, TTI, TL);
760     return !PA.areAllPreserved();
761   }
762 
763 private:
764   void getAnalysisUsage(AnalysisUsage &AU) const override {
765     AU.addRequired<TargetLibraryInfoWrapperPass>();
766     AU.addRequired<TargetTransformInfoWrapperPass>();
767     FunctionPass::getAnalysisUsage(AU);
768   }
769 
770   PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI,
771                             const TargetTransformInfo *TTI,
772                             const TargetLowering* TL);
773   // Returns true if a change was made.
774   bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI,
775                   const TargetTransformInfo *TTI, const TargetLowering* TL,
776                   const DataLayout& DL);
777 };
778 
779 bool ExpandMemCmpPass::runOnBlock(
780     BasicBlock &BB, const TargetLibraryInfo *TLI,
781     const TargetTransformInfo *TTI, const TargetLowering* TL,
782     const DataLayout& DL) {
783   for (Instruction& I : BB) {
784     CallInst *CI = dyn_cast<CallInst>(&I);
785     if (!CI) {
786       continue;
787     }
788     LibFunc Func;
789     if (TLI->getLibFunc(ImmutableCallSite(CI), Func) &&
790         Func == LibFunc_memcmp && expandMemCmp(CI, TTI, TL, &DL)) {
791       return true;
792     }
793   }
794   return false;
795 }
796 
797 
798 PreservedAnalyses ExpandMemCmpPass::runImpl(
799     Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI,
800     const TargetLowering* TL) {
801   const DataLayout& DL = F.getParent()->getDataLayout();
802   bool MadeChanges = false;
803   for (auto BBIt = F.begin(); BBIt != F.end();) {
804     if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) {
805       MadeChanges = true;
806       // If changes were made, restart the function from the beginning, since
807       // the structure of the function was changed.
808       BBIt = F.begin();
809     } else {
810       ++BBIt;
811     }
812   }
813   return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all();
814 }
815 
816 } // namespace
817 
818 char ExpandMemCmpPass::ID = 0;
819 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp",
820                       "Expand memcmp() to load/stores", false, false)
821 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
822 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
823 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp",
824                     "Expand memcmp() to load/stores", false, false)
825 
826 FunctionPass *llvm::createExpandMemCmpPass() {
827   return new ExpandMemCmpPass();
828 }
829