1 //===--- ExpandMemCmp.cpp - Expand memcmp() to load/stores ----------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This pass tries to expand memcmp() calls into optimally-sized loads and 11 // compares for the target. 12 // 13 //===----------------------------------------------------------------------===// 14 15 #include "llvm/ADT/Statistic.h" 16 #include "llvm/Analysis/ConstantFolding.h" 17 #include "llvm/Analysis/TargetLibraryInfo.h" 18 #include "llvm/Analysis/TargetTransformInfo.h" 19 #include "llvm/Analysis/ValueTracking.h" 20 #include "llvm/CodeGen/TargetLowering.h" 21 #include "llvm/CodeGen/TargetPassConfig.h" 22 #include "llvm/CodeGen/TargetSubtargetInfo.h" 23 #include "llvm/IR/IRBuilder.h" 24 25 using namespace llvm; 26 27 #define DEBUG_TYPE "expandmemcmp" 28 29 STATISTIC(NumMemCmpCalls, "Number of memcmp calls"); 30 STATISTIC(NumMemCmpNotConstant, "Number of memcmp calls without constant size"); 31 STATISTIC(NumMemCmpGreaterThanMax, 32 "Number of memcmp calls with size greater than max size"); 33 STATISTIC(NumMemCmpInlined, "Number of inlined memcmp calls"); 34 35 static cl::opt<unsigned> MemCmpNumLoadsPerBlock( 36 "memcmp-num-loads-per-block", cl::Hidden, cl::init(1), 37 cl::desc("The number of loads per basic block for inline expansion of " 38 "memcmp that is only being compared against zero.")); 39 40 namespace { 41 42 43 // This class provides helper functions to expand a memcmp library call into an 44 // inline expansion. 45 class MemCmpExpansion { 46 struct ResultBlock { 47 BasicBlock *BB = nullptr; 48 PHINode *PhiSrc1 = nullptr; 49 PHINode *PhiSrc2 = nullptr; 50 51 ResultBlock() = default; 52 }; 53 54 CallInst *const CI; 55 ResultBlock ResBlock; 56 const uint64_t Size; 57 unsigned MaxLoadSize; 58 uint64_t NumLoadsNonOneByte; 59 const uint64_t NumLoadsPerBlock; 60 std::vector<BasicBlock *> LoadCmpBlocks; 61 BasicBlock *EndBlock; 62 PHINode *PhiRes; 63 const bool IsUsedForZeroCmp; 64 const DataLayout &DL; 65 IRBuilder<> Builder; 66 // Represents the decomposition in blocks of the expansion. For example, 67 // comparing 33 bytes on X86+sse can be done with 2x16-byte loads and 68 // 1x1-byte load, which would be represented as [{16, 0}, {16, 16}, {32, 1}. 69 // TODO(courbet): Involve the target more in this computation. On X86, 7 70 // bytes can be done more efficiently with two overlaping 4-byte loads than 71 // covering the interval with [{4, 0},{2, 4},{1, 6}}. 72 struct LoadEntry { 73 LoadEntry(unsigned LoadSize, uint64_t Offset) 74 : LoadSize(LoadSize), Offset(Offset) { 75 assert(Offset % LoadSize == 0 && "invalid load entry"); 76 } 77 78 uint64_t getGEPIndex() const { return Offset / LoadSize; } 79 80 // The size of the load for this block, in bytes. 81 const unsigned LoadSize; 82 // The offset of this load WRT the base pointer, in bytes. 83 const uint64_t Offset; 84 }; 85 SmallVector<LoadEntry, 8> LoadSequence; 86 87 void createLoadCmpBlocks(); 88 void createResultBlock(); 89 void setupResultBlockPHINodes(); 90 void setupEndBlockPHINodes(); 91 Value *getCompareLoadPairs(unsigned BlockIndex, unsigned &LoadIndex); 92 void emitLoadCompareBlock(unsigned BlockIndex); 93 void emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 94 unsigned &LoadIndex); 95 void emitLoadCompareByteBlock(unsigned BlockIndex, unsigned GEPIndex); 96 void emitMemCmpResultBlock(); 97 Value *getMemCmpExpansionZeroCase(); 98 Value *getMemCmpEqZeroOneBlock(); 99 Value *getMemCmpOneBlock(); 100 101 public: 102 MemCmpExpansion(CallInst *CI, uint64_t Size, 103 const TargetTransformInfo::MemCmpExpansionOptions &Options, 104 unsigned MaxNumLoads, const bool IsUsedForZeroCmp, 105 unsigned NumLoadsPerBlock, const DataLayout &DL); 106 107 unsigned getNumBlocks(); 108 uint64_t getNumLoads() const { return LoadSequence.size(); } 109 110 Value *getMemCmpExpansion(); 111 }; 112 113 // Initialize the basic block structure required for expansion of memcmp call 114 // with given maximum load size and memcmp size parameter. 115 // This structure includes: 116 // 1. A list of load compare blocks - LoadCmpBlocks. 117 // 2. An EndBlock, split from original instruction point, which is the block to 118 // return from. 119 // 3. ResultBlock, block to branch to for early exit when a 120 // LoadCmpBlock finds a difference. 121 MemCmpExpansion::MemCmpExpansion( 122 CallInst *const CI, uint64_t Size, 123 const TargetTransformInfo::MemCmpExpansionOptions &Options, 124 const unsigned MaxNumLoads, const bool IsUsedForZeroCmp, 125 const unsigned NumLoadsPerBlock, const DataLayout &TheDataLayout) 126 : CI(CI), 127 Size(Size), 128 MaxLoadSize(0), 129 NumLoadsNonOneByte(0), 130 NumLoadsPerBlock(NumLoadsPerBlock), 131 IsUsedForZeroCmp(IsUsedForZeroCmp), 132 DL(TheDataLayout), 133 Builder(CI) { 134 assert(Size > 0 && "zero blocks"); 135 // Scale the max size down if the target can load more bytes than we need. 136 size_t LoadSizeIndex = 0; 137 while (LoadSizeIndex < Options.LoadSizes.size() && 138 Options.LoadSizes[LoadSizeIndex] > Size) { 139 ++LoadSizeIndex; 140 } 141 this->MaxLoadSize = Options.LoadSizes[LoadSizeIndex]; 142 // Compute the decomposition. 143 uint64_t CurSize = Size; 144 uint64_t Offset = 0; 145 while (CurSize && LoadSizeIndex < Options.LoadSizes.size()) { 146 const unsigned LoadSize = Options.LoadSizes[LoadSizeIndex]; 147 assert(LoadSize > 0 && "zero load size"); 148 const uint64_t NumLoadsForThisSize = CurSize / LoadSize; 149 if (LoadSequence.size() + NumLoadsForThisSize > MaxNumLoads) { 150 // Do not expand if the total number of loads is larger than what the 151 // target allows. Note that it's important that we exit before completing 152 // the expansion to avoid using a ton of memory to store the expansion for 153 // large sizes. 154 LoadSequence.clear(); 155 return; 156 } 157 if (NumLoadsForThisSize > 0) { 158 for (uint64_t I = 0; I < NumLoadsForThisSize; ++I) { 159 LoadSequence.push_back({LoadSize, Offset}); 160 Offset += LoadSize; 161 } 162 if (LoadSize > 1) { 163 ++NumLoadsNonOneByte; 164 } 165 CurSize = CurSize % LoadSize; 166 } 167 ++LoadSizeIndex; 168 } 169 assert(LoadSequence.size() <= MaxNumLoads && "broken invariant"); 170 } 171 172 unsigned MemCmpExpansion::getNumBlocks() { 173 if (IsUsedForZeroCmp) 174 return getNumLoads() / NumLoadsPerBlock + 175 (getNumLoads() % NumLoadsPerBlock != 0 ? 1 : 0); 176 return getNumLoads(); 177 } 178 179 void MemCmpExpansion::createLoadCmpBlocks() { 180 for (unsigned i = 0; i < getNumBlocks(); i++) { 181 BasicBlock *BB = BasicBlock::Create(CI->getContext(), "loadbb", 182 EndBlock->getParent(), EndBlock); 183 LoadCmpBlocks.push_back(BB); 184 } 185 } 186 187 void MemCmpExpansion::createResultBlock() { 188 ResBlock.BB = BasicBlock::Create(CI->getContext(), "res_block", 189 EndBlock->getParent(), EndBlock); 190 } 191 192 // This function creates the IR instructions for loading and comparing 1 byte. 193 // It loads 1 byte from each source of the memcmp parameters with the given 194 // GEPIndex. It then subtracts the two loaded values and adds this result to the 195 // final phi node for selecting the memcmp result. 196 void MemCmpExpansion::emitLoadCompareByteBlock(unsigned BlockIndex, 197 unsigned GEPIndex) { 198 Value *Source1 = CI->getArgOperand(0); 199 Value *Source2 = CI->getArgOperand(1); 200 201 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 202 Type *LoadSizeType = Type::getInt8Ty(CI->getContext()); 203 // Cast source to LoadSizeType*. 204 if (Source1->getType() != LoadSizeType) 205 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 206 if (Source2->getType() != LoadSizeType) 207 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 208 209 // Get the base address using the GEPIndex. 210 if (GEPIndex != 0) { 211 Source1 = Builder.CreateGEP(LoadSizeType, Source1, 212 ConstantInt::get(LoadSizeType, GEPIndex)); 213 Source2 = Builder.CreateGEP(LoadSizeType, Source2, 214 ConstantInt::get(LoadSizeType, GEPIndex)); 215 } 216 217 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 218 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 219 220 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Type::getInt32Ty(CI->getContext())); 221 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Type::getInt32Ty(CI->getContext())); 222 Value *Diff = Builder.CreateSub(LoadSrc1, LoadSrc2); 223 224 PhiRes->addIncoming(Diff, LoadCmpBlocks[BlockIndex]); 225 226 if (BlockIndex < (LoadCmpBlocks.size() - 1)) { 227 // Early exit branch if difference found to EndBlock. Otherwise, continue to 228 // next LoadCmpBlock, 229 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_NE, Diff, 230 ConstantInt::get(Diff->getType(), 0)); 231 BranchInst *CmpBr = 232 BranchInst::Create(EndBlock, LoadCmpBlocks[BlockIndex + 1], Cmp); 233 Builder.Insert(CmpBr); 234 } else { 235 // The last block has an unconditional branch to EndBlock. 236 BranchInst *CmpBr = BranchInst::Create(EndBlock); 237 Builder.Insert(CmpBr); 238 } 239 } 240 241 /// Generate an equality comparison for one or more pairs of loaded values. 242 /// This is used in the case where the memcmp() call is compared equal or not 243 /// equal to zero. 244 Value *MemCmpExpansion::getCompareLoadPairs(unsigned BlockIndex, 245 unsigned &LoadIndex) { 246 assert(LoadIndex < getNumLoads() && 247 "getCompareLoadPairs() called with no remaining loads"); 248 std::vector<Value *> XorList, OrList; 249 Value *Diff; 250 251 const unsigned NumLoads = 252 std::min(getNumLoads() - LoadIndex, NumLoadsPerBlock); 253 254 // For a single-block expansion, start inserting before the memcmp call. 255 if (LoadCmpBlocks.empty()) 256 Builder.SetInsertPoint(CI); 257 else 258 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 259 260 Value *Cmp = nullptr; 261 // If we have multiple loads per block, we need to generate a composite 262 // comparison using xor+or. The type for the combinations is the largest load 263 // type. 264 IntegerType *const MaxLoadType = 265 NumLoads == 1 ? nullptr 266 : IntegerType::get(CI->getContext(), MaxLoadSize * 8); 267 for (unsigned i = 0; i < NumLoads; ++i, ++LoadIndex) { 268 const LoadEntry &CurLoadEntry = LoadSequence[LoadIndex]; 269 270 IntegerType *LoadSizeType = 271 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 272 273 Value *Source1 = CI->getArgOperand(0); 274 Value *Source2 = CI->getArgOperand(1); 275 276 // Cast source to LoadSizeType*. 277 if (Source1->getType() != LoadSizeType) 278 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 279 if (Source2->getType() != LoadSizeType) 280 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 281 282 // Get the base address using a GEP. 283 if (CurLoadEntry.Offset != 0) { 284 Source1 = Builder.CreateGEP( 285 LoadSizeType, Source1, 286 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 287 Source2 = Builder.CreateGEP( 288 LoadSizeType, Source2, 289 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 290 } 291 292 // Get a constant or load a value for each source address. 293 Value *LoadSrc1 = nullptr; 294 if (auto *Source1C = dyn_cast<Constant>(Source1)) 295 LoadSrc1 = ConstantFoldLoadFromConstPtr(Source1C, LoadSizeType, DL); 296 if (!LoadSrc1) 297 LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 298 299 Value *LoadSrc2 = nullptr; 300 if (auto *Source2C = dyn_cast<Constant>(Source2)) 301 LoadSrc2 = ConstantFoldLoadFromConstPtr(Source2C, LoadSizeType, DL); 302 if (!LoadSrc2) 303 LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 304 305 if (NumLoads != 1) { 306 if (LoadSizeType != MaxLoadType) { 307 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); 308 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); 309 } 310 // If we have multiple loads per block, we need to generate a composite 311 // comparison using xor+or. 312 Diff = Builder.CreateXor(LoadSrc1, LoadSrc2); 313 Diff = Builder.CreateZExt(Diff, MaxLoadType); 314 XorList.push_back(Diff); 315 } else { 316 // If there's only one load per block, we just compare the loaded values. 317 Cmp = Builder.CreateICmpNE(LoadSrc1, LoadSrc2); 318 } 319 } 320 321 auto pairWiseOr = [&](std::vector<Value *> &InList) -> std::vector<Value *> { 322 std::vector<Value *> OutList; 323 for (unsigned i = 0; i < InList.size() - 1; i = i + 2) { 324 Value *Or = Builder.CreateOr(InList[i], InList[i + 1]); 325 OutList.push_back(Or); 326 } 327 if (InList.size() % 2 != 0) 328 OutList.push_back(InList.back()); 329 return OutList; 330 }; 331 332 if (!Cmp) { 333 // Pairwise OR the XOR results. 334 OrList = pairWiseOr(XorList); 335 336 // Pairwise OR the OR results until one result left. 337 while (OrList.size() != 1) { 338 OrList = pairWiseOr(OrList); 339 } 340 Cmp = Builder.CreateICmpNE(OrList[0], ConstantInt::get(Diff->getType(), 0)); 341 } 342 343 return Cmp; 344 } 345 346 void MemCmpExpansion::emitLoadCompareBlockMultipleLoads(unsigned BlockIndex, 347 unsigned &LoadIndex) { 348 Value *Cmp = getCompareLoadPairs(BlockIndex, LoadIndex); 349 350 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 351 ? EndBlock 352 : LoadCmpBlocks[BlockIndex + 1]; 353 // Early exit branch if difference found to ResultBlock. Otherwise, 354 // continue to next LoadCmpBlock or EndBlock. 355 BranchInst *CmpBr = BranchInst::Create(ResBlock.BB, NextBB, Cmp); 356 Builder.Insert(CmpBr); 357 358 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 359 // since early exit to ResultBlock was not taken (no difference was found in 360 // any of the bytes). 361 if (BlockIndex == LoadCmpBlocks.size() - 1) { 362 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 363 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 364 } 365 } 366 367 // This function creates the IR intructions for loading and comparing using the 368 // given LoadSize. It loads the number of bytes specified by LoadSize from each 369 // source of the memcmp parameters. It then does a subtract to see if there was 370 // a difference in the loaded values. If a difference is found, it branches 371 // with an early exit to the ResultBlock for calculating which source was 372 // larger. Otherwise, it falls through to the either the next LoadCmpBlock or 373 // the EndBlock if this is the last LoadCmpBlock. Loading 1 byte is handled with 374 // a special case through emitLoadCompareByteBlock. The special handling can 375 // simply subtract the loaded values and add it to the result phi node. 376 void MemCmpExpansion::emitLoadCompareBlock(unsigned BlockIndex) { 377 // There is one load per block in this case, BlockIndex == LoadIndex. 378 const LoadEntry &CurLoadEntry = LoadSequence[BlockIndex]; 379 380 if (CurLoadEntry.LoadSize == 1) { 381 MemCmpExpansion::emitLoadCompareByteBlock(BlockIndex, 382 CurLoadEntry.getGEPIndex()); 383 return; 384 } 385 386 Type *LoadSizeType = 387 IntegerType::get(CI->getContext(), CurLoadEntry.LoadSize * 8); 388 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 389 assert(CurLoadEntry.LoadSize <= MaxLoadSize && "Unexpected load type"); 390 391 Value *Source1 = CI->getArgOperand(0); 392 Value *Source2 = CI->getArgOperand(1); 393 394 Builder.SetInsertPoint(LoadCmpBlocks[BlockIndex]); 395 // Cast source to LoadSizeType*. 396 if (Source1->getType() != LoadSizeType) 397 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 398 if (Source2->getType() != LoadSizeType) 399 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 400 401 // Get the base address using a GEP. 402 if (CurLoadEntry.Offset != 0) { 403 Source1 = Builder.CreateGEP( 404 LoadSizeType, Source1, 405 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 406 Source2 = Builder.CreateGEP( 407 LoadSizeType, Source2, 408 ConstantInt::get(LoadSizeType, CurLoadEntry.getGEPIndex())); 409 } 410 411 // Load LoadSizeType from the base address. 412 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 413 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 414 415 if (DL.isLittleEndian()) { 416 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 417 Intrinsic::bswap, LoadSizeType); 418 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); 419 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); 420 } 421 422 if (LoadSizeType != MaxLoadType) { 423 LoadSrc1 = Builder.CreateZExt(LoadSrc1, MaxLoadType); 424 LoadSrc2 = Builder.CreateZExt(LoadSrc2, MaxLoadType); 425 } 426 427 // Add the loaded values to the phi nodes for calculating memcmp result only 428 // if result is not used in a zero equality. 429 if (!IsUsedForZeroCmp) { 430 ResBlock.PhiSrc1->addIncoming(LoadSrc1, LoadCmpBlocks[BlockIndex]); 431 ResBlock.PhiSrc2->addIncoming(LoadSrc2, LoadCmpBlocks[BlockIndex]); 432 } 433 434 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_EQ, LoadSrc1, LoadSrc2); 435 BasicBlock *NextBB = (BlockIndex == (LoadCmpBlocks.size() - 1)) 436 ? EndBlock 437 : LoadCmpBlocks[BlockIndex + 1]; 438 // Early exit branch if difference found to ResultBlock. Otherwise, continue 439 // to next LoadCmpBlock or EndBlock. 440 BranchInst *CmpBr = BranchInst::Create(NextBB, ResBlock.BB, Cmp); 441 Builder.Insert(CmpBr); 442 443 // Add a phi edge for the last LoadCmpBlock to Endblock with a value of 0 444 // since early exit to ResultBlock was not taken (no difference was found in 445 // any of the bytes). 446 if (BlockIndex == LoadCmpBlocks.size() - 1) { 447 Value *Zero = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 0); 448 PhiRes->addIncoming(Zero, LoadCmpBlocks[BlockIndex]); 449 } 450 } 451 452 // This function populates the ResultBlock with a sequence to calculate the 453 // memcmp result. It compares the two loaded source values and returns -1 if 454 // src1 < src2 and 1 if src1 > src2. 455 void MemCmpExpansion::emitMemCmpResultBlock() { 456 // Special case: if memcmp result is used in a zero equality, result does not 457 // need to be calculated and can simply return 1. 458 if (IsUsedForZeroCmp) { 459 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 460 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 461 Value *Res = ConstantInt::get(Type::getInt32Ty(CI->getContext()), 1); 462 PhiRes->addIncoming(Res, ResBlock.BB); 463 BranchInst *NewBr = BranchInst::Create(EndBlock); 464 Builder.Insert(NewBr); 465 return; 466 } 467 BasicBlock::iterator InsertPt = ResBlock.BB->getFirstInsertionPt(); 468 Builder.SetInsertPoint(ResBlock.BB, InsertPt); 469 470 Value *Cmp = Builder.CreateICmp(ICmpInst::ICMP_ULT, ResBlock.PhiSrc1, 471 ResBlock.PhiSrc2); 472 473 Value *Res = 474 Builder.CreateSelect(Cmp, ConstantInt::get(Builder.getInt32Ty(), -1), 475 ConstantInt::get(Builder.getInt32Ty(), 1)); 476 477 BranchInst *NewBr = BranchInst::Create(EndBlock); 478 Builder.Insert(NewBr); 479 PhiRes->addIncoming(Res, ResBlock.BB); 480 } 481 482 void MemCmpExpansion::setupResultBlockPHINodes() { 483 Type *MaxLoadType = IntegerType::get(CI->getContext(), MaxLoadSize * 8); 484 Builder.SetInsertPoint(ResBlock.BB); 485 // Note: this assumes one load per block. 486 ResBlock.PhiSrc1 = 487 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src1"); 488 ResBlock.PhiSrc2 = 489 Builder.CreatePHI(MaxLoadType, NumLoadsNonOneByte, "phi.src2"); 490 } 491 492 void MemCmpExpansion::setupEndBlockPHINodes() { 493 Builder.SetInsertPoint(&EndBlock->front()); 494 PhiRes = Builder.CreatePHI(Type::getInt32Ty(CI->getContext()), 2, "phi.res"); 495 } 496 497 Value *MemCmpExpansion::getMemCmpExpansionZeroCase() { 498 unsigned LoadIndex = 0; 499 // This loop populates each of the LoadCmpBlocks with the IR sequence to 500 // handle multiple loads per block. 501 for (unsigned I = 0; I < getNumBlocks(); ++I) { 502 emitLoadCompareBlockMultipleLoads(I, LoadIndex); 503 } 504 505 emitMemCmpResultBlock(); 506 return PhiRes; 507 } 508 509 /// A memcmp expansion that compares equality with 0 and only has one block of 510 /// load and compare can bypass the compare, branch, and phi IR that is required 511 /// in the general case. 512 Value *MemCmpExpansion::getMemCmpEqZeroOneBlock() { 513 unsigned LoadIndex = 0; 514 Value *Cmp = getCompareLoadPairs(0, LoadIndex); 515 assert(LoadIndex == getNumLoads() && "some entries were not consumed"); 516 return Builder.CreateZExt(Cmp, Type::getInt32Ty(CI->getContext())); 517 } 518 519 /// A memcmp expansion that only has one block of load and compare can bypass 520 /// the compare, branch, and phi IR that is required in the general case. 521 Value *MemCmpExpansion::getMemCmpOneBlock() { 522 assert(NumLoadsPerBlock == 1 && "Only handles one load pair per block"); 523 524 Type *LoadSizeType = IntegerType::get(CI->getContext(), Size * 8); 525 Value *Source1 = CI->getArgOperand(0); 526 Value *Source2 = CI->getArgOperand(1); 527 528 // Cast source to LoadSizeType*. 529 if (Source1->getType() != LoadSizeType) 530 Source1 = Builder.CreateBitCast(Source1, LoadSizeType->getPointerTo()); 531 if (Source2->getType() != LoadSizeType) 532 Source2 = Builder.CreateBitCast(Source2, LoadSizeType->getPointerTo()); 533 534 // Load LoadSizeType from the base address. 535 Value *LoadSrc1 = Builder.CreateLoad(LoadSizeType, Source1); 536 Value *LoadSrc2 = Builder.CreateLoad(LoadSizeType, Source2); 537 538 if (DL.isLittleEndian() && Size != 1) { 539 Function *Bswap = Intrinsic::getDeclaration(CI->getModule(), 540 Intrinsic::bswap, LoadSizeType); 541 LoadSrc1 = Builder.CreateCall(Bswap, LoadSrc1); 542 LoadSrc2 = Builder.CreateCall(Bswap, LoadSrc2); 543 } 544 545 if (Size < 4) { 546 // The i8 and i16 cases don't need compares. We zext the loaded values and 547 // subtract them to get the suitable negative, zero, or positive i32 result. 548 LoadSrc1 = Builder.CreateZExt(LoadSrc1, Builder.getInt32Ty()); 549 LoadSrc2 = Builder.CreateZExt(LoadSrc2, Builder.getInt32Ty()); 550 return Builder.CreateSub(LoadSrc1, LoadSrc2); 551 } 552 553 // The result of memcmp is negative, zero, or positive, so produce that by 554 // subtracting 2 extended compare bits: sub (ugt, ult). 555 // If a target prefers to use selects to get -1/0/1, they should be able 556 // to transform this later. The inverse transform (going from selects to math) 557 // may not be possible in the DAG because the selects got converted into 558 // branches before we got there. 559 Value *CmpUGT = Builder.CreateICmpUGT(LoadSrc1, LoadSrc2); 560 Value *CmpULT = Builder.CreateICmpULT(LoadSrc1, LoadSrc2); 561 Value *ZextUGT = Builder.CreateZExt(CmpUGT, Builder.getInt32Ty()); 562 Value *ZextULT = Builder.CreateZExt(CmpULT, Builder.getInt32Ty()); 563 return Builder.CreateSub(ZextUGT, ZextULT); 564 } 565 566 // This function expands the memcmp call into an inline expansion and returns 567 // the memcmp result. 568 Value *MemCmpExpansion::getMemCmpExpansion() { 569 // A memcmp with zero-comparison with only one block of load and compare does 570 // not need to set up any extra blocks. This case could be handled in the DAG, 571 // but since we have all of the machinery to flexibly expand any memcpy here, 572 // we choose to handle this case too to avoid fragmented lowering. 573 if ((!IsUsedForZeroCmp && NumLoadsPerBlock != 1) || getNumBlocks() != 1) { 574 BasicBlock *StartBlock = CI->getParent(); 575 EndBlock = StartBlock->splitBasicBlock(CI, "endblock"); 576 setupEndBlockPHINodes(); 577 createResultBlock(); 578 579 // If return value of memcmp is not used in a zero equality, we need to 580 // calculate which source was larger. The calculation requires the 581 // two loaded source values of each load compare block. 582 // These will be saved in the phi nodes created by setupResultBlockPHINodes. 583 if (!IsUsedForZeroCmp) setupResultBlockPHINodes(); 584 585 // Create the number of required load compare basic blocks. 586 createLoadCmpBlocks(); 587 588 // Update the terminator added by splitBasicBlock to branch to the first 589 // LoadCmpBlock. 590 StartBlock->getTerminator()->setSuccessor(0, LoadCmpBlocks[0]); 591 } 592 593 Builder.SetCurrentDebugLocation(CI->getDebugLoc()); 594 595 if (IsUsedForZeroCmp) 596 return getNumBlocks() == 1 ? getMemCmpEqZeroOneBlock() 597 : getMemCmpExpansionZeroCase(); 598 599 // TODO: Handle more than one load pair per block in getMemCmpOneBlock(). 600 if (getNumBlocks() == 1 && NumLoadsPerBlock == 1) return getMemCmpOneBlock(); 601 602 for (unsigned I = 0; I < getNumBlocks(); ++I) { 603 emitLoadCompareBlock(I); 604 } 605 606 emitMemCmpResultBlock(); 607 return PhiRes; 608 } 609 610 // This function checks to see if an expansion of memcmp can be generated. 611 // It checks for constant compare size that is less than the max inline size. 612 // If an expansion cannot occur, returns false to leave as a library call. 613 // Otherwise, the library call is replaced with a new IR instruction sequence. 614 /// We want to transform: 615 /// %call = call signext i32 @memcmp(i8* %0, i8* %1, i64 15) 616 /// To: 617 /// loadbb: 618 /// %0 = bitcast i32* %buffer2 to i8* 619 /// %1 = bitcast i32* %buffer1 to i8* 620 /// %2 = bitcast i8* %1 to i64* 621 /// %3 = bitcast i8* %0 to i64* 622 /// %4 = load i64, i64* %2 623 /// %5 = load i64, i64* %3 624 /// %6 = call i64 @llvm.bswap.i64(i64 %4) 625 /// %7 = call i64 @llvm.bswap.i64(i64 %5) 626 /// %8 = sub i64 %6, %7 627 /// %9 = icmp ne i64 %8, 0 628 /// br i1 %9, label %res_block, label %loadbb1 629 /// res_block: ; preds = %loadbb2, 630 /// %loadbb1, %loadbb 631 /// %phi.src1 = phi i64 [ %6, %loadbb ], [ %22, %loadbb1 ], [ %36, %loadbb2 ] 632 /// %phi.src2 = phi i64 [ %7, %loadbb ], [ %23, %loadbb1 ], [ %37, %loadbb2 ] 633 /// %10 = icmp ult i64 %phi.src1, %phi.src2 634 /// %11 = select i1 %10, i32 -1, i32 1 635 /// br label %endblock 636 /// loadbb1: ; preds = %loadbb 637 /// %12 = bitcast i32* %buffer2 to i8* 638 /// %13 = bitcast i32* %buffer1 to i8* 639 /// %14 = bitcast i8* %13 to i32* 640 /// %15 = bitcast i8* %12 to i32* 641 /// %16 = getelementptr i32, i32* %14, i32 2 642 /// %17 = getelementptr i32, i32* %15, i32 2 643 /// %18 = load i32, i32* %16 644 /// %19 = load i32, i32* %17 645 /// %20 = call i32 @llvm.bswap.i32(i32 %18) 646 /// %21 = call i32 @llvm.bswap.i32(i32 %19) 647 /// %22 = zext i32 %20 to i64 648 /// %23 = zext i32 %21 to i64 649 /// %24 = sub i64 %22, %23 650 /// %25 = icmp ne i64 %24, 0 651 /// br i1 %25, label %res_block, label %loadbb2 652 /// loadbb2: ; preds = %loadbb1 653 /// %26 = bitcast i32* %buffer2 to i8* 654 /// %27 = bitcast i32* %buffer1 to i8* 655 /// %28 = bitcast i8* %27 to i16* 656 /// %29 = bitcast i8* %26 to i16* 657 /// %30 = getelementptr i16, i16* %28, i16 6 658 /// %31 = getelementptr i16, i16* %29, i16 6 659 /// %32 = load i16, i16* %30 660 /// %33 = load i16, i16* %31 661 /// %34 = call i16 @llvm.bswap.i16(i16 %32) 662 /// %35 = call i16 @llvm.bswap.i16(i16 %33) 663 /// %36 = zext i16 %34 to i64 664 /// %37 = zext i16 %35 to i64 665 /// %38 = sub i64 %36, %37 666 /// %39 = icmp ne i64 %38, 0 667 /// br i1 %39, label %res_block, label %loadbb3 668 /// loadbb3: ; preds = %loadbb2 669 /// %40 = bitcast i32* %buffer2 to i8* 670 /// %41 = bitcast i32* %buffer1 to i8* 671 /// %42 = getelementptr i8, i8* %41, i8 14 672 /// %43 = getelementptr i8, i8* %40, i8 14 673 /// %44 = load i8, i8* %42 674 /// %45 = load i8, i8* %43 675 /// %46 = zext i8 %44 to i32 676 /// %47 = zext i8 %45 to i32 677 /// %48 = sub i32 %46, %47 678 /// br label %endblock 679 /// endblock: ; preds = %res_block, 680 /// %loadbb3 681 /// %phi.res = phi i32 [ %48, %loadbb3 ], [ %11, %res_block ] 682 /// ret i32 %phi.res 683 static bool expandMemCmp(CallInst *CI, const TargetTransformInfo *TTI, 684 const TargetLowering *TLI, const DataLayout *DL) { 685 NumMemCmpCalls++; 686 687 // Early exit from expansion if -Oz. 688 if (CI->getFunction()->optForMinSize()) 689 return false; 690 691 // Early exit from expansion if size is not a constant. 692 ConstantInt *SizeCast = dyn_cast<ConstantInt>(CI->getArgOperand(2)); 693 if (!SizeCast) { 694 NumMemCmpNotConstant++; 695 return false; 696 } 697 const uint64_t SizeVal = SizeCast->getZExtValue(); 698 699 if (SizeVal == 0) { 700 return false; 701 } 702 703 // TTI call to check if target would like to expand memcmp. Also, get the 704 // available load sizes. 705 const bool IsUsedForZeroCmp = isOnlyUsedInZeroEqualityComparison(CI); 706 const auto *const Options = TTI->enableMemCmpExpansion(IsUsedForZeroCmp); 707 if (!Options) return false; 708 709 const unsigned MaxNumLoads = 710 TLI->getMaxExpandSizeMemcmp(CI->getFunction()->optForSize()); 711 712 MemCmpExpansion Expansion(CI, SizeVal, *Options, MaxNumLoads, 713 IsUsedForZeroCmp, MemCmpNumLoadsPerBlock, *DL); 714 715 // Don't expand if this will require more loads than desired by the target. 716 if (Expansion.getNumLoads() == 0) { 717 NumMemCmpGreaterThanMax++; 718 return false; 719 } 720 721 NumMemCmpInlined++; 722 723 Value *Res = Expansion.getMemCmpExpansion(); 724 725 // Replace call with result of expansion and erase call. 726 CI->replaceAllUsesWith(Res); 727 CI->eraseFromParent(); 728 729 return true; 730 } 731 732 733 734 class ExpandMemCmpPass : public FunctionPass { 735 public: 736 static char ID; 737 738 ExpandMemCmpPass() : FunctionPass(ID) { 739 initializeExpandMemCmpPassPass(*PassRegistry::getPassRegistry()); 740 } 741 742 bool runOnFunction(Function &F) override { 743 if (skipFunction(F)) return false; 744 745 auto *TPC = getAnalysisIfAvailable<TargetPassConfig>(); 746 if (!TPC) { 747 return false; 748 } 749 const TargetLowering* TL = 750 TPC->getTM<TargetMachine>().getSubtargetImpl(F)->getTargetLowering(); 751 752 const TargetLibraryInfo *TLI = 753 &getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(); 754 const TargetTransformInfo *TTI = 755 &getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F); 756 auto PA = runImpl(F, TLI, TTI, TL); 757 return !PA.areAllPreserved(); 758 } 759 760 private: 761 void getAnalysisUsage(AnalysisUsage &AU) const override { 762 AU.addRequired<TargetLibraryInfoWrapperPass>(); 763 AU.addRequired<TargetTransformInfoWrapperPass>(); 764 FunctionPass::getAnalysisUsage(AU); 765 } 766 767 PreservedAnalyses runImpl(Function &F, const TargetLibraryInfo *TLI, 768 const TargetTransformInfo *TTI, 769 const TargetLowering* TL); 770 // Returns true if a change was made. 771 bool runOnBlock(BasicBlock &BB, const TargetLibraryInfo *TLI, 772 const TargetTransformInfo *TTI, const TargetLowering* TL, 773 const DataLayout& DL); 774 }; 775 776 bool ExpandMemCmpPass::runOnBlock( 777 BasicBlock &BB, const TargetLibraryInfo *TLI, 778 const TargetTransformInfo *TTI, const TargetLowering* TL, 779 const DataLayout& DL) { 780 for (Instruction& I : BB) { 781 CallInst *CI = dyn_cast<CallInst>(&I); 782 if (!CI) { 783 continue; 784 } 785 LibFunc Func; 786 if (TLI->getLibFunc(ImmutableCallSite(CI), Func) && 787 Func == LibFunc_memcmp && expandMemCmp(CI, TTI, TL, &DL)) { 788 return true; 789 } 790 } 791 return false; 792 } 793 794 795 PreservedAnalyses ExpandMemCmpPass::runImpl( 796 Function &F, const TargetLibraryInfo *TLI, const TargetTransformInfo *TTI, 797 const TargetLowering* TL) { 798 const DataLayout& DL = F.getParent()->getDataLayout(); 799 bool MadeChanges = false; 800 for (auto BBIt = F.begin(); BBIt != F.end();) { 801 if (runOnBlock(*BBIt, TLI, TTI, TL, DL)) { 802 MadeChanges = true; 803 // If changes were made, restart the function from the beginning, since 804 // the structure of the function was changed. 805 BBIt = F.begin(); 806 } else { 807 ++BBIt; 808 } 809 } 810 return MadeChanges ? PreservedAnalyses::none() : PreservedAnalyses::all(); 811 } 812 813 } // namespace 814 815 char ExpandMemCmpPass::ID = 0; 816 INITIALIZE_PASS_BEGIN(ExpandMemCmpPass, "expandmemcmp", 817 "Expand memcmp() to load/stores", false, false) 818 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass) 819 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass) 820 INITIALIZE_PASS_END(ExpandMemCmpPass, "expandmemcmp", 821 "Expand memcmp() to load/stores", false, false) 822 823 FunctionPass *llvm::createExpandMemCmpPass() { 824 return new ExpandMemCmpPass(); 825 } 826