1 //===- TruncInstCombine.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // TruncInstCombine - looks for expression dags post-dominated by TruncInst and 10 // for each eligible dag, it will create a reduced bit-width expression, replace 11 // the old expression with this new one and remove the old expression. 12 // Eligible expression dag is such that: 13 // 1. Contains only supported instructions. 14 // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. 15 // 3. Can be evaluated into type with reduced legal bit-width. 16 // 4. All instructions in the dag must not have users outside the dag. 17 // The only exception is for {ZExt, SExt}Inst with operand type equal to 18 // the new reduced type evaluated in (3). 19 // 20 // The motivation for this optimization is that evaluating and expression using 21 // smaller bit-width is preferable, especially for vectorization where we can 22 // fit more values in one vectorized instruction. In addition, this optimization 23 // may decrease the number of cast instructions, but will not increase it. 24 // 25 //===----------------------------------------------------------------------===// 26 27 #include "AggressiveInstCombineInternal.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ADT/Statistic.h" 30 #include "llvm/Analysis/ConstantFolding.h" 31 #include "llvm/Analysis/TargetLibraryInfo.h" 32 #include "llvm/Analysis/ValueTracking.h" 33 #include "llvm/IR/DataLayout.h" 34 #include "llvm/IR/Dominators.h" 35 #include "llvm/IR/IRBuilder.h" 36 #include "llvm/IR/Instruction.h" 37 #include "llvm/Support/KnownBits.h" 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "aggressive-instcombine" 42 43 STATISTIC( 44 NumDAGsReduced, 45 "Number of truncations eliminated by reducing bit width of expression DAG"); 46 STATISTIC(NumInstrsReduced, 47 "Number of instructions whose bit width was reduced"); 48 49 /// Given an instruction and a container, it fills all the relevant operands of 50 /// that instruction, with respect to the Trunc expression dag optimizaton. 51 static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { 52 unsigned Opc = I->getOpcode(); 53 switch (Opc) { 54 case Instruction::Trunc: 55 case Instruction::ZExt: 56 case Instruction::SExt: 57 // These CastInst are considered leaves of the evaluated expression, thus, 58 // their operands are not relevent. 59 break; 60 case Instruction::Add: 61 case Instruction::Sub: 62 case Instruction::Mul: 63 case Instruction::And: 64 case Instruction::Or: 65 case Instruction::Xor: 66 case Instruction::Shl: 67 case Instruction::LShr: 68 Ops.push_back(I->getOperand(0)); 69 Ops.push_back(I->getOperand(1)); 70 break; 71 case Instruction::Select: 72 Ops.push_back(I->getOperand(1)); 73 Ops.push_back(I->getOperand(2)); 74 break; 75 default: 76 llvm_unreachable("Unreachable!"); 77 } 78 } 79 80 bool TruncInstCombine::buildTruncExpressionDag() { 81 SmallVector<Value *, 8> Worklist; 82 SmallVector<Instruction *, 8> Stack; 83 // Clear old expression dag. 84 InstInfoMap.clear(); 85 86 Worklist.push_back(CurrentTruncInst->getOperand(0)); 87 88 while (!Worklist.empty()) { 89 Value *Curr = Worklist.back(); 90 91 if (isa<Constant>(Curr)) { 92 Worklist.pop_back(); 93 continue; 94 } 95 96 auto *I = dyn_cast<Instruction>(Curr); 97 if (!I) 98 return false; 99 100 if (!Stack.empty() && Stack.back() == I) { 101 // Already handled all instruction operands, can remove it from both the 102 // Worklist and the Stack, and add it to the instruction info map. 103 Worklist.pop_back(); 104 Stack.pop_back(); 105 // Insert I to the Info map. 106 InstInfoMap.insert(std::make_pair(I, Info())); 107 continue; 108 } 109 110 if (InstInfoMap.count(I)) { 111 Worklist.pop_back(); 112 continue; 113 } 114 115 // Add the instruction to the stack before start handling its operands. 116 Stack.push_back(I); 117 118 unsigned Opc = I->getOpcode(); 119 switch (Opc) { 120 case Instruction::Trunc: 121 case Instruction::ZExt: 122 case Instruction::SExt: 123 // trunc(trunc(x)) -> trunc(x) 124 // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest 125 // trunc(ext(x)) -> trunc(x) if the source type is larger than the new 126 // dest 127 break; 128 case Instruction::Add: 129 case Instruction::Sub: 130 case Instruction::Mul: 131 case Instruction::And: 132 case Instruction::Or: 133 case Instruction::Xor: 134 case Instruction::Shl: 135 case Instruction::LShr: 136 case Instruction::Select: { 137 SmallVector<Value *, 2> Operands; 138 getRelevantOperands(I, Operands); 139 append_range(Worklist, Operands); 140 break; 141 } 142 default: 143 // TODO: Can handle more cases here: 144 // 1. shufflevector, extractelement, insertelement 145 // 2. udiv, urem 146 // 3. ashr 147 // 4. phi node(and loop handling) 148 // ... 149 return false; 150 } 151 } 152 return true; 153 } 154 155 unsigned TruncInstCombine::getMinBitWidth() { 156 SmallVector<Value *, 8> Worklist; 157 SmallVector<Instruction *, 8> Stack; 158 159 Value *Src = CurrentTruncInst->getOperand(0); 160 Type *DstTy = CurrentTruncInst->getType(); 161 unsigned TruncBitWidth = DstTy->getScalarSizeInBits(); 162 unsigned OrigBitWidth = 163 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); 164 165 if (isa<Constant>(Src)) 166 return TruncBitWidth; 167 168 Worklist.push_back(Src); 169 InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth; 170 171 while (!Worklist.empty()) { 172 Value *Curr = Worklist.back(); 173 174 if (isa<Constant>(Curr)) { 175 Worklist.pop_back(); 176 continue; 177 } 178 179 // Otherwise, it must be an instruction. 180 auto *I = cast<Instruction>(Curr); 181 182 auto &Info = InstInfoMap[I]; 183 184 SmallVector<Value *, 2> Operands; 185 getRelevantOperands(I, Operands); 186 187 if (!Stack.empty() && Stack.back() == I) { 188 // Already handled all instruction operands, can remove it from both, the 189 // Worklist and the Stack, and update MinBitWidth. 190 Worklist.pop_back(); 191 Stack.pop_back(); 192 for (auto *Operand : Operands) 193 if (auto *IOp = dyn_cast<Instruction>(Operand)) 194 Info.MinBitWidth = 195 std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth); 196 continue; 197 } 198 199 // Add the instruction to the stack before start handling its operands. 200 Stack.push_back(I); 201 unsigned ValidBitWidth = Info.ValidBitWidth; 202 203 // Update minimum bit-width before handling its operands. This is required 204 // when the instruction is part of a loop. 205 Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth); 206 207 for (auto *Operand : Operands) 208 if (auto *IOp = dyn_cast<Instruction>(Operand)) { 209 // If we already calculated the minimum bit-width for this valid 210 // bit-width, or for a smaller valid bit-width, then just keep the 211 // answer we already calculated. 212 unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; 213 if (IOpBitwidth >= ValidBitWidth) 214 continue; 215 InstInfoMap[IOp].ValidBitWidth = ValidBitWidth; 216 Worklist.push_back(IOp); 217 } 218 } 219 unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth; 220 assert(MinBitWidth >= TruncBitWidth); 221 222 if (MinBitWidth > TruncBitWidth) { 223 // In this case reducing expression with vector type might generate a new 224 // vector type, which is not preferable as it might result in generating 225 // sub-optimal code. 226 if (DstTy->isVectorTy()) 227 return OrigBitWidth; 228 // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth). 229 Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth); 230 // Update minimum bit-width with the new destination type bit-width if 231 // succeeded to find such, otherwise, with original bit-width. 232 MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth; 233 } else { // MinBitWidth == TruncBitWidth 234 // In this case the expression can be evaluated with the trunc instruction 235 // destination type, and trunc instruction can be omitted. However, we 236 // should not perform the evaluation if the original type is a legal scalar 237 // type and the target type is illegal. 238 bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth); 239 bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth); 240 if (!DstTy->isVectorTy() && FromLegal && !ToLegal) 241 return OrigBitWidth; 242 } 243 return MinBitWidth; 244 } 245 246 Type *TruncInstCombine::getBestTruncatedType() { 247 if (!buildTruncExpressionDag()) 248 return nullptr; 249 250 // We don't want to duplicate instructions, which isn't profitable. Thus, we 251 // can't shrink something that has multiple users, unless all users are 252 // post-dominated by the trunc instruction, i.e., were visited during the 253 // expression evaluation. 254 unsigned DesiredBitWidth = 0; 255 for (auto Itr : InstInfoMap) { 256 Instruction *I = Itr.first; 257 if (I->hasOneUse()) 258 continue; 259 bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I)); 260 for (auto *U : I->users()) 261 if (auto *UI = dyn_cast<Instruction>(U)) 262 if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) { 263 if (!IsExtInst) 264 return nullptr; 265 // If this is an extension from the dest type, we can eliminate it, 266 // even if it has multiple users. Thus, update the DesiredBitWidth and 267 // validate all extension instructions agrees on same DesiredBitWidth. 268 unsigned ExtInstBitWidth = 269 I->getOperand(0)->getType()->getScalarSizeInBits(); 270 if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth) 271 return nullptr; 272 DesiredBitWidth = ExtInstBitWidth; 273 } 274 } 275 276 unsigned OrigBitWidth = 277 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); 278 279 // Initialize MinBitWidth for shift instructions with the minimum number 280 // that is greater than shift amount (i.e. shift amount + 1). For `lshr` 281 // adjust MinBitWidth so that all potentially truncated bits of 282 // the value-to-be-shifted are zeros. 283 // Also normalize MinBitWidth not to be greater than source bitwidth. 284 for (auto &Itr : InstInfoMap) { 285 Instruction *I = Itr.first; 286 if (I->getOpcode() == Instruction::Shl || 287 I->getOpcode() == Instruction::LShr) { 288 KnownBits KnownRHS = computeKnownBits(I->getOperand(1), DL); 289 const unsigned SrcBitWidth = KnownRHS.getBitWidth(); 290 unsigned MinBitWidth = KnownRHS.getMaxValue() 291 .uadd_sat(APInt(SrcBitWidth, 1)) 292 .getLimitedValue(SrcBitWidth); 293 if (MinBitWidth >= OrigBitWidth) 294 return nullptr; 295 if (I->getOpcode() == Instruction::LShr) { 296 KnownBits KnownLHS = computeKnownBits(I->getOperand(0), DL); 297 MinBitWidth = 298 std::max(MinBitWidth, KnownLHS.getMaxValue().getActiveBits()); 299 if (MinBitWidth >= OrigBitWidth) 300 return nullptr; 301 } 302 Itr.second.MinBitWidth = MinBitWidth; 303 } 304 } 305 306 // Calculate minimum allowed bit-width allowed for shrinking the currently 307 // visited truncate's operand. 308 unsigned MinBitWidth = getMinBitWidth(); 309 310 // Check that we can shrink to smaller bit-width than original one and that 311 // it is similar to the DesiredBitWidth is such exists. 312 if (MinBitWidth >= OrigBitWidth || 313 (DesiredBitWidth && DesiredBitWidth != MinBitWidth)) 314 return nullptr; 315 316 return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth); 317 } 318 319 /// Given a reduced scalar type \p Ty and a \p V value, return a reduced type 320 /// for \p V, according to its type, if it vector type, return the vector 321 /// version of \p Ty, otherwise return \p Ty. 322 static Type *getReducedType(Value *V, Type *Ty) { 323 assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type"); 324 if (auto *VTy = dyn_cast<VectorType>(V->getType())) 325 return VectorType::get(Ty, VTy->getElementCount()); 326 return Ty; 327 } 328 329 Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) { 330 Type *Ty = getReducedType(V, SclTy); 331 if (auto *C = dyn_cast<Constant>(V)) { 332 C = ConstantExpr::getIntegerCast(C, Ty, false); 333 // If we got a constantexpr back, try to simplify it with DL info. 334 return ConstantFoldConstant(C, DL, &TLI); 335 } 336 337 auto *I = cast<Instruction>(V); 338 Info Entry = InstInfoMap.lookup(I); 339 assert(Entry.NewValue); 340 return Entry.NewValue; 341 } 342 343 void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { 344 NumInstrsReduced += InstInfoMap.size(); 345 for (auto &Itr : InstInfoMap) { // Forward 346 Instruction *I = Itr.first; 347 TruncInstCombine::Info &NodeInfo = Itr.second; 348 349 assert(!NodeInfo.NewValue && "Instruction has been evaluated"); 350 351 IRBuilder<> Builder(I); 352 Value *Res = nullptr; 353 unsigned Opc = I->getOpcode(); 354 switch (Opc) { 355 case Instruction::Trunc: 356 case Instruction::ZExt: 357 case Instruction::SExt: { 358 Type *Ty = getReducedType(I, SclTy); 359 // If the source type of the cast is the type we're trying for then we can 360 // just return the source. There's no need to insert it because it is not 361 // new. 362 if (I->getOperand(0)->getType() == Ty) { 363 assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst"); 364 NodeInfo.NewValue = I->getOperand(0); 365 continue; 366 } 367 // Otherwise, must be the same type of cast, so just reinsert a new one. 368 // This also handles the case of zext(trunc(x)) -> zext(x). 369 Res = Builder.CreateIntCast(I->getOperand(0), Ty, 370 Opc == Instruction::SExt); 371 372 // Update Worklist entries with new value if needed. 373 // There are three possible changes to the Worklist: 374 // 1. Update Old-TruncInst -> New-TruncInst. 375 // 2. Remove Old-TruncInst (if New node is not TruncInst). 376 // 3. Add New-TruncInst (if Old node was not TruncInst). 377 auto *Entry = find(Worklist, I); 378 if (Entry != Worklist.end()) { 379 if (auto *NewCI = dyn_cast<TruncInst>(Res)) 380 *Entry = NewCI; 381 else 382 Worklist.erase(Entry); 383 } else if (auto *NewCI = dyn_cast<TruncInst>(Res)) 384 Worklist.push_back(NewCI); 385 break; 386 } 387 case Instruction::Add: 388 case Instruction::Sub: 389 case Instruction::Mul: 390 case Instruction::And: 391 case Instruction::Or: 392 case Instruction::Xor: 393 case Instruction::Shl: 394 case Instruction::LShr: { 395 Value *LHS = getReducedOperand(I->getOperand(0), SclTy); 396 Value *RHS = getReducedOperand(I->getOperand(1), SclTy); 397 Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); 398 // Preserve `exact` flag since truncation doesn't change exactness 399 if (Opc == Instruction::LShr) 400 cast<Instruction>(Res)->setIsExact(I->isExact()); 401 break; 402 } 403 case Instruction::Select: { 404 Value *Op0 = I->getOperand(0); 405 Value *LHS = getReducedOperand(I->getOperand(1), SclTy); 406 Value *RHS = getReducedOperand(I->getOperand(2), SclTy); 407 Res = Builder.CreateSelect(Op0, LHS, RHS); 408 break; 409 } 410 default: 411 llvm_unreachable("Unhandled instruction"); 412 } 413 414 NodeInfo.NewValue = Res; 415 if (auto *ResI = dyn_cast<Instruction>(Res)) 416 ResI->takeName(I); 417 } 418 419 Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy); 420 Type *DstTy = CurrentTruncInst->getType(); 421 if (Res->getType() != DstTy) { 422 IRBuilder<> Builder(CurrentTruncInst); 423 Res = Builder.CreateIntCast(Res, DstTy, false); 424 if (auto *ResI = dyn_cast<Instruction>(Res)) 425 ResI->takeName(CurrentTruncInst); 426 } 427 CurrentTruncInst->replaceAllUsesWith(Res); 428 429 // Erase old expression dag, which was replaced by the reduced expression dag. 430 // We iterate backward, which means we visit the instruction before we visit 431 // any of its operands, this way, when we get to the operand, we already 432 // removed the instructions (from the expression dag) that uses it. 433 CurrentTruncInst->eraseFromParent(); 434 for (auto I = InstInfoMap.rbegin(), E = InstInfoMap.rend(); I != E; ++I) { 435 // We still need to check that the instruction has no users before we erase 436 // it, because {SExt, ZExt}Inst Instruction might have other users that was 437 // not reduced, in such case, we need to keep that instruction. 438 if (I->first->use_empty()) 439 I->first->eraseFromParent(); 440 } 441 } 442 443 bool TruncInstCombine::run(Function &F) { 444 bool MadeIRChange = false; 445 446 // Collect all TruncInst in the function into the Worklist for evaluating. 447 for (auto &BB : F) { 448 // Ignore unreachable basic block. 449 if (!DT.isReachableFromEntry(&BB)) 450 continue; 451 for (auto &I : BB) 452 if (auto *CI = dyn_cast<TruncInst>(&I)) 453 Worklist.push_back(CI); 454 } 455 456 // Process all TruncInst in the Worklist, for each instruction: 457 // 1. Check if it dominates an eligible expression dag to be reduced. 458 // 2. Create a reduced expression dag and replace the old one with it. 459 while (!Worklist.empty()) { 460 CurrentTruncInst = Worklist.pop_back_val(); 461 462 if (Type *NewDstSclTy = getBestTruncatedType()) { 463 LLVM_DEBUG( 464 dbgs() << "ICE: TruncInstCombine reducing type of expression dag " 465 "dominated by: " 466 << CurrentTruncInst << '\n'); 467 ReduceExpressionDag(NewDstSclTy); 468 ++NumDAGsReduced; 469 MadeIRChange = true; 470 } 471 } 472 473 return MadeIRChange; 474 } 475