1 //===- TruncInstCombine.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // TruncInstCombine - looks for expression dags post-dominated by TruncInst and 10 // for each eligible dag, it will create a reduced bit-width expression, replace 11 // the old expression with this new one and remove the old expression. 12 // Eligible expression dag is such that: 13 // 1. Contains only supported instructions. 14 // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. 15 // 3. Can be evaluated into type with reduced legal bit-width. 16 // 4. All instructions in the dag must not have users outside the dag. 17 // The only exception is for {ZExt, SExt}Inst with operand type equal to 18 // the new reduced type evaluated in (3). 19 // 20 // The motivation for this optimization is that evaluating and expression using 21 // smaller bit-width is preferable, especially for vectorization where we can 22 // fit more values in one vectorized instruction. In addition, this optimization 23 // may decrease the number of cast instructions, but will not increase it. 24 // 25 //===----------------------------------------------------------------------===// 26 27 #include "AggressiveInstCombineInternal.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ADT/Statistic.h" 30 #include "llvm/Analysis/ConstantFolding.h" 31 #include "llvm/Analysis/TargetLibraryInfo.h" 32 #include "llvm/IR/DataLayout.h" 33 #include "llvm/IR/Dominators.h" 34 #include "llvm/IR/IRBuilder.h" 35 #include "llvm/IR/Instruction.h" 36 #include "llvm/Support/KnownBits.h" 37 38 using namespace llvm; 39 40 #define DEBUG_TYPE "aggressive-instcombine" 41 42 STATISTIC( 43 NumDAGsReduced, 44 "Number of truncations eliminated by reducing bit width of expression DAG"); 45 STATISTIC(NumInstrsReduced, 46 "Number of instructions whose bit width was reduced"); 47 48 /// Given an instruction and a container, it fills all the relevant operands of 49 /// that instruction, with respect to the Trunc expression dag optimizaton. 50 static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { 51 unsigned Opc = I->getOpcode(); 52 switch (Opc) { 53 case Instruction::Trunc: 54 case Instruction::ZExt: 55 case Instruction::SExt: 56 // These CastInst are considered leaves of the evaluated expression, thus, 57 // their operands are not relevent. 58 break; 59 case Instruction::Add: 60 case Instruction::Sub: 61 case Instruction::Mul: 62 case Instruction::And: 63 case Instruction::Or: 64 case Instruction::Xor: 65 case Instruction::Shl: 66 case Instruction::LShr: 67 case Instruction::AShr: 68 case Instruction::UDiv: 69 case Instruction::URem: 70 Ops.push_back(I->getOperand(0)); 71 Ops.push_back(I->getOperand(1)); 72 break; 73 case Instruction::Select: 74 Ops.push_back(I->getOperand(1)); 75 Ops.push_back(I->getOperand(2)); 76 break; 77 default: 78 llvm_unreachable("Unreachable!"); 79 } 80 } 81 82 bool TruncInstCombine::buildTruncExpressionDag() { 83 SmallVector<Value *, 8> Worklist; 84 SmallVector<Instruction *, 8> Stack; 85 // Clear old expression dag. 86 InstInfoMap.clear(); 87 88 Worklist.push_back(CurrentTruncInst->getOperand(0)); 89 90 while (!Worklist.empty()) { 91 Value *Curr = Worklist.back(); 92 93 if (isa<Constant>(Curr)) { 94 Worklist.pop_back(); 95 continue; 96 } 97 98 auto *I = dyn_cast<Instruction>(Curr); 99 if (!I) 100 return false; 101 102 if (!Stack.empty() && Stack.back() == I) { 103 // Already handled all instruction operands, can remove it from both the 104 // Worklist and the Stack, and add it to the instruction info map. 105 Worklist.pop_back(); 106 Stack.pop_back(); 107 // Insert I to the Info map. 108 InstInfoMap.insert(std::make_pair(I, Info())); 109 continue; 110 } 111 112 if (InstInfoMap.count(I)) { 113 Worklist.pop_back(); 114 continue; 115 } 116 117 // Add the instruction to the stack before start handling its operands. 118 Stack.push_back(I); 119 120 unsigned Opc = I->getOpcode(); 121 switch (Opc) { 122 case Instruction::Trunc: 123 case Instruction::ZExt: 124 case Instruction::SExt: 125 // trunc(trunc(x)) -> trunc(x) 126 // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest 127 // trunc(ext(x)) -> trunc(x) if the source type is larger than the new 128 // dest 129 break; 130 case Instruction::Add: 131 case Instruction::Sub: 132 case Instruction::Mul: 133 case Instruction::And: 134 case Instruction::Or: 135 case Instruction::Xor: 136 case Instruction::Shl: 137 case Instruction::LShr: 138 case Instruction::AShr: 139 case Instruction::UDiv: 140 case Instruction::URem: 141 case Instruction::Select: { 142 SmallVector<Value *, 2> Operands; 143 getRelevantOperands(I, Operands); 144 append_range(Worklist, Operands); 145 break; 146 } 147 default: 148 // TODO: Can handle more cases here: 149 // 1. shufflevector, extractelement, insertelement 150 // 2. sdiv, srem 151 // 3. phi node(and loop handling) 152 // ... 153 return false; 154 } 155 } 156 return true; 157 } 158 159 unsigned TruncInstCombine::getMinBitWidth() { 160 SmallVector<Value *, 8> Worklist; 161 SmallVector<Instruction *, 8> Stack; 162 163 Value *Src = CurrentTruncInst->getOperand(0); 164 Type *DstTy = CurrentTruncInst->getType(); 165 unsigned TruncBitWidth = DstTy->getScalarSizeInBits(); 166 unsigned OrigBitWidth = 167 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); 168 169 if (isa<Constant>(Src)) 170 return TruncBitWidth; 171 172 Worklist.push_back(Src); 173 InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth; 174 175 while (!Worklist.empty()) { 176 Value *Curr = Worklist.back(); 177 178 if (isa<Constant>(Curr)) { 179 Worklist.pop_back(); 180 continue; 181 } 182 183 // Otherwise, it must be an instruction. 184 auto *I = cast<Instruction>(Curr); 185 186 auto &Info = InstInfoMap[I]; 187 188 SmallVector<Value *, 2> Operands; 189 getRelevantOperands(I, Operands); 190 191 if (!Stack.empty() && Stack.back() == I) { 192 // Already handled all instruction operands, can remove it from both, the 193 // Worklist and the Stack, and update MinBitWidth. 194 Worklist.pop_back(); 195 Stack.pop_back(); 196 for (auto *Operand : Operands) 197 if (auto *IOp = dyn_cast<Instruction>(Operand)) 198 Info.MinBitWidth = 199 std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth); 200 continue; 201 } 202 203 // Add the instruction to the stack before start handling its operands. 204 Stack.push_back(I); 205 unsigned ValidBitWidth = Info.ValidBitWidth; 206 207 // Update minimum bit-width before handling its operands. This is required 208 // when the instruction is part of a loop. 209 Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth); 210 211 for (auto *Operand : Operands) 212 if (auto *IOp = dyn_cast<Instruction>(Operand)) { 213 // If we already calculated the minimum bit-width for this valid 214 // bit-width, or for a smaller valid bit-width, then just keep the 215 // answer we already calculated. 216 unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; 217 if (IOpBitwidth >= ValidBitWidth) 218 continue; 219 InstInfoMap[IOp].ValidBitWidth = ValidBitWidth; 220 Worklist.push_back(IOp); 221 } 222 } 223 unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth; 224 assert(MinBitWidth >= TruncBitWidth); 225 226 if (MinBitWidth > TruncBitWidth) { 227 // In this case reducing expression with vector type might generate a new 228 // vector type, which is not preferable as it might result in generating 229 // sub-optimal code. 230 if (DstTy->isVectorTy()) 231 return OrigBitWidth; 232 // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth). 233 Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth); 234 // Update minimum bit-width with the new destination type bit-width if 235 // succeeded to find such, otherwise, with original bit-width. 236 MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth; 237 } else { // MinBitWidth == TruncBitWidth 238 // In this case the expression can be evaluated with the trunc instruction 239 // destination type, and trunc instruction can be omitted. However, we 240 // should not perform the evaluation if the original type is a legal scalar 241 // type and the target type is illegal. 242 bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth); 243 bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth); 244 if (!DstTy->isVectorTy() && FromLegal && !ToLegal) 245 return OrigBitWidth; 246 } 247 return MinBitWidth; 248 } 249 250 Type *TruncInstCombine::getBestTruncatedType() { 251 if (!buildTruncExpressionDag()) 252 return nullptr; 253 254 // We don't want to duplicate instructions, which isn't profitable. Thus, we 255 // can't shrink something that has multiple users, unless all users are 256 // post-dominated by the trunc instruction, i.e., were visited during the 257 // expression evaluation. 258 unsigned DesiredBitWidth = 0; 259 for (auto Itr : InstInfoMap) { 260 Instruction *I = Itr.first; 261 if (I->hasOneUse()) 262 continue; 263 bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I)); 264 for (auto *U : I->users()) 265 if (auto *UI = dyn_cast<Instruction>(U)) 266 if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) { 267 if (!IsExtInst) 268 return nullptr; 269 // If this is an extension from the dest type, we can eliminate it, 270 // even if it has multiple users. Thus, update the DesiredBitWidth and 271 // validate all extension instructions agrees on same DesiredBitWidth. 272 unsigned ExtInstBitWidth = 273 I->getOperand(0)->getType()->getScalarSizeInBits(); 274 if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth) 275 return nullptr; 276 DesiredBitWidth = ExtInstBitWidth; 277 } 278 } 279 280 unsigned OrigBitWidth = 281 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); 282 283 // Initialize MinBitWidth for shift instructions with the minimum number 284 // that is greater than shift amount (i.e. shift amount + 1). 285 // For `lshr` adjust MinBitWidth so that all potentially truncated 286 // bits of the value-to-be-shifted are zeros. 287 // For `ashr` adjust MinBitWidth so that all potentially truncated 288 // bits of the value-to-be-shifted are sign bits (all zeros or ones) 289 // and even one (first) untruncated bit is sign bit. 290 // Exit early if MinBitWidth is not less than original bitwidth. 291 for (auto &Itr : InstInfoMap) { 292 Instruction *I = Itr.first; 293 if (I->isShift()) { 294 KnownBits KnownRHS = computeKnownBits(I->getOperand(1)); 295 unsigned MinBitWidth = KnownRHS.getMaxValue() 296 .uadd_sat(APInt(OrigBitWidth, 1)) 297 .getLimitedValue(OrigBitWidth); 298 if (MinBitWidth == OrigBitWidth) 299 return nullptr; 300 if (I->getOpcode() == Instruction::LShr) { 301 KnownBits KnownLHS = computeKnownBits(I->getOperand(0)); 302 MinBitWidth = 303 std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits()); 304 } 305 if (I->getOpcode() == Instruction::AShr) { 306 unsigned NumSignBits = ComputeNumSignBits(I->getOperand(0)); 307 MinBitWidth = std::max(MinBitWidth, OrigBitWidth - NumSignBits + 1); 308 } 309 if (MinBitWidth >= OrigBitWidth) 310 return nullptr; 311 Itr.second.MinBitWidth = MinBitWidth; 312 } 313 if (I->getOpcode() == Instruction::UDiv || 314 I->getOpcode() == Instruction::URem) { 315 unsigned MinBitWidth = 0; 316 for (const auto &Op : I->operands()) { 317 KnownBits Known = computeKnownBits(Op); 318 MinBitWidth = 319 std::max(Known.getMaxValue().getActiveBits(), MinBitWidth); 320 if (MinBitWidth >= OrigBitWidth) 321 return nullptr; 322 } 323 Itr.second.MinBitWidth = MinBitWidth; 324 } 325 } 326 327 // Calculate minimum allowed bit-width allowed for shrinking the currently 328 // visited truncate's operand. 329 unsigned MinBitWidth = getMinBitWidth(); 330 331 // Check that we can shrink to smaller bit-width than original one and that 332 // it is similar to the DesiredBitWidth is such exists. 333 if (MinBitWidth >= OrigBitWidth || 334 (DesiredBitWidth && DesiredBitWidth != MinBitWidth)) 335 return nullptr; 336 337 return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth); 338 } 339 340 /// Given a reduced scalar type \p Ty and a \p V value, return a reduced type 341 /// for \p V, according to its type, if it vector type, return the vector 342 /// version of \p Ty, otherwise return \p Ty. 343 static Type *getReducedType(Value *V, Type *Ty) { 344 assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type"); 345 if (auto *VTy = dyn_cast<VectorType>(V->getType())) 346 return VectorType::get(Ty, VTy->getElementCount()); 347 return Ty; 348 } 349 350 Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) { 351 Type *Ty = getReducedType(V, SclTy); 352 if (auto *C = dyn_cast<Constant>(V)) { 353 C = ConstantExpr::getIntegerCast(C, Ty, false); 354 // If we got a constantexpr back, try to simplify it with DL info. 355 return ConstantFoldConstant(C, DL, &TLI); 356 } 357 358 auto *I = cast<Instruction>(V); 359 Info Entry = InstInfoMap.lookup(I); 360 assert(Entry.NewValue); 361 return Entry.NewValue; 362 } 363 364 void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { 365 NumInstrsReduced += InstInfoMap.size(); 366 for (auto &Itr : InstInfoMap) { // Forward 367 Instruction *I = Itr.first; 368 TruncInstCombine::Info &NodeInfo = Itr.second; 369 370 assert(!NodeInfo.NewValue && "Instruction has been evaluated"); 371 372 IRBuilder<> Builder(I); 373 Value *Res = nullptr; 374 unsigned Opc = I->getOpcode(); 375 switch (Opc) { 376 case Instruction::Trunc: 377 case Instruction::ZExt: 378 case Instruction::SExt: { 379 Type *Ty = getReducedType(I, SclTy); 380 // If the source type of the cast is the type we're trying for then we can 381 // just return the source. There's no need to insert it because it is not 382 // new. 383 if (I->getOperand(0)->getType() == Ty) { 384 assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst"); 385 NodeInfo.NewValue = I->getOperand(0); 386 continue; 387 } 388 // Otherwise, must be the same type of cast, so just reinsert a new one. 389 // This also handles the case of zext(trunc(x)) -> zext(x). 390 Res = Builder.CreateIntCast(I->getOperand(0), Ty, 391 Opc == Instruction::SExt); 392 393 // Update Worklist entries with new value if needed. 394 // There are three possible changes to the Worklist: 395 // 1. Update Old-TruncInst -> New-TruncInst. 396 // 2. Remove Old-TruncInst (if New node is not TruncInst). 397 // 3. Add New-TruncInst (if Old node was not TruncInst). 398 auto *Entry = find(Worklist, I); 399 if (Entry != Worklist.end()) { 400 if (auto *NewCI = dyn_cast<TruncInst>(Res)) 401 *Entry = NewCI; 402 else 403 Worklist.erase(Entry); 404 } else if (auto *NewCI = dyn_cast<TruncInst>(Res)) 405 Worklist.push_back(NewCI); 406 break; 407 } 408 case Instruction::Add: 409 case Instruction::Sub: 410 case Instruction::Mul: 411 case Instruction::And: 412 case Instruction::Or: 413 case Instruction::Xor: 414 case Instruction::Shl: 415 case Instruction::LShr: 416 case Instruction::AShr: 417 case Instruction::UDiv: 418 case Instruction::URem: { 419 Value *LHS = getReducedOperand(I->getOperand(0), SclTy); 420 Value *RHS = getReducedOperand(I->getOperand(1), SclTy); 421 Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); 422 // Preserve `exact` flag since truncation doesn't change exactness 423 if (auto *PEO = dyn_cast<PossiblyExactOperator>(I)) 424 if (auto *ResI = dyn_cast<Instruction>(Res)) 425 ResI->setIsExact(PEO->isExact()); 426 break; 427 } 428 case Instruction::Select: { 429 Value *Op0 = I->getOperand(0); 430 Value *LHS = getReducedOperand(I->getOperand(1), SclTy); 431 Value *RHS = getReducedOperand(I->getOperand(2), SclTy); 432 Res = Builder.CreateSelect(Op0, LHS, RHS); 433 break; 434 } 435 default: 436 llvm_unreachable("Unhandled instruction"); 437 } 438 439 NodeInfo.NewValue = Res; 440 if (auto *ResI = dyn_cast<Instruction>(Res)) 441 ResI->takeName(I); 442 } 443 444 Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy); 445 Type *DstTy = CurrentTruncInst->getType(); 446 if (Res->getType() != DstTy) { 447 IRBuilder<> Builder(CurrentTruncInst); 448 Res = Builder.CreateIntCast(Res, DstTy, false); 449 if (auto *ResI = dyn_cast<Instruction>(Res)) 450 ResI->takeName(CurrentTruncInst); 451 } 452 CurrentTruncInst->replaceAllUsesWith(Res); 453 454 // Erase old expression dag, which was replaced by the reduced expression dag. 455 // We iterate backward, which means we visit the instruction before we visit 456 // any of its operands, this way, when we get to the operand, we already 457 // removed the instructions (from the expression dag) that uses it. 458 CurrentTruncInst->eraseFromParent(); 459 for (auto I = InstInfoMap.rbegin(), E = InstInfoMap.rend(); I != E; ++I) { 460 // We still need to check that the instruction has no users before we erase 461 // it, because {SExt, ZExt}Inst Instruction might have other users that was 462 // not reduced, in such case, we need to keep that instruction. 463 if (I->first->use_empty()) 464 I->first->eraseFromParent(); 465 } 466 } 467 468 bool TruncInstCombine::run(Function &F) { 469 bool MadeIRChange = false; 470 471 // Collect all TruncInst in the function into the Worklist for evaluating. 472 for (auto &BB : F) { 473 // Ignore unreachable basic block. 474 if (!DT.isReachableFromEntry(&BB)) 475 continue; 476 for (auto &I : BB) 477 if (auto *CI = dyn_cast<TruncInst>(&I)) 478 Worklist.push_back(CI); 479 } 480 481 // Process all TruncInst in the Worklist, for each instruction: 482 // 1. Check if it dominates an eligible expression dag to be reduced. 483 // 2. Create a reduced expression dag and replace the old one with it. 484 while (!Worklist.empty()) { 485 CurrentTruncInst = Worklist.pop_back_val(); 486 487 if (Type *NewDstSclTy = getBestTruncatedType()) { 488 LLVM_DEBUG( 489 dbgs() << "ICE: TruncInstCombine reducing type of expression dag " 490 "dominated by: " 491 << CurrentTruncInst << '\n'); 492 ReduceExpressionDag(NewDstSclTy); 493 ++NumDAGsReduced; 494 MadeIRChange = true; 495 } 496 } 497 498 return MadeIRChange; 499 } 500