1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass tries to partially inline the fast path of well-known library 11 // functions, such as using square-root instructions for cases where sqrt() 12 // does not need to set errno. 13 // 14 //===----------------------------------------------------------------------===// 15 16 #include "llvm/ADT/Statistic.h" 17 #include "llvm/Analysis/ConstantFolding.h" 18 #include "llvm/Analysis/TargetLibraryInfo.h" 19 #include "llvm/Analysis/TargetTransformInfo.h" 20 #include "llvm/Analysis/ValueTracking.h" 21 #include "llvm/CodeGen/TargetPassConfig.h" 22 #include "llvm/IR/IRBuilder.h" 23 #include "llvm/Target/TargetLowering.h" 24 #include "llvm/Target/TargetSubtargetInfo.h" 25 #include "llvm/Transforms/Scalar.h" 26 #include "llvm/Transforms/Utils/BasicBlockUtils.h" 27 28 using namespace llvm; 29 30 #define DEBUG_TYPE "expandmemcmp" 31 32 STATISTIC(NumMemCmpCalls, "Number of memcmp calls"); 33 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size"); 34 STATISTIC(NumMemCmpGreaterThanMax, 35 "Number of memcmp calls with size greater than max size"); 36 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls"); 37 38 static cl::opt<unsigned> MemCmpNumLoadsPerBlock( 39 "memcmp-num-loads-per-block", cl::Hidden, cl::init(1), 40 cl::desc("The number of loads per basic block for inline expansion of " 41 "memcmp that is only being compared against zero.")); 42 43 namespace { 44 45 46 // This class provides helper functions to expand a memcmp library call into an 47 // inline expansion. 48 class MemCmpExpansion { 49 struct ResultBlock { 50 BasicBlock *BB = nullptr; 51 PHINode *PhiSrc1 = nullptr; 52 PHINode *PhiSrc2 = nullptr; 53 54 ResultBlock() = default; 55 }; 56 57 CallInst *const CI; 58 ResultBlock ResBlock; 59 const uint64_t Size; 60 unsigned MaxLoadSize; 61 uint64_t NumLoadsNonOneByte; 62 const uint64_t NumLoadsPerBlock; 63 std::vector<BasicBlock *> LoadCmpBlocks; 64 BasicBlock *EndBlock; 65 PHINode *PhiRes; 66 const bool IsUsedForZeroCmp; 67 const DataLayout &DL; 68 IRBuilder<> Builder; 69 // Represents the decomposition in blocks of the expansion. For example, 70 // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and 71 // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}. 72 // TODO(courbet): Involve the target more in this computation. On X86, 7 73 // bytes can be done more efficiently with two overlaping 4-byte loads than 74 // covering the interval with [{4, 0},{2, 4},{1, 6}}. 75 struct LoadEntry { 76 LoadEntry(unsigned LoadSize, uint64_t Offset) 77 : LoadSize(LoadSize), Offset(Offset) { 78 assert(Offset % LoadSize == 0 && "invalid load entry"); 79 } 80 81 uint64_t getGEPIndex() const { return Offset / LoadSize; } 82 83 // The size of the load for this block, in bytes. 84 const unsigned LoadSize; 85 // The offset of this load WRT the base pointer, in bytes. 86 const uint64_t Offset; 87 }; 88 SmallVector<LoadEntry, 8> LoadSequence; 89 90 void createLoadCmpBlocks(); 91 void createResultBlock(); 92 void setupResultBlockPHINodes(); 93 void setupEndBlockPHINodes(); 94 Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); 95 void emitLoadCompareBlock(unsigned BlockIndex); 96 void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 97 unsigned &LoadIndex); 98 void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex); 99 void emitMemCmpResultBlock(); 100 Value *getMemCmpExpansionZeroCase(); 101 Value *getMemCmpEqZeroOneBlock(); 102 Value *getMemCmpOneBlock(); 103 104 public: 105 MemCmpExpansion(CallInst *CI, uint64_t Size, 106 const TargetTransformInfo::MemCmpExpansionOptions &Options, 107 unsigned MaxNumLoads, const bool IsUsedForZeroCmp, 108 unsigned NumLoadsPerBlock, const DataLayout &DL); 109 110 unsigned getNumBlocks(); 111 uint64_t getNumLoads() const { return LoadSequence.size(); } 112 113 Value *getMemCmpExpansion(); 114 }; 115 116 // Initialize the basic block structure required for expansion of memcmp call 117 // with given maximum load size and memcmp size parameter. 118 // This structure includes: 119 // 1. A list of load compare blocks - LoadCmpBlocks. 120 // 2. An EndBlock, split from original instruction point, which is the block to 121 // return from. 122 // 3. ResultBlock, block to branch to for early exit when a 123 // LoadCmpBlock finds a difference. 124 MemCmpExpansion::MemCmpExpansion( 125 CallInst *const CI, uint64_t Size, 126 const TargetTransformInfo::MemCmpExpansionOptions &Options, 127 const unsigned MaxNumLoads, const bool IsUsedForZeroCmp, 128 const unsigned NumLoadsPerBlock, const DataLayout &TheDataLayout) 129 : CI(CI), 130 Size(Size), 131 MaxLoadSize(0), 132 NumLoadsNonOneByte(0), 133 NumLoadsPerBlock(NumLoadsPerBlock), 134 IsUsedForZeroCmp(IsUsedForZeroCmp), 135 DL(TheDataLayout), 136 Builder(CI) { 137 assert(Size > 0 && "zero blocks"); 138 // Scale the max size down if the target can load more bytes than we need. 139 size_t LoadSizeIndex = 0; 140 while (LoadSizeIndex < Options.LoadSizes.size() && 141 Options.LoadSizes[LoadSizeIndex] > Size) { 142 ++LoadSizeIndex; 143 } 144 this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex]; 145 // Compute the decomposition. 146 uint64_t CurSize = Size; 147 uint64_t Offset = 0; 148 while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) { 149 const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex]; 150 assert(LoadSize > 0 && "zero load size"); 151 const uint64_t NumLoadsForThisSize = CurSize / LoadSize; 152 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { 153 // Do not expand if the total number of loads is larger than what the 154 // target allows. Note that it's important that we exit before completing 155 // the expansion to avoid using a ton of memory to store the expansion for 156 // large sizes. 157 LoadSequence.clear(); 158 return; 159 } 160 if (NumLoadsForThisSize > 0) { 161 for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { 162 LoadSequence.push_back({LoadSize, Offset}); 163 Offset += LoadSize; 164 } 165 if (LoadSize > 1) { 166 ++NumLoadsNonOneByte; 167 } 168 CurSize = CurSize % LoadSize; 169 } 170 ++LoadSizeIndex; 171 } 172 assert(LoadSequence.size() <= MaxNumLoads && "broken invariant"); 173 } 174 175 unsigned MemCmpExpansion::getNumBlocks() { 176 if (IsUsedForZeroCmp) 177 return getNumLoads() / NumLoadsPerBlock + 178 (getNumLoads() % NumLoadsPerBlock != 0 ? 1 : 0); 179 return getNumLoads(); 180 } 181 182 void MemCmpExpansion::createLoadCmpBlocks() { 183 for (unsigned i = 0; i < getNumBlocks(); i++) { 184 BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", 185 EndBlock->getParent(), EndBlock); 186 LoadCmpBlocks.push_back(BB); 187 } 188 } 189 190 void MemCmpExpansion::createResultBlock() { 191 ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block", 192 EndBlock->getParent(), EndBlock); 193 } 194 195 // This function creates the IR instructions for loading and comparing 1 byte. 196 // It loads 1 byte from each source of the memcmp parameters with the given 197 // GEPIndex. It then subtracts the two loaded values and adds this result to the 198 // final phi node for selecting the memcmp result. 199 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, 200 unsigned GEPIndex) { 201 Value *Source1 = CI->getArgOperand(0); 202 Value *Source2 = CI->getArgOperand(1); 203 204 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 205 Type *LoadSizeType = Type::getInt8Ty(CI->getContext()); 206 // Cast source to LoadSizeType*. 207 if (Source1->getType() != LoadSizeType) 208 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 209 if (Source2->getType() != LoadSizeType) 210 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 211 212 // Get the base address using the GEPIndex. 213 if (GEPIndex != 0) { 214 Source1 = Builder.CreateGEP(LoadSizeType, Source1, 215 ConstantInt::get(LoadSizeType, GEPIndex)); 216 Source2 = Builder.CreateGEP(LoadSizeType, Source2, 217 ConstantInt::get(LoadSizeType, GEPIndex)); 218 } 219 220 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 221 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 222 223 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext())); 224 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext())); 225 Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2); 226 227 PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]); 228 229 if (BlockIndex < (LoadCmpBlocks.size() - 1)) { 230 // Early exit branch if difference found to EndBlock. Otherwise, continue to 231 // next LoadCmpBlock, 232 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, 233 ConstantInt::get(Diff->getType(), 0)); 234 BranchInst *CmpBr = 235 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); 236 Builder.Insert(CmpBr); 237 } else { 238 // The last block has an unconditional branch to EndBlock. 239 BranchInst *CmpBr = BranchInst::Create(EndBlock); 240 Builder.Insert(CmpBr); 241 } 242 } 243 244 /// Generate an equality comparison for one or more pairs of loaded values. 245 /// This is used in the case where the memcmp() call is compared equal or not 246 /// equal to zero. 247 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, 248 unsigned &LoadIndex) { 249 assert(LoadIndex < getNumLoads() && 250 "getCompareLoadPairs() called with no remaining loads"); 251 std::vector<Value *> XorList, OrList; 252 Value *Diff; 253 254 const unsigned NumLoads = 255 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlock); 256 257 // For a single-block expansion, start inserting before the memcmp call. 258 if (LoadCmpBlocks.empty()) 259 Builder.SetInsertPoint(CI); 260 else 261 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 262 263 Value *Cmp = nullptr; 264 // If we have multiple loads per block, we need to generate a composite 265 // comparison using xor+or. The type for the combinations is the largest load 266 // type. 267 IntegerType *const MaxLoadType = 268 NumLoads == 1 ? nullptr 269 : IntegerType::get(CI->getContext(), MaxLoadSize * 8); 270 for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { 271 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; 272 273 IntegerType *LoadSizeType = 274 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 275 276 Value *Source1 = CI->getArgOperand(0); 277 Value *Source2 = CI->getArgOperand(1); 278 279 // Cast source to LoadSizeType*. 280 if (Source1->getType() != LoadSizeType) 281 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 282 if (Source2->getType() != LoadSizeType) 283 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 284 285 // Get the base address using a GEP. 286 if (CurLoadEntry.Offset != 0) { 287 Source1 = Builder.CreateGEP( 288 LoadSizeType, Source1, 289 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 290 Source2 = Builder.CreateGEP( 291 LoadSizeType, Source2, 292 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 293 } 294 295 // Get a constant or load a value for each source address. 296 Value *LoadSrc1 = nullptr; 297 if (auto *Source1C = dyn_cast<Constant>(Source1)) 298 LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL); 299 if (!LoadSrc1) 300 LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 301 302 Value *LoadSrc2 = nullptr; 303 if (auto *Source2C = dyn_cast<Constant>(Source2)) 304 LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL); 305 if (!LoadSrc2) 306 LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 307 308 if (NumLoads != 1) { 309 if (LoadSizeType != MaxLoadType) { 310 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); 311 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); 312 } 313 // If we have multiple loads per block, we need to generate a composite 314 // comparison using xor+or. 315 Diff = Builder.CreateXor(LoadSrc1, LoadSrc2); 316 Diff = Builder.CreateZExt(Diff, MaxLoadType); 317 XorList.push_back(Diff); 318 } else { 319 // If there's only one load per block, we just compare the loaded values. 320 Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2); 321 } 322 } 323 324 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> { 325 std::vector<Value *> OutList; 326 for (unsigned i = 0; i < InList.size() - 1; i = i + 2) { 327 Value *Or = Builder.CreateOr(InList[i], InList[i + 1]); 328 OutList.push_back(Or); 329 } 330 if (InList.size() % 2 != 0) 331 OutList.push_back(InList.back()); 332 return OutList; 333 }; 334 335 if (!Cmp) { 336 // Pairwise OR the XOR results. 337 OrList = pairWiseOr(XorList); 338 339 // Pairwise OR the OR results until one result left. 340 while (OrList.size() != 1) { 341 OrList = pairWiseOr(OrList); 342 } 343 Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0)); 344 } 345 346 return Cmp; 347 } 348 349 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 350 unsigned &LoadIndex) { 351 Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); 352 353 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 354 ? EndBlock 355 : LoadCmpBlocks[BlockIndex + 1]; 356 // Early exit branch if difference found to ResultBlock. Otherwise, 357 // continue to next LoadCmpBlock or EndBlock. 358 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); 359 Builder.Insert(CmpBr); 360 361 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 362 // since early exit to ResultBlock was not taken (no difference was found in 363 // any of the bytes). 364 if (BlockIndex == LoadCmpBlocks.size() - 1) { 365 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 366 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 367 } 368 } 369 370 // This function creates the IR intructions for loading and comparing using the 371 // given LoadSize. It loads the number of bytes specified by LoadSize from each 372 // source of the memcmp parameters. It then does a subtract to see if there was 373 // a difference in the loaded values. If a difference is found, it branches 374 // with an early exit to the ResultBlock for calculating which source was 375 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or 376 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with 377 // a special case through emitLoadCompareByteBlock. The special handling can 378 // simply subtract the loaded values and add it to the result phi node. 379 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { 380 // There is one load per block in this case, BlockIndex == LoadIndex. 381 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; 382 383 if (CurLoadEntry.LoadSize == 1) { 384 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, 385 CurLoadEntry.getGEPIndex()); 386 return; 387 } 388 389 Type *LoadSizeType = 390 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 391 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 392 assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); 393 394 Value *Source1 = CI->getArgOperand(0); 395 Value *Source2 = CI->getArgOperand(1); 396 397 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 398 // Cast source to LoadSizeType*. 399 if (Source1->getType() != LoadSizeType) 400 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 401 if (Source2->getType() != LoadSizeType) 402 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 403 404 // Get the base address using a GEP. 405 if (CurLoadEntry.Offset != 0) { 406 Source1 = Builder.CreateGEP( 407 LoadSizeType, Source1, 408 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 409 Source2 = Builder.CreateGEP( 410 LoadSizeType, Source2, 411 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 412 } 413 414 // Load LoadSizeType from the base address. 415 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 416 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 417 418 if (DL.isLittleEndian()) { 419 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 420 Intrinsic::bswap, LoadSizeType); 421 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); 422 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); 423 } 424 425 if (LoadSizeType != MaxLoadType) { 426 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); 427 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); 428 } 429 430 // Add the loaded values to the phi nodes for calculating memcmp result only 431 // if result is not used in a zero equality. 432 if (!IsUsedForZeroCmp) { 433 ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]); 434 ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]); 435 } 436 437 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2); 438 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 439 ? EndBlock 440 : LoadCmpBlocks[BlockIndex + 1]; 441 // Early exit branch if difference found to ResultBlock. Otherwise, continue 442 // to next LoadCmpBlock or EndBlock. 443 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); 444 Builder.Insert(CmpBr); 445 446 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 447 // since early exit to ResultBlock was not taken (no difference was found in 448 // any of the bytes). 449 if (BlockIndex == LoadCmpBlocks.size() - 1) { 450 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 451 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 452 } 453 } 454 455 // This function populates the ResultBlock with a sequence to calculate the 456 // memcmp result. It compares the two loaded source values and returns -1 if 457 // src1 < src2 and 1 if src1 > src2. 458 void MemCmpExpansion::emitMemCmpResultBlock() { 459 // Special case: if memcmp result is used in a zero equality, result does not 460 // need to be calculated and can simply return 1. 461 if (IsUsedForZeroCmp) { 462 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 463 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 464 Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1); 465 PhiRes->addIncoming(Res, ResBlock.BB); 466 BranchInst *NewBr = BranchInst::Create(EndBlock); 467 Builder.Insert(NewBr); 468 return; 469 } 470 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 471 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 472 473 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1, 474 ResBlock.PhiSrc2); 475 476 Value *Res = 477 Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1), 478 ConstantInt::get(Builder.getInt32Ty(), 1)); 479 480 BranchInst *NewBr = BranchInst::Create(EndBlock); 481 Builder.Insert(NewBr); 482 PhiRes->addIncoming(Res, ResBlock.BB); 483 } 484 485 void MemCmpExpansion::setupResultBlockPHINodes() { 486 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 487 Builder.SetInsertPoint(ResBlock.BB); 488 // Note: this assumes one load per block. 489 ResBlock.PhiSrc1 = 490 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); 491 ResBlock.PhiSrc2 = 492 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); 493 } 494 495 void MemCmpExpansion::setupEndBlockPHINodes() { 496 Builder.SetInsertPoint(&EndBlock->front()); 497 PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); 498 } 499 500 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { 501 unsigned LoadIndex = 0; 502 // This loop populates each of the LoadCmpBlocks with the IR sequence to 503 // handle multiple loads per block. 504 for (unsigned I = 0; I < getNumBlocks(); ++I) { 505 emitLoadCompareBlockMultipleLoads(I, LoadIndex); 506 } 507 508 emitMemCmpResultBlock(); 509 return PhiRes; 510 } 511 512 /// A memcmp expansion that compares equality with 0 and only has one block of 513 /// load and compare can bypass the compare, branch, and phi IR that is required 514 /// in the general case. 515 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { 516 unsigned LoadIndex = 0; 517 Value *Cmp = getCompareLoadPairs(0, LoadIndex); 518 assert(LoadIndex == getNumLoads() && "some entries were not consumed"); 519 return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); 520 } 521 522 /// A memcmp expansion that only has one block of load and compare can bypass 523 /// the compare, branch, and phi IR that is required in the general case. 524 Value *MemCmpExpansion::getMemCmpOneBlock() { 525 assert(NumLoadsPerBlock == 1 && "Only handles one load pair per block"); 526 527 Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); 528 Value *Source1 = CI->getArgOperand(0); 529 Value *Source2 = CI->getArgOperand(1); 530 531 // Cast source to LoadSizeType*. 532 if (Source1->getType() != LoadSizeType) 533 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 534 if (Source2->getType() != LoadSizeType) 535 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 536 537 // Load LoadSizeType from the base address. 538 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 539 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 540 541 if (DL.isLittleEndian() && Size != 1) { 542 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 543 Intrinsic::bswap, LoadSizeType); 544 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); 545 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); 546 } 547 548 if (Size < 4) { 549 // The i8 and i16 cases don't need compares. We zext the loaded values and 550 // subtract them to get the suitable negative, zero, or positive i32 result. 551 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty()); 552 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty()); 553 return Builder.CreateSub(LoadSrc1, LoadSrc2); 554 } 555 556 // The result of memcmp is negative, zero, or positive, so produce that by 557 // subtracting 2 extended compare bits: sub (ugt, ult). 558 // If a target prefers to use selects to get -1/0/1, they should be able 559 // to transform this later. The inverse transform (going from selects to math) 560 // may not be possible in the DAG because the selects got converted into 561 // branches before we got there. 562 Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2); 563 Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2); 564 Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); 565 Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); 566 return Builder.CreateSub(ZextUGT, ZextULT); 567 } 568 569 // This function expands the memcmp call into an inline expansion and returns 570 // the memcmp result. 571 Value *MemCmpExpansion::getMemCmpExpansion() { 572 // A memcmp with zero-comparison with only one block of load and compare does 573 // not need to set up any extra blocks. This case could be handled in the DAG, 574 // but since we have all of the machinery to flexibly expand any memcpy here, 575 // we choose to handle this case too to avoid fragmented lowering. 576 if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || getNumBlocks() != 1) { 577 BasicBlock *StartBlock = CI->getParent(); 578 EndBlock = StartBlock->splitBasicBlock(CI, "endblock"); 579 setupEndBlockPHINodes(); 580 createResultBlock(); 581 582 // If return value of memcmp is not used in a zero equality, we need to 583 // calculate which source was larger. The calculation requires the 584 // two loaded source values of each load compare block. 585 // These will be saved in the phi nodes created by setupResultBlockPHINodes. 586 if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); 587 588 // Create the number of required load compare basic blocks. 589 createLoadCmpBlocks(); 590 591 // Update the terminator added by splitBasicBlock to branch to the first 592 // LoadCmpBlock. 593 StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); 594 } 595 596 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 597 598 if (IsUsedForZeroCmp) 599 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() 600 : getMemCmpExpansionZeroCase(); 601 602 // TODO: Handle more than one load pair per block in getMemCmpOneBlock(). 603 if (getNumBlocks() == 1 && NumLoadsPerBlock == 1) return getMemCmpOneBlock(); 604 605 for (unsigned I = 0; I < getNumBlocks(); ++I) { 606 emitLoadCompareBlock(I); 607 } 608 609 emitMemCmpResultBlock(); 610 return PhiRes; 611 } 612 613 // This function checks to see if an expansion of memcmp can be generated. 614 // It checks for constant compare size that is less than the max inline size. 615 // If an expansion cannot occur, returns false to leave as a library call. 616 // Otherwise, the library call is replaced with a new IR instruction sequence. 617 /// We want to transform: 618 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15) 619 /// To: 620 /// loadbb: 621 /// %0 = bitcast i32* %buffer2 to i8* 622 /// %1 = bitcast i32* %buffer1 to i8* 623 /// %2 = bitcast i8* %1 to i64* 624 /// %3 = bitcast i8* %0 to i64* 625 /// %4 = load i64, i64* %2 626 /// %5 = load i64, i64* %3 627 /// %6 = call i64 @llvm.bswap.i64(i64 %4) 628 /// %7 = call i64 @llvm.bswap.i64(i64 %5) 629 /// %8 = sub i64 %6, %7 630 /// %9 = icmp ne i64 %8, 0 631 /// br i1 %9, label %res_block, label %loadbb1 632 /// res_block: ; preds = %loadbb2, 633 /// %loadbb1, %loadbb 634 /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ] 635 /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ] 636 /// %10 = icmp ult i64 %phi.src1, %phi.src2 637 /// %11 = select i1 %10, i32 -1, i32 1 638 /// br label %endblock 639 /// loadbb1: ; preds = %loadbb 640 /// %12 = bitcast i32* %buffer2 to i8* 641 /// %13 = bitcast i32* %buffer1 to i8* 642 /// %14 = bitcast i8* %13 to i32* 643 /// %15 = bitcast i8* %12 to i32* 644 /// %16 = getelementptr i32, i32* %14, i32 2 645 /// %17 = getelementptr i32, i32* %15, i32 2 646 /// %18 = load i32, i32* %16 647 /// %19 = load i32, i32* %17 648 /// %20 = call i32 @llvm.bswap.i32(i32 %18) 649 /// %21 = call i32 @llvm.bswap.i32(i32 %19) 650 /// %22 = zext i32 %20 to i64 651 /// %23 = zext i32 %21 to i64 652 /// %24 = sub i64 %22, %23 653 /// %25 = icmp ne i64 %24, 0 654 /// br i1 %25, label %res_block, label %loadbb2 655 /// loadbb2: ; preds = %loadbb1 656 /// %26 = bitcast i32* %buffer2 to i8* 657 /// %27 = bitcast i32* %buffer1 to i8* 658 /// %28 = bitcast i8* %27 to i16* 659 /// %29 = bitcast i8* %26 to i16* 660 /// %30 = getelementptr i16, i16* %28, i16 6 661 /// %31 = getelementptr i16, i16* %29, i16 6 662 /// %32 = load i16, i16* %30 663 /// %33 = load i16, i16* %31 664 /// %34 = call i16 @llvm.bswap.i16(i16 %32) 665 /// %35 = call i16 @llvm.bswap.i16(i16 %33) 666 /// %36 = zext i16 %34 to i64 667 /// %37 = zext i16 %35 to i64 668 /// %38 = sub i64 %36, %37 669 /// %39 = icmp ne i64 %38, 0 670 /// br i1 %39, label %res_block, label %loadbb3 671 /// loadbb3: ; preds = %loadbb2 672 /// %40 = bitcast i32* %buffer2 to i8* 673 /// %41 = bitcast i32* %buffer1 to i8* 674 /// %42 = getelementptr i8, i8* %41, i8 14 675 /// %43 = getelementptr i8, i8* %40, i8 14 676 /// %44 = load i8, i8* %42 677 /// %45 = load i8, i8* %43 678 /// %46 = zext i8 %44 to i32 679 /// %47 = zext i8 %45 to i32 680 /// %48 = sub i32 %46, %47 681 /// br label %endblock 682 /// endblock: ; preds = %res_block, 683 /// %loadbb3 684 /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ] 685 /// ret i32 %phi.res 686 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, 687 const TargetLowering *TLI, const DataLayout *DL) { 688 NumMemCmpCalls++; 689 690 // Early exit from expansion if -Oz. 691 if (CI->getFunction()->optForMinSize()) 692 return false; 693 694 // Early exit from expansion if size is not a constant. 695 ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2)); 696 if (!SizeCast) { 697 NumMemCmpNotConstant++; 698 return false; 699 } 700 const uint64_t SizeVal = SizeCast->getZExtValue(); 701 702 if (SizeVal == 0) { 703 return false; 704 } 705 706 // TTI call to check if target would like to expand memcmp. Also, get the 707 // available load sizes. 708 const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); 709 const auto *const Options = TTI->enableMemCmpExpansion(IsUsedForZeroCmp); 710 if (!Options) return false; 711 712 const unsigned MaxNumLoads = 713 TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize()); 714 715 MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads, 716 IsUsedForZeroCmp, MemCmpNumLoadsPerBlock, *DL); 717 718 // Don't expand if this will require more loads than desired by the target. 719 if (Expansion.getNumLoads() == 0) { 720 NumMemCmpGreaterThanMax++; 721 return false; 722 } 723 724 NumMemCmpInlined++; 725 726 Value *Res = Expansion.getMemCmpExpansion(); 727 728 // Replace call with result of expansion and erase call. 729 CI->replaceAllUsesWith(Res); 730 CI->eraseFromParent(); 731 732 return true; 733 } 734 735 736 737 class ExpandMemCmpPass : public FunctionPass { 738 public: 739 static char ID; 740 741 ExpandMemCmpPass() : FunctionPass(ID) { 742 initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry()); 743 } 744 745 bool runOnFunction(Function &F) override { 746 if (skipFunction(F)) return false; 747 748 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 749 if (!TPC) { 750 return false; 751 } 752 const TargetLowering* TL = 753 TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering(); 754 755 const TargetLibraryInfo *TLI = 756 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); 757 const TargetTransformInfo *TTI = 758 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 759 auto PA = runImpl(F, TLI, TTI, TL); 760 return !PA.areAllPreserved(); 761 } 762 763 private: 764 void getAnalysisUsage(AnalysisUsage &AU) const override { 765 AU.addRequired<TargetLibraryInfoWrapperPass>(); 766 AU.addRequired<TargetTransformInfoWrapperPass>(); 767 FunctionPass::getAnalysisUsage(AU); 768 } 769 770 PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, 771 const TargetTransformInfo *TTI, 772 const TargetLowering* TL); 773 // Returns true if a change was made. 774 bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, 775 const TargetTransformInfo *TTI, const TargetLowering* TL, 776 const DataLayout& DL); 777 }; 778 779 bool ExpandMemCmpPass::runOnBlock( 780 BasicBlock &BB, const TargetLibraryInfo *TLI, 781 const TargetTransformInfo *TTI, const TargetLowering* TL, 782 const DataLayout& DL) { 783 for (Instruction& I : BB) { 784 CallInst *CI = dyn_cast<CallInst>(&I); 785 if (!CI) { 786 continue; 787 } 788 LibFunc Func; 789 if (TLI->getLibFunc(ImmutableCallSite(CI), Func) && 790 Func == LibFunc_memcmp && expandMemCmp(CI, TTI, TL, &DL)) { 791 return true; 792 } 793 } 794 return false; 795 } 796 797 798 PreservedAnalyses ExpandMemCmpPass::runImpl( 799 Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, 800 const TargetLowering* TL) { 801 const DataLayout& DL = F.getParent()->getDataLayout(); 802 bool MadeChanges = false; 803 for (auto BBIt = F.begin(); BBIt != F.end();) { 804 if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) { 805 MadeChanges = true; 806 // If changes were made, restart the function from the beginning, since 807 // the structure of the function was changed. 808 BBIt = F.begin(); 809 } else { 810 ++BBIt; 811 } 812 } 813 return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all(); 814 } 815 816 } // namespace 817 818 char ExpandMemCmpPass::ID = 0; 819 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp", 820 "Expand memcmp() to load/stores", false, false) 821 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 822 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 823 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp", 824 "Expand memcmp() to load/stores", false, false) 825 826 FunctionPass *llvm::createExpandMemCmpPass() { 827 return new ExpandMemCmpPass(); 828 } 829