1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass tries to expand memcmp() calls into optimally-sized loads and 10 // compares for the target. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/Statistic.h" 15 #include "llvm/Analysis/ConstantFolding.h" 16 #include "llvm/Analysis/LazyBlockFrequencyInfo.h" 17 #include "llvm/Analysis/ProfileSummaryInfo.h" 18 #include "llvm/Analysis/TargetLibraryInfo.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/ValueTracking.h" 21 #include "llvm/CodeGen/TargetLowering.h" 22 #include "llvm/CodeGen/TargetPassConfig.h" 23 #include "llvm/CodeGen/TargetSubtargetInfo.h" 24 #include "llvm/IR/IRBuilder.h" 25 #include "llvm/InitializePasses.h" 26 #include "llvm/Transforms/Utils/SizeOpts.h" 27 28 using namespace llvm; 29 30 #define DEBUG_TYPE "expandmemcmp" 31 32 STATISTIC(NumMemCmpCalls, "Number of memcmp calls"); 33 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size"); 34 STATISTIC(NumMemCmpGreaterThanMax, 35 "Number of memcmp calls with size greater than max size"); 36 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls"); 37 38 static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock( 39 "memcmp-num-loads-per-block", cl::Hidden, cl::init(1), 40 cl::desc("The number of loads per basic block for inline expansion of " 41 "memcmp that is only being compared against zero.")); 42 43 static cl::opt<unsigned> MaxLoadsPerMemcmp( 44 "max-loads-per-memcmp", cl::Hidden, 45 cl::desc("Set maximum number of loads used in expanded memcmp")); 46 47 static cl::opt<unsigned> MaxLoadsPerMemcmpOptSize( 48 "max-loads-per-memcmp-opt-size", cl::Hidden, 49 cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz")); 50 51 namespace { 52 53 54 // This class provides helper functions to expand a memcmp library call into an 55 // inline expansion. 56 class MemCmpExpansion { 57 struct ResultBlock { 58 BasicBlock *BB = nullptr; 59 PHINode *PhiSrc1 = nullptr; 60 PHINode *PhiSrc2 = nullptr; 61 62 ResultBlock() = default; 63 }; 64 65 CallInst *const CI; 66 ResultBlock ResBlock; 67 const uint64_t Size; 68 unsigned MaxLoadSize; 69 uint64_t NumLoadsNonOneByte; 70 const uint64_t NumLoadsPerBlockForZeroCmp; 71 std::vector<BasicBlock *> LoadCmpBlocks; 72 BasicBlock *EndBlock; 73 PHINode *PhiRes; 74 const bool IsUsedForZeroCmp; 75 const DataLayout &DL; 76 IRBuilder<> Builder; 77 // Represents the decomposition in blocks of the expansion. For example, 78 // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and 79 // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {1, 32}. 80 struct LoadEntry { 81 LoadEntry(unsigned LoadSize, uint64_t Offset) 82 : LoadSize(LoadSize), Offset(Offset) { 83 } 84 85 // The size of the load for this block, in bytes. 86 unsigned LoadSize; 87 // The offset of this load from the base pointer, in bytes. 88 uint64_t Offset; 89 }; 90 using LoadEntryVector = SmallVector<LoadEntry, 8>; 91 LoadEntryVector LoadSequence; 92 93 void createLoadCmpBlocks(); 94 void createResultBlock(); 95 void setupResultBlockPHINodes(); 96 void setupEndBlockPHINodes(); 97 Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); 98 void emitLoadCompareBlock(unsigned BlockIndex); 99 void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 100 unsigned &LoadIndex); 101 void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes); 102 void emitMemCmpResultBlock(); 103 Value *getMemCmpExpansionZeroCase(); 104 Value *getMemCmpEqZeroOneBlock(); 105 Value *getMemCmpOneBlock(); 106 struct LoadPair { 107 Value *Lhs = nullptr; 108 Value *Rhs = nullptr; 109 }; 110 LoadPair getLoadPair(Type *LoadSizeType, bool NeedsBSwap, Type *CmpSizeType, 111 unsigned OffsetBytes); 112 113 static LoadEntryVector 114 computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, 115 unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte); 116 static LoadEntryVector 117 computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize, 118 unsigned MaxNumLoads, 119 unsigned &NumLoadsNonOneByte); 120 121 public: 122 MemCmpExpansion(CallInst *CI, uint64_t Size, 123 const TargetTransformInfo::MemCmpExpansionOptions &Options, 124 const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout); 125 126 unsigned getNumBlocks(); 127 uint64_t getNumLoads() const { return LoadSequence.size(); } 128 129 Value *getMemCmpExpansion(); 130 }; 131 132 MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence( 133 uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, 134 const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) { 135 NumLoadsNonOneByte = 0; 136 LoadEntryVector LoadSequence; 137 uint64_t Offset = 0; 138 while (Size && !LoadSizes.empty()) { 139 const unsigned LoadSize = LoadSizes.front(); 140 const uint64_t NumLoadsForThisSize = Size / LoadSize; 141 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { 142 // Do not expand if the total number of loads is larger than what the 143 // target allows. Note that it's important that we exit before completing 144 // the expansion to avoid using a ton of memory to store the expansion for 145 // large sizes. 146 return {}; 147 } 148 if (NumLoadsForThisSize > 0) { 149 for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { 150 LoadSequence.push_back({LoadSize, Offset}); 151 Offset += LoadSize; 152 } 153 if (LoadSize > 1) 154 ++NumLoadsNonOneByte; 155 Size = Size % LoadSize; 156 } 157 LoadSizes = LoadSizes.drop_front(); 158 } 159 return LoadSequence; 160 } 161 162 MemCmpExpansion::LoadEntryVector 163 MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size, 164 const unsigned MaxLoadSize, 165 const unsigned MaxNumLoads, 166 unsigned &NumLoadsNonOneByte) { 167 // These are already handled by the greedy approach. 168 if (Size < 2 || MaxLoadSize < 2) 169 return {}; 170 171 // We try to do as many non-overlapping loads as possible starting from the 172 // beginning. 173 const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize; 174 assert(NumNonOverlappingLoads && "there must be at least one load"); 175 // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with 176 // an overlapping load. 177 Size = Size - NumNonOverlappingLoads * MaxLoadSize; 178 // Bail if we do not need an overloapping store, this is already handled by 179 // the greedy approach. 180 if (Size == 0) 181 return {}; 182 // Bail if the number of loads (non-overlapping + potential overlapping one) 183 // is larger than the max allowed. 184 if ((NumNonOverlappingLoads + 1) > MaxNumLoads) 185 return {}; 186 187 // Add non-overlapping loads. 188 LoadEntryVector LoadSequence; 189 uint64_t Offset = 0; 190 for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) { 191 LoadSequence.push_back({MaxLoadSize, Offset}); 192 Offset += MaxLoadSize; 193 } 194 195 // Add the last overlapping load. 196 assert(Size > 0 && Size < MaxLoadSize && "broken invariant"); 197 LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)}); 198 NumLoadsNonOneByte = 1; 199 return LoadSequence; 200 } 201 202 // Initialize the basic block structure required for expansion of memcmp call 203 // with given maximum load size and memcmp size parameter. 204 // This structure includes: 205 // 1. A list of load compare blocks - LoadCmpBlocks. 206 // 2. An EndBlock, split from original instruction point, which is the block to 207 // return from. 208 // 3. ResultBlock, block to branch to for early exit when a 209 // LoadCmpBlock finds a difference. 210 MemCmpExpansion::MemCmpExpansion( 211 CallInst *const CI, uint64_t Size, 212 const TargetTransformInfo::MemCmpExpansionOptions &Options, 213 const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout) 214 : CI(CI), Size(Size), MaxLoadSize(0), NumLoadsNonOneByte(0), 215 NumLoadsPerBlockForZeroCmp(Options.NumLoadsPerBlock), 216 IsUsedForZeroCmp(IsUsedForZeroCmp), DL(TheDataLayout), Builder(CI) { 217 assert(Size > 0 && "zero blocks"); 218 // Scale the max size down if the target can load more bytes than we need. 219 llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes); 220 while (!LoadSizes.empty() && LoadSizes.front() > Size) { 221 LoadSizes = LoadSizes.drop_front(); 222 } 223 assert(!LoadSizes.empty() && "cannot load Size bytes"); 224 MaxLoadSize = LoadSizes.front(); 225 // Compute the decomposition. 226 unsigned GreedyNumLoadsNonOneByte = 0; 227 LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, Options.MaxNumLoads, 228 GreedyNumLoadsNonOneByte); 229 NumLoadsNonOneByte = GreedyNumLoadsNonOneByte; 230 assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); 231 // If we allow overlapping loads and the load sequence is not already optimal, 232 // use overlapping loads. 233 if (Options.AllowOverlappingLoads && 234 (LoadSequence.empty() || LoadSequence.size() > 2)) { 235 unsigned OverlappingNumLoadsNonOneByte = 0; 236 auto OverlappingLoads = computeOverlappingLoadSequence( 237 Size, MaxLoadSize, Options.MaxNumLoads, OverlappingNumLoadsNonOneByte); 238 if (!OverlappingLoads.empty() && 239 (LoadSequence.empty() || 240 OverlappingLoads.size() < LoadSequence.size())) { 241 LoadSequence = OverlappingLoads; 242 NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte; 243 } 244 } 245 assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); 246 } 247 248 unsigned MemCmpExpansion::getNumBlocks() { 249 if (IsUsedForZeroCmp) 250 return getNumLoads() / NumLoadsPerBlockForZeroCmp + 251 (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0); 252 return getNumLoads(); 253 } 254 255 void MemCmpExpansion::createLoadCmpBlocks() { 256 for (unsigned i = 0; i < getNumBlocks(); i++) { 257 BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", 258 EndBlock->getParent(), EndBlock); 259 LoadCmpBlocks.push_back(BB); 260 } 261 } 262 263 void MemCmpExpansion::createResultBlock() { 264 ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block", 265 EndBlock->getParent(), EndBlock); 266 } 267 268 MemCmpExpansion::LoadPair MemCmpExpansion::getLoadPair(Type *LoadSizeType, 269 bool NeedsBSwap, 270 Type *CmpSizeType, 271 unsigned OffsetBytes) { 272 // Get the memory source at offset `OffsetBytes`. 273 Value *LhsSource = CI->getArgOperand(0); 274 Value *RhsSource = CI->getArgOperand(1); 275 if (OffsetBytes > 0) { 276 auto *ByteType = Type::getInt8Ty(CI->getContext()); 277 LhsSource = Builder.CreateConstGEP1_64( 278 ByteType, Builder.CreateBitCast(LhsSource, ByteType->getPointerTo()), 279 OffsetBytes); 280 RhsSource = Builder.CreateConstGEP1_64( 281 ByteType, Builder.CreateBitCast(RhsSource, ByteType->getPointerTo()), 282 OffsetBytes); 283 } 284 LhsSource = Builder.CreateBitCast(LhsSource, LoadSizeType->getPointerTo()); 285 RhsSource = Builder.CreateBitCast(RhsSource, LoadSizeType->getPointerTo()); 286 287 // Create a constant or a load from the source. 288 Value *Lhs = nullptr; 289 if (auto *C = dyn_cast<Constant>(LhsSource)) 290 Lhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL); 291 if (!Lhs) 292 Lhs = Builder.CreateLoad(LoadSizeType, LhsSource); 293 294 Value *Rhs = nullptr; 295 if (auto *C = dyn_cast<Constant>(RhsSource)) 296 Rhs = ConstantFoldLoadFromConstPtr(C, LoadSizeType, DL); 297 if (!Rhs) 298 Rhs = Builder.CreateLoad(LoadSizeType, RhsSource); 299 300 // Swap bytes if required. 301 if (NeedsBSwap) { 302 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 303 Intrinsic::bswap, LoadSizeType); 304 Lhs = Builder.CreateCall(Bswap, Lhs); 305 Rhs = Builder.CreateCall(Bswap, Rhs); 306 } 307 308 // Zero extend if required. 309 if (CmpSizeType != nullptr && CmpSizeType != LoadSizeType) { 310 Lhs = Builder.CreateZExt(Lhs, CmpSizeType); 311 Rhs = Builder.CreateZExt(Rhs, CmpSizeType); 312 } 313 return {Lhs, Rhs}; 314 } 315 316 // This function creates the IR instructions for loading and comparing 1 byte. 317 // It loads 1 byte from each source of the memcmp parameters with the given 318 // GEPIndex. It then subtracts the two loaded values and adds this result to the 319 // final phi node for selecting the memcmp result. 320 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, 321 unsigned OffsetBytes) { 322 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 323 const LoadPair Loads = 324 getLoadPair(Type::getInt8Ty(CI->getContext()), /*NeedsBSwap=*/false, 325 Type::getInt32Ty(CI->getContext()), OffsetBytes); 326 Value *Diff = Builder.CreateSub(Loads.Lhs, Loads.Rhs); 327 328 PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]); 329 330 if (BlockIndex < (LoadCmpBlocks.size() - 1)) { 331 // Early exit branch if difference found to EndBlock. Otherwise, continue to 332 // next LoadCmpBlock, 333 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, 334 ConstantInt::get(Diff->getType(), 0)); 335 BranchInst *CmpBr = 336 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); 337 Builder.Insert(CmpBr); 338 } else { 339 // The last block has an unconditional branch to EndBlock. 340 BranchInst *CmpBr = BranchInst::Create(EndBlock); 341 Builder.Insert(CmpBr); 342 } 343 } 344 345 /// Generate an equality comparison for one or more pairs of loaded values. 346 /// This is used in the case where the memcmp() call is compared equal or not 347 /// equal to zero. 348 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, 349 unsigned &LoadIndex) { 350 assert(LoadIndex < getNumLoads() && 351 "getCompareLoadPairs() called with no remaining loads"); 352 std::vector<Value *> XorList, OrList; 353 Value *Diff = nullptr; 354 355 const unsigned NumLoads = 356 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp); 357 358 // For a single-block expansion, start inserting before the memcmp call. 359 if (LoadCmpBlocks.empty()) 360 Builder.SetInsertPoint(CI); 361 else 362 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 363 364 Value *Cmp = nullptr; 365 // If we have multiple loads per block, we need to generate a composite 366 // comparison using xor+or. The type for the combinations is the largest load 367 // type. 368 IntegerType *const MaxLoadType = 369 NumLoads == 1 ? nullptr 370 : IntegerType::get(CI->getContext(), MaxLoadSize * 8); 371 for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { 372 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; 373 const LoadPair Loads = getLoadPair( 374 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8), 375 /*NeedsBSwap=*/false, MaxLoadType, CurLoadEntry.Offset); 376 377 if (NumLoads != 1) { 378 // If we have multiple loads per block, we need to generate a composite 379 // comparison using xor+or. 380 Diff = Builder.CreateXor(Loads.Lhs, Loads.Rhs); 381 Diff = Builder.CreateZExt(Diff, MaxLoadType); 382 XorList.push_back(Diff); 383 } else { 384 // If there's only one load per block, we just compare the loaded values. 385 Cmp = Builder.CreateICmpNE(Loads.Lhs, Loads.Rhs); 386 } 387 } 388 389 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> { 390 std::vector<Value *> OutList; 391 for (unsigned i = 0; i < InList.size() - 1; i = i + 2) { 392 Value *Or = Builder.CreateOr(InList[i], InList[i + 1]); 393 OutList.push_back(Or); 394 } 395 if (InList.size() % 2 != 0) 396 OutList.push_back(InList.back()); 397 return OutList; 398 }; 399 400 if (!Cmp) { 401 // Pairwise OR the XOR results. 402 OrList = pairWiseOr(XorList); 403 404 // Pairwise OR the OR results until one result left. 405 while (OrList.size() != 1) { 406 OrList = pairWiseOr(OrList); 407 } 408 409 assert(Diff && "Failed to find comparison diff"); 410 Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0)); 411 } 412 413 return Cmp; 414 } 415 416 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 417 unsigned &LoadIndex) { 418 Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); 419 420 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 421 ? EndBlock 422 : LoadCmpBlocks[BlockIndex + 1]; 423 // Early exit branch if difference found to ResultBlock. Otherwise, 424 // continue to next LoadCmpBlock or EndBlock. 425 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); 426 Builder.Insert(CmpBr); 427 428 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 429 // since early exit to ResultBlock was not taken (no difference was found in 430 // any of the bytes). 431 if (BlockIndex == LoadCmpBlocks.size() - 1) { 432 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 433 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 434 } 435 } 436 437 // This function creates the IR intructions for loading and comparing using the 438 // given LoadSize. It loads the number of bytes specified by LoadSize from each 439 // source of the memcmp parameters. It then does a subtract to see if there was 440 // a difference in the loaded values. If a difference is found, it branches 441 // with an early exit to the ResultBlock for calculating which source was 442 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or 443 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with 444 // a special case through emitLoadCompareByteBlock. The special handling can 445 // simply subtract the loaded values and add it to the result phi node. 446 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { 447 // There is one load per block in this case, BlockIndex == LoadIndex. 448 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; 449 450 if (CurLoadEntry.LoadSize == 1) { 451 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset); 452 return; 453 } 454 455 Type *LoadSizeType = 456 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 457 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 458 assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); 459 460 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 461 462 const LoadPair Loads = 463 getLoadPair(LoadSizeType, /*NeedsBSwap=*/DL.isLittleEndian(), MaxLoadType, 464 CurLoadEntry.Offset); 465 466 // Add the loaded values to the phi nodes for calculating memcmp result only 467 // if result is not used in a zero equality. 468 if (!IsUsedForZeroCmp) { 469 ResBlock.PhiSrc1->addIncoming(Loads.Lhs, LoadCmpBlocks[BlockIndex]); 470 ResBlock.PhiSrc2->addIncoming(Loads.Rhs, LoadCmpBlocks[BlockIndex]); 471 } 472 473 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, Loads.Lhs, Loads.Rhs); 474 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 475 ? EndBlock 476 : LoadCmpBlocks[BlockIndex + 1]; 477 // Early exit branch if difference found to ResultBlock. Otherwise, continue 478 // to next LoadCmpBlock or EndBlock. 479 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); 480 Builder.Insert(CmpBr); 481 482 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 483 // since early exit to ResultBlock was not taken (no difference was found in 484 // any of the bytes). 485 if (BlockIndex == LoadCmpBlocks.size() - 1) { 486 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 487 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 488 } 489 } 490 491 // This function populates the ResultBlock with a sequence to calculate the 492 // memcmp result. It compares the two loaded source values and returns -1 if 493 // src1 < src2 and 1 if src1 > src2. 494 void MemCmpExpansion::emitMemCmpResultBlock() { 495 // Special case: if memcmp result is used in a zero equality, result does not 496 // need to be calculated and can simply return 1. 497 if (IsUsedForZeroCmp) { 498 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 499 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 500 Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1); 501 PhiRes->addIncoming(Res, ResBlock.BB); 502 BranchInst *NewBr = BranchInst::Create(EndBlock); 503 Builder.Insert(NewBr); 504 return; 505 } 506 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 507 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 508 509 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1, 510 ResBlock.PhiSrc2); 511 512 Value *Res = 513 Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1), 514 ConstantInt::get(Builder.getInt32Ty(), 1)); 515 516 BranchInst *NewBr = BranchInst::Create(EndBlock); 517 Builder.Insert(NewBr); 518 PhiRes->addIncoming(Res, ResBlock.BB); 519 } 520 521 void MemCmpExpansion::setupResultBlockPHINodes() { 522 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 523 Builder.SetInsertPoint(ResBlock.BB); 524 // Note: this assumes one load per block. 525 ResBlock.PhiSrc1 = 526 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); 527 ResBlock.PhiSrc2 = 528 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); 529 } 530 531 void MemCmpExpansion::setupEndBlockPHINodes() { 532 Builder.SetInsertPoint(&EndBlock->front()); 533 PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); 534 } 535 536 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { 537 unsigned LoadIndex = 0; 538 // This loop populates each of the LoadCmpBlocks with the IR sequence to 539 // handle multiple loads per block. 540 for (unsigned I = 0; I < getNumBlocks(); ++I) { 541 emitLoadCompareBlockMultipleLoads(I, LoadIndex); 542 } 543 544 emitMemCmpResultBlock(); 545 return PhiRes; 546 } 547 548 /// A memcmp expansion that compares equality with 0 and only has one block of 549 /// load and compare can bypass the compare, branch, and phi IR that is required 550 /// in the general case. 551 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { 552 unsigned LoadIndex = 0; 553 Value *Cmp = getCompareLoadPairs(0, LoadIndex); 554 assert(LoadIndex == getNumLoads() && "some entries were not consumed"); 555 return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); 556 } 557 558 /// A memcmp expansion that only has one block of load and compare can bypass 559 /// the compare, branch, and phi IR that is required in the general case. 560 Value *MemCmpExpansion::getMemCmpOneBlock() { 561 Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); 562 bool NeedsBSwap = DL.isLittleEndian() && Size != 1; 563 564 // The i8 and i16 cases don't need compares. We zext the loaded values and 565 // subtract them to get the suitable negative, zero, or positive i32 result. 566 if (Size < 4) { 567 const LoadPair Loads = 568 getLoadPair(LoadSizeType, NeedsBSwap, Builder.getInt32Ty(), 569 /*Offset*/ 0); 570 return Builder.CreateSub(Loads.Lhs, Loads.Rhs); 571 } 572 573 const LoadPair Loads = getLoadPair(LoadSizeType, NeedsBSwap, LoadSizeType, 574 /*Offset*/ 0); 575 // The result of memcmp is negative, zero, or positive, so produce that by 576 // subtracting 2 extended compare bits: sub (ugt, ult). 577 // If a target prefers to use selects to get -1/0/1, they should be able 578 // to transform this later. The inverse transform (going from selects to math) 579 // may not be possible in the DAG because the selects got converted into 580 // branches before we got there. 581 Value *CmpUGT = Builder.CreateICmpUGT(Loads.Lhs, Loads.Rhs); 582 Value *CmpULT = Builder.CreateICmpULT(Loads.Lhs, Loads.Rhs); 583 Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); 584 Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); 585 return Builder.CreateSub(ZextUGT, ZextULT); 586 } 587 588 // This function expands the memcmp call into an inline expansion and returns 589 // the memcmp result. 590 Value *MemCmpExpansion::getMemCmpExpansion() { 591 // Create the basic block framework for a multi-block expansion. 592 if (getNumBlocks() != 1) { 593 BasicBlock *StartBlock = CI->getParent(); 594 EndBlock = StartBlock->splitBasicBlock(CI, "endblock"); 595 setupEndBlockPHINodes(); 596 createResultBlock(); 597 598 // If return value of memcmp is not used in a zero equality, we need to 599 // calculate which source was larger. The calculation requires the 600 // two loaded source values of each load compare block. 601 // These will be saved in the phi nodes created by setupResultBlockPHINodes. 602 if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); 603 604 // Create the number of required load compare basic blocks. 605 createLoadCmpBlocks(); 606 607 // Update the terminator added by splitBasicBlock to branch to the first 608 // LoadCmpBlock. 609 StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); 610 } 611 612 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 613 614 if (IsUsedForZeroCmp) 615 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() 616 : getMemCmpExpansionZeroCase(); 617 618 if (getNumBlocks() == 1) 619 return getMemCmpOneBlock(); 620 621 for (unsigned I = 0; I < getNumBlocks(); ++I) { 622 emitLoadCompareBlock(I); 623 } 624 625 emitMemCmpResultBlock(); 626 return PhiRes; 627 } 628 629 // This function checks to see if an expansion of memcmp can be generated. 630 // It checks for constant compare size that is less than the max inline size. 631 // If an expansion cannot occur, returns false to leave as a library call. 632 // Otherwise, the library call is replaced with a new IR instruction sequence. 633 /// We want to transform: 634 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15) 635 /// To: 636 /// loadbb: 637 /// %0 = bitcast i32* %buffer2 to i8* 638 /// %1 = bitcast i32* %buffer1 to i8* 639 /// %2 = bitcast i8* %1 to i64* 640 /// %3 = bitcast i8* %0 to i64* 641 /// %4 = load i64, i64* %2 642 /// %5 = load i64, i64* %3 643 /// %6 = call i64 @llvm.bswap.i64(i64 %4) 644 /// %7 = call i64 @llvm.bswap.i64(i64 %5) 645 /// %8 = sub i64 %6, %7 646 /// %9 = icmp ne i64 %8, 0 647 /// br i1 %9, label %res_block, label %loadbb1 648 /// res_block: ; preds = %loadbb2, 649 /// %loadbb1, %loadbb 650 /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ] 651 /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ] 652 /// %10 = icmp ult i64 %phi.src1, %phi.src2 653 /// %11 = select i1 %10, i32 -1, i32 1 654 /// br label %endblock 655 /// loadbb1: ; preds = %loadbb 656 /// %12 = bitcast i32* %buffer2 to i8* 657 /// %13 = bitcast i32* %buffer1 to i8* 658 /// %14 = bitcast i8* %13 to i32* 659 /// %15 = bitcast i8* %12 to i32* 660 /// %16 = getelementptr i32, i32* %14, i32 2 661 /// %17 = getelementptr i32, i32* %15, i32 2 662 /// %18 = load i32, i32* %16 663 /// %19 = load i32, i32* %17 664 /// %20 = call i32 @llvm.bswap.i32(i32 %18) 665 /// %21 = call i32 @llvm.bswap.i32(i32 %19) 666 /// %22 = zext i32 %20 to i64 667 /// %23 = zext i32 %21 to i64 668 /// %24 = sub i64 %22, %23 669 /// %25 = icmp ne i64 %24, 0 670 /// br i1 %25, label %res_block, label %loadbb2 671 /// loadbb2: ; preds = %loadbb1 672 /// %26 = bitcast i32* %buffer2 to i8* 673 /// %27 = bitcast i32* %buffer1 to i8* 674 /// %28 = bitcast i8* %27 to i16* 675 /// %29 = bitcast i8* %26 to i16* 676 /// %30 = getelementptr i16, i16* %28, i16 6 677 /// %31 = getelementptr i16, i16* %29, i16 6 678 /// %32 = load i16, i16* %30 679 /// %33 = load i16, i16* %31 680 /// %34 = call i16 @llvm.bswap.i16(i16 %32) 681 /// %35 = call i16 @llvm.bswap.i16(i16 %33) 682 /// %36 = zext i16 %34 to i64 683 /// %37 = zext i16 %35 to i64 684 /// %38 = sub i64 %36, %37 685 /// %39 = icmp ne i64 %38, 0 686 /// br i1 %39, label %res_block, label %loadbb3 687 /// loadbb3: ; preds = %loadbb2 688 /// %40 = bitcast i32* %buffer2 to i8* 689 /// %41 = bitcast i32* %buffer1 to i8* 690 /// %42 = getelementptr i8, i8* %41, i8 14 691 /// %43 = getelementptr i8, i8* %40, i8 14 692 /// %44 = load i8, i8* %42 693 /// %45 = load i8, i8* %43 694 /// %46 = zext i8 %44 to i32 695 /// %47 = zext i8 %45 to i32 696 /// %48 = sub i32 %46, %47 697 /// br label %endblock 698 /// endblock: ; preds = %res_block, 699 /// %loadbb3 700 /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ] 701 /// ret i32 %phi.res 702 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, 703 const TargetLowering *TLI, const DataLayout *DL, 704 ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) { 705 NumMemCmpCalls++; 706 707 // Early exit from expansion if -Oz. 708 if (CI->getFunction()->hasMinSize()) 709 return false; 710 711 // Early exit from expansion if size is not a constant. 712 ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2)); 713 if (!SizeCast) { 714 NumMemCmpNotConstant++; 715 return false; 716 } 717 const uint64_t SizeVal = SizeCast->getZExtValue(); 718 719 if (SizeVal == 0) { 720 return false; 721 } 722 // TTI call to check if target would like to expand memcmp. Also, get the 723 // available load sizes. 724 const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); 725 bool OptForSize = CI->getFunction()->hasOptSize() || 726 llvm::shouldOptimizeForSize(CI->getParent(), PSI, BFI); 727 auto Options = TTI->enableMemCmpExpansion(OptForSize, 728 IsUsedForZeroCmp); 729 if (!Options) return false; 730 731 if (MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences()) 732 Options.NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock; 733 734 if (OptForSize && 735 MaxLoadsPerMemcmpOptSize.getNumOccurrences()) 736 Options.MaxNumLoads = MaxLoadsPerMemcmpOptSize; 737 738 if (!OptForSize && MaxLoadsPerMemcmp.getNumOccurrences()) 739 Options.MaxNumLoads = MaxLoadsPerMemcmp; 740 741 MemCmpExpansion Expansion(CI, SizeVal, Options, IsUsedForZeroCmp, *DL); 742 743 // Don't expand if this will require more loads than desired by the target. 744 if (Expansion.getNumLoads() == 0) { 745 NumMemCmpGreaterThanMax++; 746 return false; 747 } 748 749 NumMemCmpInlined++; 750 751 Value *Res = Expansion.getMemCmpExpansion(); 752 753 // Replace call with result of expansion and erase call. 754 CI->replaceAllUsesWith(Res); 755 CI->eraseFromParent(); 756 757 return true; 758 } 759 760 761 762 class ExpandMemCmpPass : public FunctionPass { 763 public: 764 static char ID; 765 766 ExpandMemCmpPass() : FunctionPass(ID) { 767 initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry()); 768 } 769 770 bool runOnFunction(Function &F) override { 771 if (skipFunction(F)) return false; 772 773 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 774 if (!TPC) { 775 return false; 776 } 777 const TargetLowering* TL = 778 TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering(); 779 780 const TargetLibraryInfo *TLI = 781 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 782 const TargetTransformInfo *TTI = 783 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 784 auto *PSI = &getAnalysis<ProfileSummaryInfoWrapperPass>().getPSI(); 785 auto *BFI = (PSI && PSI->hasProfileSummary()) ? 786 &getAnalysis<LazyBlockFrequencyInfoPass>().getBFI() : 787 nullptr; 788 auto PA = runImpl(F, TLI, TTI, TL, PSI, BFI); 789 return !PA.areAllPreserved(); 790 } 791 792 private: 793 void getAnalysisUsage(AnalysisUsage &AU) const override { 794 AU.addRequired<TargetLibraryInfoWrapperPass>(); 795 AU.addRequired<TargetTransformInfoWrapperPass>(); 796 AU.addRequired<ProfileSummaryInfoWrapperPass>(); 797 LazyBlockFrequencyInfoPass::getLazyBFIAnalysisUsage(AU); 798 FunctionPass::getAnalysisUsage(AU); 799 } 800 801 PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, 802 const TargetTransformInfo *TTI, 803 const TargetLowering* TL, 804 ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI); 805 // Returns true if a change was made. 806 bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, 807 const TargetTransformInfo *TTI, const TargetLowering* TL, 808 const DataLayout& DL, ProfileSummaryInfo *PSI, 809 BlockFrequencyInfo *BFI); 810 }; 811 812 bool ExpandMemCmpPass::runOnBlock( 813 BasicBlock &BB, const TargetLibraryInfo *TLI, 814 const TargetTransformInfo *TTI, const TargetLowering* TL, 815 const DataLayout& DL, ProfileSummaryInfo *PSI, BlockFrequencyInfo *BFI) { 816 for (Instruction& I : BB) { 817 CallInst *CI = dyn_cast<CallInst>(&I); 818 if (!CI) { 819 continue; 820 } 821 LibFunc Func; 822 if (TLI->getLibFunc(ImmutableCallSite(CI), Func) && 823 (Func == LibFunc_memcmp || Func == LibFunc_bcmp) && 824 expandMemCmp(CI, TTI, TL, &DL, PSI, BFI)) { 825 return true; 826 } 827 } 828 return false; 829 } 830 831 832 PreservedAnalyses ExpandMemCmpPass::runImpl( 833 Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, 834 const TargetLowering* TL, ProfileSummaryInfo *PSI, 835 BlockFrequencyInfo *BFI) { 836 const DataLayout& DL = F.getParent()->getDataLayout(); 837 bool MadeChanges = false; 838 for (auto BBIt = F.begin(); BBIt != F.end();) { 839 if (runOnBlock(*BBIt, TLI, TTI, TL, DL, PSI, BFI)) { 840 MadeChanges = true; 841 // If changes were made, restart the function from the beginning, since 842 // the structure of the function was changed. 843 BBIt = F.begin(); 844 } else { 845 ++BBIt; 846 } 847 } 848 return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all(); 849 } 850 851 } // namespace 852 853 char ExpandMemCmpPass::ID = 0; 854 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp", 855 "Expand memcmp() to load/stores", false, false) 856 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 857 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 858 INITIALIZE_PASS_DEPENDENCY(LazyBlockFrequencyInfoPass) 859 INITIALIZE_PASS_DEPENDENCY(ProfileSummaryInfoWrapperPass) 860 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp", 861 "Expand memcmp() to load/stores", false, false) 862 863 FunctionPass *llvm::createExpandMemCmpPass() { 864 return new ExpandMemCmpPass(); 865 } 866