1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This pass tries to expand memcmp() calls into optimally-sized loads and 10 // compares for the target. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "llvm/ADT/Statistic.h" 15 #include "llvm/Analysis/ConstantFolding.h" 16 #include "llvm/Analysis/TargetLibraryInfo.h" 17 #include "llvm/Analysis/TargetTransformInfo.h" 18 #include "llvm/Analysis/ValueTracking.h" 19 #include "llvm/CodeGen/TargetLowering.h" 20 #include "llvm/CodeGen/TargetPassConfig.h" 21 #include "llvm/CodeGen/TargetSubtargetInfo.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/InitializePasses.h" 24 25 using namespace llvm; 26 27 #define DEBUG_TYPE "expandmemcmp" 28 29 STATISTIC(NumMemCmpCalls, "Number of memcmp calls"); 30 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size"); 31 STATISTIC(NumMemCmpGreaterThanMax, 32 "Number of memcmp calls with size greater than max size"); 33 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls"); 34 35 static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock( 36 "memcmp-num-loads-per-block", cl::Hidden, cl::init(1), 37 cl::desc("The number of loads per basic block for inline expansion of " 38 "memcmp that is only being compared against zero.")); 39 40 static cl::opt<unsigned> MaxLoadsPerMemcmp( 41 "max-loads-per-memcmp", cl::Hidden, 42 cl::desc("Set maximum number of loads used in expanded memcmp")); 43 44 static cl::opt<unsigned> MaxLoadsPerMemcmpOptSize( 45 "max-loads-per-memcmp-opt-size", cl::Hidden, 46 cl::desc("Set maximum number of loads used in expanded memcmp for -Os/Oz")); 47 48 namespace { 49 50 51 // This class provides helper functions to expand a memcmp library call into an 52 // inline expansion. 53 class MemCmpExpansion { 54 struct ResultBlock { 55 BasicBlock *BB = nullptr; 56 PHINode *PhiSrc1 = nullptr; 57 PHINode *PhiSrc2 = nullptr; 58 59 ResultBlock() = default; 60 }; 61 62 CallInst *const CI; 63 ResultBlock ResBlock; 64 const uint64_t Size; 65 unsigned MaxLoadSize; 66 uint64_t NumLoadsNonOneByte; 67 const uint64_t NumLoadsPerBlockForZeroCmp; 68 std::vector<BasicBlock *> LoadCmpBlocks; 69 BasicBlock *EndBlock; 70 PHINode *PhiRes; 71 const bool IsUsedForZeroCmp; 72 const DataLayout &DL; 73 IRBuilder<> Builder; 74 // Represents the decomposition in blocks of the expansion. For example, 75 // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and 76 // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}. 77 struct LoadEntry { 78 LoadEntry(unsigned LoadSize, uint64_t Offset) 79 : LoadSize(LoadSize), Offset(Offset) { 80 } 81 82 // The size of the load for this block, in bytes. 83 unsigned LoadSize; 84 // The offset of this load from the base pointer, in bytes. 85 uint64_t Offset; 86 }; 87 using LoadEntryVector = SmallVector<LoadEntry, 8>; 88 LoadEntryVector LoadSequence; 89 90 void createLoadCmpBlocks(); 91 void createResultBlock(); 92 void setupResultBlockPHINodes(); 93 void setupEndBlockPHINodes(); 94 Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); 95 void emitLoadCompareBlock(unsigned BlockIndex); 96 void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 97 unsigned &LoadIndex); 98 void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned OffsetBytes); 99 void emitMemCmpResultBlock(); 100 Value *getMemCmpExpansionZeroCase(); 101 Value *getMemCmpEqZeroOneBlock(); 102 Value *getMemCmpOneBlock(); 103 Value *getPtrToElementAtOffset(Value *Source, Type *LoadSizeType, 104 uint64_t OffsetBytes); 105 106 static LoadEntryVector 107 computeGreedyLoadSequence(uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, 108 unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte); 109 static LoadEntryVector 110 computeOverlappingLoadSequence(uint64_t Size, unsigned MaxLoadSize, 111 unsigned MaxNumLoads, 112 unsigned &NumLoadsNonOneByte); 113 114 public: 115 MemCmpExpansion(CallInst *CI, uint64_t Size, 116 const TargetTransformInfo::MemCmpExpansionOptions &Options, 117 const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout); 118 119 unsigned getNumBlocks(); 120 uint64_t getNumLoads() const { return LoadSequence.size(); } 121 122 Value *getMemCmpExpansion(); 123 }; 124 125 MemCmpExpansion::LoadEntryVector MemCmpExpansion::computeGreedyLoadSequence( 126 uint64_t Size, llvm::ArrayRef<unsigned> LoadSizes, 127 const unsigned MaxNumLoads, unsigned &NumLoadsNonOneByte) { 128 NumLoadsNonOneByte = 0; 129 LoadEntryVector LoadSequence; 130 uint64_t Offset = 0; 131 while (Size && !LoadSizes.empty()) { 132 const unsigned LoadSize = LoadSizes.front(); 133 const uint64_t NumLoadsForThisSize = Size / LoadSize; 134 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { 135 // Do not expand if the total number of loads is larger than what the 136 // target allows. Note that it's important that we exit before completing 137 // the expansion to avoid using a ton of memory to store the expansion for 138 // large sizes. 139 return {}; 140 } 141 if (NumLoadsForThisSize > 0) { 142 for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { 143 LoadSequence.push_back({LoadSize, Offset}); 144 Offset += LoadSize; 145 } 146 if (LoadSize > 1) 147 ++NumLoadsNonOneByte; 148 Size = Size % LoadSize; 149 } 150 LoadSizes = LoadSizes.drop_front(); 151 } 152 return LoadSequence; 153 } 154 155 MemCmpExpansion::LoadEntryVector 156 MemCmpExpansion::computeOverlappingLoadSequence(uint64_t Size, 157 const unsigned MaxLoadSize, 158 const unsigned MaxNumLoads, 159 unsigned &NumLoadsNonOneByte) { 160 // These are already handled by the greedy approach. 161 if (Size < 2 || MaxLoadSize < 2) 162 return {}; 163 164 // We try to do as many non-overlapping loads as possible starting from the 165 // beginning. 166 const uint64_t NumNonOverlappingLoads = Size / MaxLoadSize; 167 assert(NumNonOverlappingLoads && "there must be at least one load"); 168 // There remain 0 to (MaxLoadSize - 1) bytes to load, this will be done with 169 // an overlapping load. 170 Size = Size - NumNonOverlappingLoads * MaxLoadSize; 171 // Bail if we do not need an overloapping store, this is already handled by 172 // the greedy approach. 173 if (Size == 0) 174 return {}; 175 // Bail if the number of loads (non-overlapping + potential overlapping one) 176 // is larger than the max allowed. 177 if ((NumNonOverlappingLoads + 1) > MaxNumLoads) 178 return {}; 179 180 // Add non-overlapping loads. 181 LoadEntryVector LoadSequence; 182 uint64_t Offset = 0; 183 for (uint64_t I = 0; I < NumNonOverlappingLoads; ++I) { 184 LoadSequence.push_back({MaxLoadSize, Offset}); 185 Offset += MaxLoadSize; 186 } 187 188 // Add the last overlapping load. 189 assert(Size > 0 && Size < MaxLoadSize && "broken invariant"); 190 LoadSequence.push_back({MaxLoadSize, Offset - (MaxLoadSize - Size)}); 191 NumLoadsNonOneByte = 1; 192 return LoadSequence; 193 } 194 195 // Initialize the basic block structure required for expansion of memcmp call 196 // with given maximum load size and memcmp size parameter. 197 // This structure includes: 198 // 1. A list of load compare blocks - LoadCmpBlocks. 199 // 2. An EndBlock, split from original instruction point, which is the block to 200 // return from. 201 // 3. ResultBlock, block to branch to for early exit when a 202 // LoadCmpBlock finds a difference. 203 MemCmpExpansion::MemCmpExpansion( 204 CallInst *const CI, uint64_t Size, 205 const TargetTransformInfo::MemCmpExpansionOptions &Options, 206 const bool IsUsedForZeroCmp, const DataLayout &TheDataLayout) 207 : CI(CI), Size(Size), MaxLoadSize(0), NumLoadsNonOneByte(0), 208 NumLoadsPerBlockForZeroCmp(Options.NumLoadsPerBlock), 209 IsUsedForZeroCmp(IsUsedForZeroCmp), DL(TheDataLayout), Builder(CI) { 210 assert(Size > 0 && "zero blocks"); 211 // Scale the max size down if the target can load more bytes than we need. 212 llvm::ArrayRef<unsigned> LoadSizes(Options.LoadSizes); 213 while (!LoadSizes.empty() && LoadSizes.front() > Size) { 214 LoadSizes = LoadSizes.drop_front(); 215 } 216 assert(!LoadSizes.empty() && "cannot load Size bytes"); 217 MaxLoadSize = LoadSizes.front(); 218 // Compute the decomposition. 219 unsigned GreedyNumLoadsNonOneByte = 0; 220 LoadSequence = computeGreedyLoadSequence(Size, LoadSizes, Options.MaxNumLoads, 221 GreedyNumLoadsNonOneByte); 222 NumLoadsNonOneByte = GreedyNumLoadsNonOneByte; 223 assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); 224 // If we allow overlapping loads and the load sequence is not already optimal, 225 // use overlapping loads. 226 if (Options.AllowOverlappingLoads && 227 (LoadSequence.empty() || LoadSequence.size() > 2)) { 228 unsigned OverlappingNumLoadsNonOneByte = 0; 229 auto OverlappingLoads = computeOverlappingLoadSequence( 230 Size, MaxLoadSize, Options.MaxNumLoads, OverlappingNumLoadsNonOneByte); 231 if (!OverlappingLoads.empty() && 232 (LoadSequence.empty() || 233 OverlappingLoads.size() < LoadSequence.size())) { 234 LoadSequence = OverlappingLoads; 235 NumLoadsNonOneByte = OverlappingNumLoadsNonOneByte; 236 } 237 } 238 assert(LoadSequence.size() <= Options.MaxNumLoads && "broken invariant"); 239 } 240 241 unsigned MemCmpExpansion::getNumBlocks() { 242 if (IsUsedForZeroCmp) 243 return getNumLoads() / NumLoadsPerBlockForZeroCmp + 244 (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0); 245 return getNumLoads(); 246 } 247 248 void MemCmpExpansion::createLoadCmpBlocks() { 249 for (unsigned i = 0; i < getNumBlocks(); i++) { 250 BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", 251 EndBlock->getParent(), EndBlock); 252 LoadCmpBlocks.push_back(BB); 253 } 254 } 255 256 void MemCmpExpansion::createResultBlock() { 257 ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block", 258 EndBlock->getParent(), EndBlock); 259 } 260 261 /// Return a pointer to an element of type `LoadSizeType` at offset 262 /// `OffsetBytes`. 263 Value *MemCmpExpansion::getPtrToElementAtOffset(Value *Source, 264 Type *LoadSizeType, 265 uint64_t OffsetBytes) { 266 if (OffsetBytes > 0) { 267 auto *ByteType = Type::getInt8Ty(CI->getContext()); 268 Source = Builder.CreateConstGEP1_64( 269 ByteType, Builder.CreateBitCast(Source, ByteType->getPointerTo()), 270 OffsetBytes); 271 } 272 return Builder.CreateBitCast(Source, LoadSizeType->getPointerTo()); 273 } 274 275 // This function creates the IR instructions for loading and comparing 1 byte. 276 // It loads 1 byte from each source of the memcmp parameters with the given 277 // GEPIndex. It then subtracts the two loaded values and adds this result to the 278 // final phi node for selecting the memcmp result. 279 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, 280 unsigned OffsetBytes) { 281 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 282 Type *LoadSizeType = Type::getInt8Ty(CI->getContext()); 283 Value *Source1 = 284 getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, OffsetBytes); 285 Value *Source2 = 286 getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, OffsetBytes); 287 288 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 289 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 290 291 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext())); 292 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext())); 293 Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2); 294 295 PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]); 296 297 if (BlockIndex < (LoadCmpBlocks.size() - 1)) { 298 // Early exit branch if difference found to EndBlock. Otherwise, continue to 299 // next LoadCmpBlock, 300 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, 301 ConstantInt::get(Diff->getType(), 0)); 302 BranchInst *CmpBr = 303 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); 304 Builder.Insert(CmpBr); 305 } else { 306 // The last block has an unconditional branch to EndBlock. 307 BranchInst *CmpBr = BranchInst::Create(EndBlock); 308 Builder.Insert(CmpBr); 309 } 310 } 311 312 /// Generate an equality comparison for one or more pairs of loaded values. 313 /// This is used in the case where the memcmp() call is compared equal or not 314 /// equal to zero. 315 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, 316 unsigned &LoadIndex) { 317 assert(LoadIndex < getNumLoads() && 318 "getCompareLoadPairs() called with no remaining loads"); 319 std::vector<Value *> XorList, OrList; 320 Value *Diff = nullptr; 321 322 const unsigned NumLoads = 323 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp); 324 325 // For a single-block expansion, start inserting before the memcmp call. 326 if (LoadCmpBlocks.empty()) 327 Builder.SetInsertPoint(CI); 328 else 329 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 330 331 Value *Cmp = nullptr; 332 // If we have multiple loads per block, we need to generate a composite 333 // comparison using xor+or. The type for the combinations is the largest load 334 // type. 335 IntegerType *const MaxLoadType = 336 NumLoads == 1 ? nullptr 337 : IntegerType::get(CI->getContext(), MaxLoadSize * 8); 338 for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { 339 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; 340 341 IntegerType *LoadSizeType = 342 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 343 344 Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, 345 CurLoadEntry.Offset); 346 Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, 347 CurLoadEntry.Offset); 348 349 // Get a constant or load a value for each source address. 350 Value *LoadSrc1 = nullptr; 351 if (auto *Source1C = dyn_cast<Constant>(Source1)) 352 LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL); 353 if (!LoadSrc1) 354 LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 355 356 Value *LoadSrc2 = nullptr; 357 if (auto *Source2C = dyn_cast<Constant>(Source2)) 358 LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL); 359 if (!LoadSrc2) 360 LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 361 362 if (NumLoads != 1) { 363 if (LoadSizeType != MaxLoadType) { 364 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); 365 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); 366 } 367 // If we have multiple loads per block, we need to generate a composite 368 // comparison using xor+or. 369 Diff = Builder.CreateXor(LoadSrc1, LoadSrc2); 370 Diff = Builder.CreateZExt(Diff, MaxLoadType); 371 XorList.push_back(Diff); 372 } else { 373 // If there's only one load per block, we just compare the loaded values. 374 Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2); 375 } 376 } 377 378 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> { 379 std::vector<Value *> OutList; 380 for (unsigned i = 0; i < InList.size() - 1; i = i + 2) { 381 Value *Or = Builder.CreateOr(InList[i], InList[i + 1]); 382 OutList.push_back(Or); 383 } 384 if (InList.size() % 2 != 0) 385 OutList.push_back(InList.back()); 386 return OutList; 387 }; 388 389 if (!Cmp) { 390 // Pairwise OR the XOR results. 391 OrList = pairWiseOr(XorList); 392 393 // Pairwise OR the OR results until one result left. 394 while (OrList.size() != 1) { 395 OrList = pairWiseOr(OrList); 396 } 397 398 assert(Diff && "Failed to find comparison diff"); 399 Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0)); 400 } 401 402 return Cmp; 403 } 404 405 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 406 unsigned &LoadIndex) { 407 Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); 408 409 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 410 ? EndBlock 411 : LoadCmpBlocks[BlockIndex + 1]; 412 // Early exit branch if difference found to ResultBlock. Otherwise, 413 // continue to next LoadCmpBlock or EndBlock. 414 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); 415 Builder.Insert(CmpBr); 416 417 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 418 // since early exit to ResultBlock was not taken (no difference was found in 419 // any of the bytes). 420 if (BlockIndex == LoadCmpBlocks.size() - 1) { 421 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 422 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 423 } 424 } 425 426 // This function creates the IR intructions for loading and comparing using the 427 // given LoadSize. It loads the number of bytes specified by LoadSize from each 428 // source of the memcmp parameters. It then does a subtract to see if there was 429 // a difference in the loaded values. If a difference is found, it branches 430 // with an early exit to the ResultBlock for calculating which source was 431 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or 432 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with 433 // a special case through emitLoadCompareByteBlock. The special handling can 434 // simply subtract the loaded values and add it to the result phi node. 435 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { 436 // There is one load per block in this case, BlockIndex == LoadIndex. 437 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; 438 439 if (CurLoadEntry.LoadSize == 1) { 440 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, CurLoadEntry.Offset); 441 return; 442 } 443 444 Type *LoadSizeType = 445 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 446 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 447 assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); 448 449 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 450 451 Value *Source1 = getPtrToElementAtOffset(CI->getArgOperand(0), LoadSizeType, 452 CurLoadEntry.Offset); 453 Value *Source2 = getPtrToElementAtOffset(CI->getArgOperand(1), LoadSizeType, 454 CurLoadEntry.Offset); 455 456 // Load LoadSizeType from the base address. 457 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 458 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 459 460 if (DL.isLittleEndian()) { 461 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 462 Intrinsic::bswap, LoadSizeType); 463 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); 464 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); 465 } 466 467 if (LoadSizeType != MaxLoadType) { 468 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); 469 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); 470 } 471 472 // Add the loaded values to the phi nodes for calculating memcmp result only 473 // if result is not used in a zero equality. 474 if (!IsUsedForZeroCmp) { 475 ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]); 476 ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]); 477 } 478 479 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2); 480 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 481 ? EndBlock 482 : LoadCmpBlocks[BlockIndex + 1]; 483 // Early exit branch if difference found to ResultBlock. Otherwise, continue 484 // to next LoadCmpBlock or EndBlock. 485 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); 486 Builder.Insert(CmpBr); 487 488 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 489 // since early exit to ResultBlock was not taken (no difference was found in 490 // any of the bytes). 491 if (BlockIndex == LoadCmpBlocks.size() - 1) { 492 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 493 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 494 } 495 } 496 497 // This function populates the ResultBlock with a sequence to calculate the 498 // memcmp result. It compares the two loaded source values and returns -1 if 499 // src1 < src2 and 1 if src1 > src2. 500 void MemCmpExpansion::emitMemCmpResultBlock() { 501 // Special case: if memcmp result is used in a zero equality, result does not 502 // need to be calculated and can simply return 1. 503 if (IsUsedForZeroCmp) { 504 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 505 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 506 Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1); 507 PhiRes->addIncoming(Res, ResBlock.BB); 508 BranchInst *NewBr = BranchInst::Create(EndBlock); 509 Builder.Insert(NewBr); 510 return; 511 } 512 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 513 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 514 515 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1, 516 ResBlock.PhiSrc2); 517 518 Value *Res = 519 Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1), 520 ConstantInt::get(Builder.getInt32Ty(), 1)); 521 522 BranchInst *NewBr = BranchInst::Create(EndBlock); 523 Builder.Insert(NewBr); 524 PhiRes->addIncoming(Res, ResBlock.BB); 525 } 526 527 void MemCmpExpansion::setupResultBlockPHINodes() { 528 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 529 Builder.SetInsertPoint(ResBlock.BB); 530 // Note: this assumes one load per block. 531 ResBlock.PhiSrc1 = 532 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); 533 ResBlock.PhiSrc2 = 534 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); 535 } 536 537 void MemCmpExpansion::setupEndBlockPHINodes() { 538 Builder.SetInsertPoint(&EndBlock->front()); 539 PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); 540 } 541 542 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { 543 unsigned LoadIndex = 0; 544 // This loop populates each of the LoadCmpBlocks with the IR sequence to 545 // handle multiple loads per block. 546 for (unsigned I = 0; I < getNumBlocks(); ++I) { 547 emitLoadCompareBlockMultipleLoads(I, LoadIndex); 548 } 549 550 emitMemCmpResultBlock(); 551 return PhiRes; 552 } 553 554 /// A memcmp expansion that compares equality with 0 and only has one block of 555 /// load and compare can bypass the compare, branch, and phi IR that is required 556 /// in the general case. 557 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { 558 unsigned LoadIndex = 0; 559 Value *Cmp = getCompareLoadPairs(0, LoadIndex); 560 assert(LoadIndex == getNumLoads() && "some entries were not consumed"); 561 return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); 562 } 563 564 /// A memcmp expansion that only has one block of load and compare can bypass 565 /// the compare, branch, and phi IR that is required in the general case. 566 Value *MemCmpExpansion::getMemCmpOneBlock() { 567 Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); 568 Value *Source1 = CI->getArgOperand(0); 569 Value *Source2 = CI->getArgOperand(1); 570 571 // Cast source to LoadSizeType*. 572 if (Source1->getType() != LoadSizeType) 573 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 574 if (Source2->getType() != LoadSizeType) 575 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 576 577 // Load LoadSizeType from the base address. 578 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 579 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 580 581 if (DL.isLittleEndian() && Size != 1) { 582 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 583 Intrinsic::bswap, LoadSizeType); 584 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); 585 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); 586 } 587 588 if (Size < 4) { 589 // The i8 and i16 cases don't need compares. We zext the loaded values and 590 // subtract them to get the suitable negative, zero, or positive i32 result. 591 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty()); 592 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty()); 593 return Builder.CreateSub(LoadSrc1, LoadSrc2); 594 } 595 596 // The result of memcmp is negative, zero, or positive, so produce that by 597 // subtracting 2 extended compare bits: sub (ugt, ult). 598 // If a target prefers to use selects to get -1/0/1, they should be able 599 // to transform this later. The inverse transform (going from selects to math) 600 // may not be possible in the DAG because the selects got converted into 601 // branches before we got there. 602 Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2); 603 Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2); 604 Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); 605 Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); 606 return Builder.CreateSub(ZextUGT, ZextULT); 607 } 608 609 // This function expands the memcmp call into an inline expansion and returns 610 // the memcmp result. 611 Value *MemCmpExpansion::getMemCmpExpansion() { 612 // Create the basic block framework for a multi-block expansion. 613 if (getNumBlocks() != 1) { 614 BasicBlock *StartBlock = CI->getParent(); 615 EndBlock = StartBlock->splitBasicBlock(CI, "endblock"); 616 setupEndBlockPHINodes(); 617 createResultBlock(); 618 619 // If return value of memcmp is not used in a zero equality, we need to 620 // calculate which source was larger. The calculation requires the 621 // two loaded source values of each load compare block. 622 // These will be saved in the phi nodes created by setupResultBlockPHINodes. 623 if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); 624 625 // Create the number of required load compare basic blocks. 626 createLoadCmpBlocks(); 627 628 // Update the terminator added by splitBasicBlock to branch to the first 629 // LoadCmpBlock. 630 StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); 631 } 632 633 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 634 635 if (IsUsedForZeroCmp) 636 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() 637 : getMemCmpExpansionZeroCase(); 638 639 if (getNumBlocks() == 1) 640 return getMemCmpOneBlock(); 641 642 for (unsigned I = 0; I < getNumBlocks(); ++I) { 643 emitLoadCompareBlock(I); 644 } 645 646 emitMemCmpResultBlock(); 647 return PhiRes; 648 } 649 650 // This function checks to see if an expansion of memcmp can be generated. 651 // It checks for constant compare size that is less than the max inline size. 652 // If an expansion cannot occur, returns false to leave as a library call. 653 // Otherwise, the library call is replaced with a new IR instruction sequence. 654 /// We want to transform: 655 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15) 656 /// To: 657 /// loadbb: 658 /// %0 = bitcast i32* %buffer2 to i8* 659 /// %1 = bitcast i32* %buffer1 to i8* 660 /// %2 = bitcast i8* %1 to i64* 661 /// %3 = bitcast i8* %0 to i64* 662 /// %4 = load i64, i64* %2 663 /// %5 = load i64, i64* %3 664 /// %6 = call i64 @llvm.bswap.i64(i64 %4) 665 /// %7 = call i64 @llvm.bswap.i64(i64 %5) 666 /// %8 = sub i64 %6, %7 667 /// %9 = icmp ne i64 %8, 0 668 /// br i1 %9, label %res_block, label %loadbb1 669 /// res_block: ; preds = %loadbb2, 670 /// %loadbb1, %loadbb 671 /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ] 672 /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ] 673 /// %10 = icmp ult i64 %phi.src1, %phi.src2 674 /// %11 = select i1 %10, i32 -1, i32 1 675 /// br label %endblock 676 /// loadbb1: ; preds = %loadbb 677 /// %12 = bitcast i32* %buffer2 to i8* 678 /// %13 = bitcast i32* %buffer1 to i8* 679 /// %14 = bitcast i8* %13 to i32* 680 /// %15 = bitcast i8* %12 to i32* 681 /// %16 = getelementptr i32, i32* %14, i32 2 682 /// %17 = getelementptr i32, i32* %15, i32 2 683 /// %18 = load i32, i32* %16 684 /// %19 = load i32, i32* %17 685 /// %20 = call i32 @llvm.bswap.i32(i32 %18) 686 /// %21 = call i32 @llvm.bswap.i32(i32 %19) 687 /// %22 = zext i32 %20 to i64 688 /// %23 = zext i32 %21 to i64 689 /// %24 = sub i64 %22, %23 690 /// %25 = icmp ne i64 %24, 0 691 /// br i1 %25, label %res_block, label %loadbb2 692 /// loadbb2: ; preds = %loadbb1 693 /// %26 = bitcast i32* %buffer2 to i8* 694 /// %27 = bitcast i32* %buffer1 to i8* 695 /// %28 = bitcast i8* %27 to i16* 696 /// %29 = bitcast i8* %26 to i16* 697 /// %30 = getelementptr i16, i16* %28, i16 6 698 /// %31 = getelementptr i16, i16* %29, i16 6 699 /// %32 = load i16, i16* %30 700 /// %33 = load i16, i16* %31 701 /// %34 = call i16 @llvm.bswap.i16(i16 %32) 702 /// %35 = call i16 @llvm.bswap.i16(i16 %33) 703 /// %36 = zext i16 %34 to i64 704 /// %37 = zext i16 %35 to i64 705 /// %38 = sub i64 %36, %37 706 /// %39 = icmp ne i64 %38, 0 707 /// br i1 %39, label %res_block, label %loadbb3 708 /// loadbb3: ; preds = %loadbb2 709 /// %40 = bitcast i32* %buffer2 to i8* 710 /// %41 = bitcast i32* %buffer1 to i8* 711 /// %42 = getelementptr i8, i8* %41, i8 14 712 /// %43 = getelementptr i8, i8* %40, i8 14 713 /// %44 = load i8, i8* %42 714 /// %45 = load i8, i8* %43 715 /// %46 = zext i8 %44 to i32 716 /// %47 = zext i8 %45 to i32 717 /// %48 = sub i32 %46, %47 718 /// br label %endblock 719 /// endblock: ; preds = %res_block, 720 /// %loadbb3 721 /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ] 722 /// ret i32 %phi.res 723 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, 724 const TargetLowering *TLI, const DataLayout *DL) { 725 NumMemCmpCalls++; 726 727 // Early exit from expansion if -Oz. 728 if (CI->getFunction()->hasMinSize()) 729 return false; 730 731 // Early exit from expansion if size is not a constant. 732 ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2)); 733 if (!SizeCast) { 734 NumMemCmpNotConstant++; 735 return false; 736 } 737 const uint64_t SizeVal = SizeCast->getZExtValue(); 738 739 if (SizeVal == 0) { 740 return false; 741 } 742 // TTI call to check if target would like to expand memcmp. Also, get the 743 // available load sizes. 744 const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); 745 auto Options = TTI->enableMemCmpExpansion(CI->getFunction()->hasOptSize(), 746 IsUsedForZeroCmp); 747 if (!Options) return false; 748 749 if (MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences()) 750 Options.NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock; 751 752 if (CI->getFunction()->hasOptSize() && 753 MaxLoadsPerMemcmpOptSize.getNumOccurrences()) 754 Options.MaxNumLoads = MaxLoadsPerMemcmpOptSize; 755 756 if (!CI->getFunction()->hasOptSize() && MaxLoadsPerMemcmp.getNumOccurrences()) 757 Options.MaxNumLoads = MaxLoadsPerMemcmp; 758 759 MemCmpExpansion Expansion(CI, SizeVal, Options, IsUsedForZeroCmp, *DL); 760 761 // Don't expand if this will require more loads than desired by the target. 762 if (Expansion.getNumLoads() == 0) { 763 NumMemCmpGreaterThanMax++; 764 return false; 765 } 766 767 NumMemCmpInlined++; 768 769 Value *Res = Expansion.getMemCmpExpansion(); 770 771 // Replace call with result of expansion and erase call. 772 CI->replaceAllUsesWith(Res); 773 CI->eraseFromParent(); 774 775 return true; 776 } 777 778 779 780 class ExpandMemCmpPass : public FunctionPass { 781 public: 782 static char ID; 783 784 ExpandMemCmpPass() : FunctionPass(ID) { 785 initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry()); 786 } 787 788 bool runOnFunction(Function &F) override { 789 if (skipFunction(F)) return false; 790 791 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 792 if (!TPC) { 793 return false; 794 } 795 const TargetLowering* TL = 796 TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering(); 797 798 const TargetLibraryInfo *TLI = 799 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F); 800 const TargetTransformInfo *TTI = 801 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 802 auto PA = runImpl(F, TLI, TTI, TL); 803 return !PA.areAllPreserved(); 804 } 805 806 private: 807 void getAnalysisUsage(AnalysisUsage &AU) const override { 808 AU.addRequired<TargetLibraryInfoWrapperPass>(); 809 AU.addRequired<TargetTransformInfoWrapperPass>(); 810 FunctionPass::getAnalysisUsage(AU); 811 } 812 813 PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, 814 const TargetTransformInfo *TTI, 815 const TargetLowering* TL); 816 // Returns true if a change was made. 817 bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, 818 const TargetTransformInfo *TTI, const TargetLowering* TL, 819 const DataLayout& DL); 820 }; 821 822 bool ExpandMemCmpPass::runOnBlock( 823 BasicBlock &BB, const TargetLibraryInfo *TLI, 824 const TargetTransformInfo *TTI, const TargetLowering* TL, 825 const DataLayout& DL) { 826 for (Instruction& I : BB) { 827 CallInst *CI = dyn_cast<CallInst>(&I); 828 if (!CI) { 829 continue; 830 } 831 LibFunc Func; 832 if (TLI->getLibFunc(ImmutableCallSite(CI), Func) && 833 (Func == LibFunc_memcmp || Func == LibFunc_bcmp) && 834 expandMemCmp(CI, TTI, TL, &DL)) { 835 return true; 836 } 837 } 838 return false; 839 } 840 841 842 PreservedAnalyses ExpandMemCmpPass::runImpl( 843 Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, 844 const TargetLowering* TL) { 845 const DataLayout& DL = F.getParent()->getDataLayout(); 846 bool MadeChanges = false; 847 for (auto BBIt = F.begin(); BBIt != F.end();) { 848 if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) { 849 MadeChanges = true; 850 // If changes were made, restart the function from the beginning, since 851 // the structure of the function was changed. 852 BBIt = F.begin(); 853 } else { 854 ++BBIt; 855 } 856 } 857 return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all(); 858 } 859 860 } // namespace 861 862 char ExpandMemCmpPass::ID = 0; 863 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp", 864 "Expand memcmp() to load/stores", false, false) 865 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 866 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 867 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp", 868 "Expand memcmp() to load/stores", false, false) 869 870 FunctionPass *llvm::createExpandMemCmpPass() { 871 return new ExpandMemCmpPass(); 872 } 873