1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass tries to expand memcmp() calls into optimally-sized loads and 11 // compares for the target. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/Analysis/ConstantFolding.h" 17 #include "llvm/Analysis/TargetLibraryInfo.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/CodeGen/TargetLowering.h" 21 #include "llvm/CodeGen/TargetPassConfig.h" 22 #include "llvm/CodeGen/TargetSubtargetInfo.h" 23 #include "llvm/IR/IRBuilder.h" 24 25 using namespace llvm; 26 27 #define DEBUG_TYPE "expandmemcmp" 28 29 STATISTIC(NumMemCmpCalls, "Number of memcmp calls"); 30 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size"); 31 STATISTIC(NumMemCmpGreaterThanMax, 32 "Number of memcmp calls with size greater than max size"); 33 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls"); 34 35 static cl::opt<unsigned> MemCmpEqZeroNumLoadsPerBlock( 36 "memcmp-num-loads-per-block", cl::Hidden, cl::init(1), 37 cl::desc("The number of loads per basic block for inline expansion of " 38 "memcmp that is only being compared against zero.")); 39 40 namespace { 41 42 43 // This class provides helper functions to expand a memcmp library call into an 44 // inline expansion. 45 class MemCmpExpansion { 46 struct ResultBlock { 47 BasicBlock *BB = nullptr; 48 PHINode *PhiSrc1 = nullptr; 49 PHINode *PhiSrc2 = nullptr; 50 51 ResultBlock() = default; 52 }; 53 54 CallInst *const CI; 55 ResultBlock ResBlock; 56 const uint64_t Size; 57 unsigned MaxLoadSize; 58 uint64_t NumLoadsNonOneByte; 59 const uint64_t NumLoadsPerBlockForZeroCmp; 60 std::vector<BasicBlock *> LoadCmpBlocks; 61 BasicBlock *EndBlock; 62 PHINode *PhiRes; 63 const bool IsUsedForZeroCmp; 64 const DataLayout &DL; 65 IRBuilder<> Builder; 66 // Represents the decomposition in blocks of the expansion. For example, 67 // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and 68 // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}. 69 // TODO(courbet): Involve the target more in this computation. On X86, 7 70 // bytes can be done more efficiently with two overlaping 4-byte loads than 71 // covering the interval with [{4, 0},{2, 4},{1, 6}}. 72 struct LoadEntry { 73 LoadEntry(unsigned LoadSize, uint64_t Offset) 74 : LoadSize(LoadSize), Offset(Offset) { 75 assert(Offset % LoadSize == 0 && "invalid load entry"); 76 } 77 78 uint64_t getGEPIndex() const { return Offset / LoadSize; } 79 80 // The size of the load for this block, in bytes. 81 const unsigned LoadSize; 82 // The offset of this load WRT the base pointer, in bytes. 83 const uint64_t Offset; 84 }; 85 SmallVector<LoadEntry, 8> LoadSequence; 86 87 void createLoadCmpBlocks(); 88 void createResultBlock(); 89 void setupResultBlockPHINodes(); 90 void setupEndBlockPHINodes(); 91 Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); 92 void emitLoadCompareBlock(unsigned BlockIndex); 93 void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 94 unsigned &LoadIndex); 95 void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex); 96 void emitMemCmpResultBlock(); 97 Value *getMemCmpExpansionZeroCase(); 98 Value *getMemCmpEqZeroOneBlock(); 99 Value *getMemCmpOneBlock(); 100 101 public: 102 MemCmpExpansion(CallInst *CI, uint64_t Size, 103 const TargetTransformInfo::MemCmpExpansionOptions &Options, 104 unsigned MaxNumLoads, const bool IsUsedForZeroCmp, 105 unsigned MaxLoadsPerBlockForZeroCmp, const DataLayout &TheDataLayout); 106 107 unsigned getNumBlocks(); 108 uint64_t getNumLoads() const { return LoadSequence.size(); } 109 110 Value *getMemCmpExpansion(); 111 }; 112 113 // Initialize the basic block structure required for expansion of memcmp call 114 // with given maximum load size and memcmp size parameter. 115 // This structure includes: 116 // 1. A list of load compare blocks - LoadCmpBlocks. 117 // 2. An EndBlock, split from original instruction point, which is the block to 118 // return from. 119 // 3. ResultBlock, block to branch to for early exit when a 120 // LoadCmpBlock finds a difference. 121 MemCmpExpansion::MemCmpExpansion( 122 CallInst *const CI, uint64_t Size, 123 const TargetTransformInfo::MemCmpExpansionOptions &Options, 124 const unsigned MaxNumLoads, const bool IsUsedForZeroCmp, 125 const unsigned MaxLoadsPerBlockForZeroCmp, const DataLayout &TheDataLayout) 126 : CI(CI), 127 Size(Size), 128 MaxLoadSize(0), 129 NumLoadsNonOneByte(0), 130 NumLoadsPerBlockForZeroCmp(MaxLoadsPerBlockForZeroCmp), 131 IsUsedForZeroCmp(IsUsedForZeroCmp), 132 DL(TheDataLayout), 133 Builder(CI) { 134 assert(Size > 0 && "zero blocks"); 135 // Scale the max size down if the target can load more bytes than we need. 136 size_t LoadSizeIndex = 0; 137 while (LoadSizeIndex < Options.LoadSizes.size() && 138 Options.LoadSizes[LoadSizeIndex] > Size) { 139 ++LoadSizeIndex; 140 } 141 this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex]; 142 // Compute the decomposition. 143 uint64_t CurSize = Size; 144 uint64_t Offset = 0; 145 while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) { 146 const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex]; 147 assert(LoadSize > 0 && "zero load size"); 148 const uint64_t NumLoadsForThisSize = CurSize / LoadSize; 149 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { 150 // Do not expand if the total number of loads is larger than what the 151 // target allows. Note that it's important that we exit before completing 152 // the expansion to avoid using a ton of memory to store the expansion for 153 // large sizes. 154 LoadSequence.clear(); 155 return; 156 } 157 if (NumLoadsForThisSize > 0) { 158 for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { 159 LoadSequence.push_back({LoadSize, Offset}); 160 Offset += LoadSize; 161 } 162 if (LoadSize > 1) { 163 ++NumLoadsNonOneByte; 164 } 165 CurSize = CurSize % LoadSize; 166 } 167 ++LoadSizeIndex; 168 } 169 assert(LoadSequence.size() <= MaxNumLoads && "broken invariant"); 170 } 171 172 unsigned MemCmpExpansion::getNumBlocks() { 173 if (IsUsedForZeroCmp) 174 return getNumLoads() / NumLoadsPerBlockForZeroCmp + 175 (getNumLoads() % NumLoadsPerBlockForZeroCmp != 0 ? 1 : 0); 176 return getNumLoads(); 177 } 178 179 void MemCmpExpansion::createLoadCmpBlocks() { 180 for (unsigned i = 0; i < getNumBlocks(); i++) { 181 BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", 182 EndBlock->getParent(), EndBlock); 183 LoadCmpBlocks.push_back(BB); 184 } 185 } 186 187 void MemCmpExpansion::createResultBlock() { 188 ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block", 189 EndBlock->getParent(), EndBlock); 190 } 191 192 // This function creates the IR instructions for loading and comparing 1 byte. 193 // It loads 1 byte from each source of the memcmp parameters with the given 194 // GEPIndex. It then subtracts the two loaded values and adds this result to the 195 // final phi node for selecting the memcmp result. 196 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, 197 unsigned GEPIndex) { 198 Value *Source1 = CI->getArgOperand(0); 199 Value *Source2 = CI->getArgOperand(1); 200 201 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 202 Type *LoadSizeType = Type::getInt8Ty(CI->getContext()); 203 // Cast source to LoadSizeType*. 204 if (Source1->getType() != LoadSizeType) 205 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 206 if (Source2->getType() != LoadSizeType) 207 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 208 209 // Get the base address using the GEPIndex. 210 if (GEPIndex != 0) { 211 Source1 = Builder.CreateGEP(LoadSizeType, Source1, 212 ConstantInt::get(LoadSizeType, GEPIndex)); 213 Source2 = Builder.CreateGEP(LoadSizeType, Source2, 214 ConstantInt::get(LoadSizeType, GEPIndex)); 215 } 216 217 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 218 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 219 220 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext())); 221 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext())); 222 Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2); 223 224 PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]); 225 226 if (BlockIndex < (LoadCmpBlocks.size() - 1)) { 227 // Early exit branch if difference found to EndBlock. Otherwise, continue to 228 // next LoadCmpBlock, 229 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, 230 ConstantInt::get(Diff->getType(), 0)); 231 BranchInst *CmpBr = 232 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); 233 Builder.Insert(CmpBr); 234 } else { 235 // The last block has an unconditional branch to EndBlock. 236 BranchInst *CmpBr = BranchInst::Create(EndBlock); 237 Builder.Insert(CmpBr); 238 } 239 } 240 241 /// Generate an equality comparison for one or more pairs of loaded values. 242 /// This is used in the case where the memcmp() call is compared equal or not 243 /// equal to zero. 244 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, 245 unsigned &LoadIndex) { 246 assert(LoadIndex < getNumLoads() && 247 "getCompareLoadPairs() called with no remaining loads"); 248 std::vector<Value *> XorList, OrList; 249 Value *Diff; 250 251 const unsigned NumLoads = 252 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlockForZeroCmp); 253 254 // For a single-block expansion, start inserting before the memcmp call. 255 if (LoadCmpBlocks.empty()) 256 Builder.SetInsertPoint(CI); 257 else 258 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 259 260 Value *Cmp = nullptr; 261 // If we have multiple loads per block, we need to generate a composite 262 // comparison using xor+or. The type for the combinations is the largest load 263 // type. 264 IntegerType *const MaxLoadType = 265 NumLoads == 1 ? nullptr 266 : IntegerType::get(CI->getContext(), MaxLoadSize * 8); 267 for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { 268 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; 269 270 IntegerType *LoadSizeType = 271 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 272 273 Value *Source1 = CI->getArgOperand(0); 274 Value *Source2 = CI->getArgOperand(1); 275 276 // Cast source to LoadSizeType*. 277 if (Source1->getType() != LoadSizeType) 278 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 279 if (Source2->getType() != LoadSizeType) 280 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 281 282 // Get the base address using a GEP. 283 if (CurLoadEntry.Offset != 0) { 284 Source1 = Builder.CreateGEP( 285 LoadSizeType, Source1, 286 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 287 Source2 = Builder.CreateGEP( 288 LoadSizeType, Source2, 289 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 290 } 291 292 // Get a constant or load a value for each source address. 293 Value *LoadSrc1 = nullptr; 294 if (auto *Source1C = dyn_cast<Constant>(Source1)) 295 LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL); 296 if (!LoadSrc1) 297 LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 298 299 Value *LoadSrc2 = nullptr; 300 if (auto *Source2C = dyn_cast<Constant>(Source2)) 301 LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL); 302 if (!LoadSrc2) 303 LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 304 305 if (NumLoads != 1) { 306 if (LoadSizeType != MaxLoadType) { 307 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); 308 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); 309 } 310 // If we have multiple loads per block, we need to generate a composite 311 // comparison using xor+or. 312 Diff = Builder.CreateXor(LoadSrc1, LoadSrc2); 313 Diff = Builder.CreateZExt(Diff, MaxLoadType); 314 XorList.push_back(Diff); 315 } else { 316 // If there's only one load per block, we just compare the loaded values. 317 Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2); 318 } 319 } 320 321 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> { 322 std::vector<Value *> OutList; 323 for (unsigned i = 0; i < InList.size() - 1; i = i + 2) { 324 Value *Or = Builder.CreateOr(InList[i], InList[i + 1]); 325 OutList.push_back(Or); 326 } 327 if (InList.size() % 2 != 0) 328 OutList.push_back(InList.back()); 329 return OutList; 330 }; 331 332 if (!Cmp) { 333 // Pairwise OR the XOR results. 334 OrList = pairWiseOr(XorList); 335 336 // Pairwise OR the OR results until one result left. 337 while (OrList.size() != 1) { 338 OrList = pairWiseOr(OrList); 339 } 340 Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0)); 341 } 342 343 return Cmp; 344 } 345 346 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 347 unsigned &LoadIndex) { 348 Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); 349 350 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 351 ? EndBlock 352 : LoadCmpBlocks[BlockIndex + 1]; 353 // Early exit branch if difference found to ResultBlock. Otherwise, 354 // continue to next LoadCmpBlock or EndBlock. 355 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); 356 Builder.Insert(CmpBr); 357 358 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 359 // since early exit to ResultBlock was not taken (no difference was found in 360 // any of the bytes). 361 if (BlockIndex == LoadCmpBlocks.size() - 1) { 362 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 363 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 364 } 365 } 366 367 // This function creates the IR intructions for loading and comparing using the 368 // given LoadSize. It loads the number of bytes specified by LoadSize from each 369 // source of the memcmp parameters. It then does a subtract to see if there was 370 // a difference in the loaded values. If a difference is found, it branches 371 // with an early exit to the ResultBlock for calculating which source was 372 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or 373 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with 374 // a special case through emitLoadCompareByteBlock. The special handling can 375 // simply subtract the loaded values and add it to the result phi node. 376 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { 377 // There is one load per block in this case, BlockIndex == LoadIndex. 378 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; 379 380 if (CurLoadEntry.LoadSize == 1) { 381 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, 382 CurLoadEntry.getGEPIndex()); 383 return; 384 } 385 386 Type *LoadSizeType = 387 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 388 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 389 assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); 390 391 Value *Source1 = CI->getArgOperand(0); 392 Value *Source2 = CI->getArgOperand(1); 393 394 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 395 // Cast source to LoadSizeType*. 396 if (Source1->getType() != LoadSizeType) 397 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 398 if (Source2->getType() != LoadSizeType) 399 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 400 401 // Get the base address using a GEP. 402 if (CurLoadEntry.Offset != 0) { 403 Source1 = Builder.CreateGEP( 404 LoadSizeType, Source1, 405 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 406 Source2 = Builder.CreateGEP( 407 LoadSizeType, Source2, 408 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 409 } 410 411 // Load LoadSizeType from the base address. 412 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 413 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 414 415 if (DL.isLittleEndian()) { 416 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 417 Intrinsic::bswap, LoadSizeType); 418 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); 419 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); 420 } 421 422 if (LoadSizeType != MaxLoadType) { 423 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); 424 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); 425 } 426 427 // Add the loaded values to the phi nodes for calculating memcmp result only 428 // if result is not used in a zero equality. 429 if (!IsUsedForZeroCmp) { 430 ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]); 431 ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]); 432 } 433 434 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2); 435 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 436 ? EndBlock 437 : LoadCmpBlocks[BlockIndex + 1]; 438 // Early exit branch if difference found to ResultBlock. Otherwise, continue 439 // to next LoadCmpBlock or EndBlock. 440 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); 441 Builder.Insert(CmpBr); 442 443 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 444 // since early exit to ResultBlock was not taken (no difference was found in 445 // any of the bytes). 446 if (BlockIndex == LoadCmpBlocks.size() - 1) { 447 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 448 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 449 } 450 } 451 452 // This function populates the ResultBlock with a sequence to calculate the 453 // memcmp result. It compares the two loaded source values and returns -1 if 454 // src1 < src2 and 1 if src1 > src2. 455 void MemCmpExpansion::emitMemCmpResultBlock() { 456 // Special case: if memcmp result is used in a zero equality, result does not 457 // need to be calculated and can simply return 1. 458 if (IsUsedForZeroCmp) { 459 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 460 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 461 Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1); 462 PhiRes->addIncoming(Res, ResBlock.BB); 463 BranchInst *NewBr = BranchInst::Create(EndBlock); 464 Builder.Insert(NewBr); 465 return; 466 } 467 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 468 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 469 470 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1, 471 ResBlock.PhiSrc2); 472 473 Value *Res = 474 Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1), 475 ConstantInt::get(Builder.getInt32Ty(), 1)); 476 477 BranchInst *NewBr = BranchInst::Create(EndBlock); 478 Builder.Insert(NewBr); 479 PhiRes->addIncoming(Res, ResBlock.BB); 480 } 481 482 void MemCmpExpansion::setupResultBlockPHINodes() { 483 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 484 Builder.SetInsertPoint(ResBlock.BB); 485 // Note: this assumes one load per block. 486 ResBlock.PhiSrc1 = 487 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); 488 ResBlock.PhiSrc2 = 489 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); 490 } 491 492 void MemCmpExpansion::setupEndBlockPHINodes() { 493 Builder.SetInsertPoint(&EndBlock->front()); 494 PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); 495 } 496 497 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { 498 unsigned LoadIndex = 0; 499 // This loop populates each of the LoadCmpBlocks with the IR sequence to 500 // handle multiple loads per block. 501 for (unsigned I = 0; I < getNumBlocks(); ++I) { 502 emitLoadCompareBlockMultipleLoads(I, LoadIndex); 503 } 504 505 emitMemCmpResultBlock(); 506 return PhiRes; 507 } 508 509 /// A memcmp expansion that compares equality with 0 and only has one block of 510 /// load and compare can bypass the compare, branch, and phi IR that is required 511 /// in the general case. 512 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { 513 unsigned LoadIndex = 0; 514 Value *Cmp = getCompareLoadPairs(0, LoadIndex); 515 assert(LoadIndex == getNumLoads() && "some entries were not consumed"); 516 return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); 517 } 518 519 /// A memcmp expansion that only has one block of load and compare can bypass 520 /// the compare, branch, and phi IR that is required in the general case. 521 Value *MemCmpExpansion::getMemCmpOneBlock() { 522 Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); 523 Value *Source1 = CI->getArgOperand(0); 524 Value *Source2 = CI->getArgOperand(1); 525 526 // Cast source to LoadSizeType*. 527 if (Source1->getType() != LoadSizeType) 528 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 529 if (Source2->getType() != LoadSizeType) 530 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 531 532 // Load LoadSizeType from the base address. 533 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 534 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 535 536 if (DL.isLittleEndian() && Size != 1) { 537 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 538 Intrinsic::bswap, LoadSizeType); 539 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); 540 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); 541 } 542 543 if (Size < 4) { 544 // The i8 and i16 cases don't need compares. We zext the loaded values and 545 // subtract them to get the suitable negative, zero, or positive i32 result. 546 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty()); 547 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty()); 548 return Builder.CreateSub(LoadSrc1, LoadSrc2); 549 } 550 551 // The result of memcmp is negative, zero, or positive, so produce that by 552 // subtracting 2 extended compare bits: sub (ugt, ult). 553 // If a target prefers to use selects to get -1/0/1, they should be able 554 // to transform this later. The inverse transform (going from selects to math) 555 // may not be possible in the DAG because the selects got converted into 556 // branches before we got there. 557 Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2); 558 Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2); 559 Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); 560 Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); 561 return Builder.CreateSub(ZextUGT, ZextULT); 562 } 563 564 // This function expands the memcmp call into an inline expansion and returns 565 // the memcmp result. 566 Value *MemCmpExpansion::getMemCmpExpansion() { 567 // Create the basic block framework for a multi-block expansion. 568 if (getNumBlocks() != 1) { 569 BasicBlock *StartBlock = CI->getParent(); 570 EndBlock = StartBlock->splitBasicBlock(CI, "endblock"); 571 setupEndBlockPHINodes(); 572 createResultBlock(); 573 574 // If return value of memcmp is not used in a zero equality, we need to 575 // calculate which source was larger. The calculation requires the 576 // two loaded source values of each load compare block. 577 // These will be saved in the phi nodes created by setupResultBlockPHINodes. 578 if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); 579 580 // Create the number of required load compare basic blocks. 581 createLoadCmpBlocks(); 582 583 // Update the terminator added by splitBasicBlock to branch to the first 584 // LoadCmpBlock. 585 StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); 586 } 587 588 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 589 590 if (IsUsedForZeroCmp) 591 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() 592 : getMemCmpExpansionZeroCase(); 593 594 if (getNumBlocks() == 1) 595 return getMemCmpOneBlock(); 596 597 for (unsigned I = 0; I < getNumBlocks(); ++I) { 598 emitLoadCompareBlock(I); 599 } 600 601 emitMemCmpResultBlock(); 602 return PhiRes; 603 } 604 605 // This function checks to see if an expansion of memcmp can be generated. 606 // It checks for constant compare size that is less than the max inline size. 607 // If an expansion cannot occur, returns false to leave as a library call. 608 // Otherwise, the library call is replaced with a new IR instruction sequence. 609 /// We want to transform: 610 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15) 611 /// To: 612 /// loadbb: 613 /// %0 = bitcast i32* %buffer2 to i8* 614 /// %1 = bitcast i32* %buffer1 to i8* 615 /// %2 = bitcast i8* %1 to i64* 616 /// %3 = bitcast i8* %0 to i64* 617 /// %4 = load i64, i64* %2 618 /// %5 = load i64, i64* %3 619 /// %6 = call i64 @llvm.bswap.i64(i64 %4) 620 /// %7 = call i64 @llvm.bswap.i64(i64 %5) 621 /// %8 = sub i64 %6, %7 622 /// %9 = icmp ne i64 %8, 0 623 /// br i1 %9, label %res_block, label %loadbb1 624 /// res_block: ; preds = %loadbb2, 625 /// %loadbb1, %loadbb 626 /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ] 627 /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ] 628 /// %10 = icmp ult i64 %phi.src1, %phi.src2 629 /// %11 = select i1 %10, i32 -1, i32 1 630 /// br label %endblock 631 /// loadbb1: ; preds = %loadbb 632 /// %12 = bitcast i32* %buffer2 to i8* 633 /// %13 = bitcast i32* %buffer1 to i8* 634 /// %14 = bitcast i8* %13 to i32* 635 /// %15 = bitcast i8* %12 to i32* 636 /// %16 = getelementptr i32, i32* %14, i32 2 637 /// %17 = getelementptr i32, i32* %15, i32 2 638 /// %18 = load i32, i32* %16 639 /// %19 = load i32, i32* %17 640 /// %20 = call i32 @llvm.bswap.i32(i32 %18) 641 /// %21 = call i32 @llvm.bswap.i32(i32 %19) 642 /// %22 = zext i32 %20 to i64 643 /// %23 = zext i32 %21 to i64 644 /// %24 = sub i64 %22, %23 645 /// %25 = icmp ne i64 %24, 0 646 /// br i1 %25, label %res_block, label %loadbb2 647 /// loadbb2: ; preds = %loadbb1 648 /// %26 = bitcast i32* %buffer2 to i8* 649 /// %27 = bitcast i32* %buffer1 to i8* 650 /// %28 = bitcast i8* %27 to i16* 651 /// %29 = bitcast i8* %26 to i16* 652 /// %30 = getelementptr i16, i16* %28, i16 6 653 /// %31 = getelementptr i16, i16* %29, i16 6 654 /// %32 = load i16, i16* %30 655 /// %33 = load i16, i16* %31 656 /// %34 = call i16 @llvm.bswap.i16(i16 %32) 657 /// %35 = call i16 @llvm.bswap.i16(i16 %33) 658 /// %36 = zext i16 %34 to i64 659 /// %37 = zext i16 %35 to i64 660 /// %38 = sub i64 %36, %37 661 /// %39 = icmp ne i64 %38, 0 662 /// br i1 %39, label %res_block, label %loadbb3 663 /// loadbb3: ; preds = %loadbb2 664 /// %40 = bitcast i32* %buffer2 to i8* 665 /// %41 = bitcast i32* %buffer1 to i8* 666 /// %42 = getelementptr i8, i8* %41, i8 14 667 /// %43 = getelementptr i8, i8* %40, i8 14 668 /// %44 = load i8, i8* %42 669 /// %45 = load i8, i8* %43 670 /// %46 = zext i8 %44 to i32 671 /// %47 = zext i8 %45 to i32 672 /// %48 = sub i32 %46, %47 673 /// br label %endblock 674 /// endblock: ; preds = %res_block, 675 /// %loadbb3 676 /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ] 677 /// ret i32 %phi.res 678 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, 679 const TargetLowering *TLI, const DataLayout *DL) { 680 NumMemCmpCalls++; 681 682 // Early exit from expansion if -Oz. 683 if (CI->getFunction()->optForMinSize()) 684 return false; 685 686 // Early exit from expansion if size is not a constant. 687 ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2)); 688 if (!SizeCast) { 689 NumMemCmpNotConstant++; 690 return false; 691 } 692 const uint64_t SizeVal = SizeCast->getZExtValue(); 693 694 if (SizeVal == 0) { 695 return false; 696 } 697 698 // TTI call to check if target would like to expand memcmp. Also, get the 699 // available load sizes. 700 const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); 701 const auto *const Options = TTI->enableMemCmpExpansion(IsUsedForZeroCmp); 702 if (!Options) return false; 703 704 const unsigned MaxNumLoads = 705 TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize()); 706 707 unsigned NumLoadsPerBlock = MemCmpEqZeroNumLoadsPerBlock.getNumOccurrences() 708 ? MemCmpEqZeroNumLoadsPerBlock 709 : TLI->getMemcmpEqZeroLoadsPerBlock(); 710 711 MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads, 712 IsUsedForZeroCmp, NumLoadsPerBlock, *DL); 713 714 // Don't expand if this will require more loads than desired by the target. 715 if (Expansion.getNumLoads() == 0) { 716 NumMemCmpGreaterThanMax++; 717 return false; 718 } 719 720 NumMemCmpInlined++; 721 722 Value *Res = Expansion.getMemCmpExpansion(); 723 724 // Replace call with result of expansion and erase call. 725 CI->replaceAllUsesWith(Res); 726 CI->eraseFromParent(); 727 728 return true; 729 } 730 731 732 733 class ExpandMemCmpPass : public FunctionPass { 734 public: 735 static char ID; 736 737 ExpandMemCmpPass() : FunctionPass(ID) { 738 initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry()); 739 } 740 741 bool runOnFunction(Function &F) override { 742 if (skipFunction(F)) return false; 743 744 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 745 if (!TPC) { 746 return false; 747 } 748 const TargetLowering* TL = 749 TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering(); 750 751 const TargetLibraryInfo *TLI = 752 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); 753 const TargetTransformInfo *TTI = 754 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 755 auto PA = runImpl(F, TLI, TTI, TL); 756 return !PA.areAllPreserved(); 757 } 758 759 private: 760 void getAnalysisUsage(AnalysisUsage &AU) const override { 761 AU.addRequired<TargetLibraryInfoWrapperPass>(); 762 AU.addRequired<TargetTransformInfoWrapperPass>(); 763 FunctionPass::getAnalysisUsage(AU); 764 } 765 766 PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, 767 const TargetTransformInfo *TTI, 768 const TargetLowering* TL); 769 // Returns true if a change was made. 770 bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, 771 const TargetTransformInfo *TTI, const TargetLowering* TL, 772 const DataLayout& DL); 773 }; 774 775 bool ExpandMemCmpPass::runOnBlock( 776 BasicBlock &BB, const TargetLibraryInfo *TLI, 777 const TargetTransformInfo *TTI, const TargetLowering* TL, 778 const DataLayout& DL) { 779 for (Instruction& I : BB) { 780 CallInst *CI = dyn_cast<CallInst>(&I); 781 if (!CI) { 782 continue; 783 } 784 LibFunc Func; 785 if (TLI->getLibFunc(ImmutableCallSite(CI), Func) && 786 Func == LibFunc_memcmp && expandMemCmp(CI, TTI, TL, &DL)) { 787 return true; 788 } 789 } 790 return false; 791 } 792 793 794 PreservedAnalyses ExpandMemCmpPass::runImpl( 795 Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, 796 const TargetLowering* TL) { 797 const DataLayout& DL = F.getParent()->getDataLayout(); 798 bool MadeChanges = false; 799 for (auto BBIt = F.begin(); BBIt != F.end();) { 800 if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) { 801 MadeChanges = true; 802 // If changes were made, restart the function from the beginning, since 803 // the structure of the function was changed. 804 BBIt = F.begin(); 805 } else { 806 ++BBIt; 807 } 808 } 809 return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all(); 810 } 811 812 } // namespace 813 814 char ExpandMemCmpPass::ID = 0; 815 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp", 816 "Expand memcmp() to load/stores", false, false) 817 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 818 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 819 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp", 820 "Expand memcmp() to load/stores", false, false) 821 822 FunctionPass *llvm::createExpandMemCmpPass() { 823 return new ExpandMemCmpPass(); 824 } 825