1 //===- TruncInstCombine.cpp -----------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // TruncInstCombine - looks for expression dags post-dominated by TruncInst and 10 // for each eligible dag, it will create a reduced bit-width expression, replace 11 // the old expression with this new one and remove the old expression. 12 // Eligible expression dag is such that: 13 // 1. Contains only supported instructions. 14 // 2. Supported leaves: ZExtInst, SExtInst, TruncInst and Constant value. 15 // 3. Can be evaluated into type with reduced legal bit-width. 16 // 4. All instructions in the dag must not have users outside the dag. 17 // The only exception is for {ZExt, SExt}Inst with operand type equal to 18 // the new reduced type evaluated in (3). 19 // 20 // The motivation for this optimization is that evaluating and expression using 21 // smaller bit-width is preferable, especially for vectorization where we can 22 // fit more values in one vectorized instruction. In addition, this optimization 23 // may decrease the number of cast instructions, but will not increase it. 24 // 25 //===----------------------------------------------------------------------===// 26 27 #include "AggressiveInstCombineInternal.h" 28 #include "llvm/ADT/STLExtras.h" 29 #include "llvm/ADT/Statistic.h" 30 #include "llvm/Analysis/ConstantFolding.h" 31 #include "llvm/Analysis/TargetLibraryInfo.h" 32 #include "llvm/Analysis/ValueTracking.h" 33 #include "llvm/IR/DataLayout.h" 34 #include "llvm/IR/Dominators.h" 35 #include "llvm/IR/IRBuilder.h" 36 #include "llvm/IR/Instruction.h" 37 #include "llvm/Support/KnownBits.h" 38 39 using namespace llvm; 40 41 #define DEBUG_TYPE "aggressive-instcombine" 42 43 STATISTIC( 44 NumDAGsReduced, 45 "Number of truncations eliminated by reducing bit width of expression DAG"); 46 STATISTIC(NumInstrsReduced, 47 "Number of instructions whose bit width was reduced"); 48 49 /// Given an instruction and a container, it fills all the relevant operands of 50 /// that instruction, with respect to the Trunc expression dag optimizaton. 51 static void getRelevantOperands(Instruction *I, SmallVectorImpl<Value *> &Ops) { 52 unsigned Opc = I->getOpcode(); 53 switch (Opc) { 54 case Instruction::Trunc: 55 case Instruction::ZExt: 56 case Instruction::SExt: 57 // These CastInst are considered leaves of the evaluated expression, thus, 58 // their operands are not relevent. 59 break; 60 case Instruction::Add: 61 case Instruction::Sub: 62 case Instruction::Mul: 63 case Instruction::And: 64 case Instruction::Or: 65 case Instruction::Xor: 66 case Instruction::Shl: 67 Ops.push_back(I->getOperand(0)); 68 Ops.push_back(I->getOperand(1)); 69 break; 70 case Instruction::Select: 71 Ops.push_back(I->getOperand(1)); 72 Ops.push_back(I->getOperand(2)); 73 break; 74 default: 75 llvm_unreachable("Unreachable!"); 76 } 77 } 78 79 bool TruncInstCombine::buildTruncExpressionDag() { 80 SmallVector<Value *, 8> Worklist; 81 SmallVector<Instruction *, 8> Stack; 82 // Clear old expression dag. 83 InstInfoMap.clear(); 84 85 Worklist.push_back(CurrentTruncInst->getOperand(0)); 86 87 while (!Worklist.empty()) { 88 Value *Curr = Worklist.back(); 89 90 if (isa<Constant>(Curr)) { 91 Worklist.pop_back(); 92 continue; 93 } 94 95 auto *I = dyn_cast<Instruction>(Curr); 96 if (!I) 97 return false; 98 99 if (!Stack.empty() && Stack.back() == I) { 100 // Already handled all instruction operands, can remove it from both the 101 // Worklist and the Stack, and add it to the instruction info map. 102 Worklist.pop_back(); 103 Stack.pop_back(); 104 // Insert I to the Info map. 105 InstInfoMap.insert(std::make_pair(I, Info())); 106 continue; 107 } 108 109 if (InstInfoMap.count(I)) { 110 Worklist.pop_back(); 111 continue; 112 } 113 114 // Add the instruction to the stack before start handling its operands. 115 Stack.push_back(I); 116 117 unsigned Opc = I->getOpcode(); 118 switch (Opc) { 119 case Instruction::Trunc: 120 case Instruction::ZExt: 121 case Instruction::SExt: 122 // trunc(trunc(x)) -> trunc(x) 123 // trunc(ext(x)) -> ext(x) if the source type is smaller than the new dest 124 // trunc(ext(x)) -> trunc(x) if the source type is larger than the new 125 // dest 126 break; 127 case Instruction::Add: 128 case Instruction::Sub: 129 case Instruction::Mul: 130 case Instruction::And: 131 case Instruction::Or: 132 case Instruction::Xor: 133 case Instruction::Shl: 134 case Instruction::Select: { 135 SmallVector<Value *, 2> Operands; 136 getRelevantOperands(I, Operands); 137 append_range(Worklist, Operands); 138 break; 139 } 140 default: 141 // TODO: Can handle more cases here: 142 // 1. shufflevector, extractelement, insertelement 143 // 2. udiv, urem 144 // 3. lshr, ashr 145 // 4. phi node(and loop handling) 146 // ... 147 return false; 148 } 149 } 150 return true; 151 } 152 153 unsigned TruncInstCombine::getMinBitWidth() { 154 SmallVector<Value *, 8> Worklist; 155 SmallVector<Instruction *, 8> Stack; 156 157 Value *Src = CurrentTruncInst->getOperand(0); 158 Type *DstTy = CurrentTruncInst->getType(); 159 unsigned TruncBitWidth = DstTy->getScalarSizeInBits(); 160 unsigned OrigBitWidth = 161 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); 162 163 if (isa<Constant>(Src)) 164 return TruncBitWidth; 165 166 Worklist.push_back(Src); 167 InstInfoMap[cast<Instruction>(Src)].ValidBitWidth = TruncBitWidth; 168 169 while (!Worklist.empty()) { 170 Value *Curr = Worklist.back(); 171 172 if (isa<Constant>(Curr)) { 173 Worklist.pop_back(); 174 continue; 175 } 176 177 // Otherwise, it must be an instruction. 178 auto *I = cast<Instruction>(Curr); 179 180 auto &Info = InstInfoMap[I]; 181 182 SmallVector<Value *, 2> Operands; 183 getRelevantOperands(I, Operands); 184 185 if (!Stack.empty() && Stack.back() == I) { 186 // Already handled all instruction operands, can remove it from both, the 187 // Worklist and the Stack, and update MinBitWidth. 188 Worklist.pop_back(); 189 Stack.pop_back(); 190 for (auto *Operand : Operands) 191 if (auto *IOp = dyn_cast<Instruction>(Operand)) 192 Info.MinBitWidth = 193 std::max(Info.MinBitWidth, InstInfoMap[IOp].MinBitWidth); 194 continue; 195 } 196 197 // Add the instruction to the stack before start handling its operands. 198 Stack.push_back(I); 199 unsigned ValidBitWidth = Info.ValidBitWidth; 200 201 // Update minimum bit-width before handling its operands. This is required 202 // when the instruction is part of a loop. 203 Info.MinBitWidth = std::max(Info.MinBitWidth, Info.ValidBitWidth); 204 205 for (auto *Operand : Operands) 206 if (auto *IOp = dyn_cast<Instruction>(Operand)) { 207 // If we already calculated the minimum bit-width for this valid 208 // bit-width, or for a smaller valid bit-width, then just keep the 209 // answer we already calculated. 210 unsigned IOpBitwidth = InstInfoMap.lookup(IOp).ValidBitWidth; 211 if (IOpBitwidth >= ValidBitWidth) 212 continue; 213 InstInfoMap[IOp].ValidBitWidth = ValidBitWidth; 214 Worklist.push_back(IOp); 215 } 216 } 217 unsigned MinBitWidth = InstInfoMap.lookup(cast<Instruction>(Src)).MinBitWidth; 218 assert(MinBitWidth >= TruncBitWidth); 219 220 if (MinBitWidth > TruncBitWidth) { 221 // In this case reducing expression with vector type might generate a new 222 // vector type, which is not preferable as it might result in generating 223 // sub-optimal code. 224 if (DstTy->isVectorTy()) 225 return OrigBitWidth; 226 // Use the smallest integer type in the range [MinBitWidth, OrigBitWidth). 227 Type *Ty = DL.getSmallestLegalIntType(DstTy->getContext(), MinBitWidth); 228 // Update minimum bit-width with the new destination type bit-width if 229 // succeeded to find such, otherwise, with original bit-width. 230 MinBitWidth = Ty ? Ty->getScalarSizeInBits() : OrigBitWidth; 231 } else { // MinBitWidth == TruncBitWidth 232 // In this case the expression can be evaluated with the trunc instruction 233 // destination type, and trunc instruction can be omitted. However, we 234 // should not perform the evaluation if the original type is a legal scalar 235 // type and the target type is illegal. 236 bool FromLegal = MinBitWidth == 1 || DL.isLegalInteger(OrigBitWidth); 237 bool ToLegal = MinBitWidth == 1 || DL.isLegalInteger(MinBitWidth); 238 if (!DstTy->isVectorTy() && FromLegal && !ToLegal) 239 return OrigBitWidth; 240 } 241 return MinBitWidth; 242 } 243 244 Type *TruncInstCombine::getBestTruncatedType() { 245 if (!buildTruncExpressionDag()) 246 return nullptr; 247 248 // We don't want to duplicate instructions, which isn't profitable. Thus, we 249 // can't shrink something that has multiple users, unless all users are 250 // post-dominated by the trunc instruction, i.e., were visited during the 251 // expression evaluation. 252 unsigned DesiredBitWidth = 0; 253 for (auto Itr : InstInfoMap) { 254 Instruction *I = Itr.first; 255 if (I->hasOneUse()) 256 continue; 257 bool IsExtInst = (isa<ZExtInst>(I) || isa<SExtInst>(I)); 258 for (auto *U : I->users()) 259 if (auto *UI = dyn_cast<Instruction>(U)) 260 if (UI != CurrentTruncInst && !InstInfoMap.count(UI)) { 261 if (!IsExtInst) 262 return nullptr; 263 // If this is an extension from the dest type, we can eliminate it, 264 // even if it has multiple users. Thus, update the DesiredBitWidth and 265 // validate all extension instructions agrees on same DesiredBitWidth. 266 unsigned ExtInstBitWidth = 267 I->getOperand(0)->getType()->getScalarSizeInBits(); 268 if (DesiredBitWidth && DesiredBitWidth != ExtInstBitWidth) 269 return nullptr; 270 DesiredBitWidth = ExtInstBitWidth; 271 } 272 } 273 274 unsigned OrigBitWidth = 275 CurrentTruncInst->getOperand(0)->getType()->getScalarSizeInBits(); 276 277 // Initialize MinBitWidth for `shl` instructions with the minimum number 278 // that is greater than shift amount (i.e. shift amount + 1). 279 // Also normalize MinBitWidth not to be greater than source bitwidth. 280 for (auto &Itr : InstInfoMap) { 281 Instruction *I = Itr.first; 282 if (I->getOpcode() == Instruction::Shl) { 283 KnownBits KnownRHS = computeKnownBits(I->getOperand(1), DL); 284 const unsigned SrcBitWidth = KnownRHS.getBitWidth(); 285 unsigned MinBitWidth = 286 KnownRHS.getMaxValue().uadd_sat(APInt(SrcBitWidth, 1)).getZExtValue(); 287 MinBitWidth = std::min(MinBitWidth, SrcBitWidth); 288 if (MinBitWidth >= OrigBitWidth) 289 return nullptr; 290 Itr.second.MinBitWidth = MinBitWidth; 291 } 292 } 293 294 // Calculate minimum allowed bit-width allowed for shrinking the currently 295 // visited truncate's operand. 296 unsigned MinBitWidth = getMinBitWidth(); 297 298 // Check that we can shrink to smaller bit-width than original one and that 299 // it is similar to the DesiredBitWidth is such exists. 300 if (MinBitWidth >= OrigBitWidth || 301 (DesiredBitWidth && DesiredBitWidth != MinBitWidth)) 302 return nullptr; 303 304 return IntegerType::get(CurrentTruncInst->getContext(), MinBitWidth); 305 } 306 307 /// Given a reduced scalar type \p Ty and a \p V value, return a reduced type 308 /// for \p V, according to its type, if it vector type, return the vector 309 /// version of \p Ty, otherwise return \p Ty. 310 static Type *getReducedType(Value *V, Type *Ty) { 311 assert(Ty && !Ty->isVectorTy() && "Expect Scalar Type"); 312 if (auto *VTy = dyn_cast<VectorType>(V->getType())) 313 return VectorType::get(Ty, VTy->getElementCount()); 314 return Ty; 315 } 316 317 Value *TruncInstCombine::getReducedOperand(Value *V, Type *SclTy) { 318 Type *Ty = getReducedType(V, SclTy); 319 if (auto *C = dyn_cast<Constant>(V)) { 320 C = ConstantExpr::getIntegerCast(C, Ty, false); 321 // If we got a constantexpr back, try to simplify it with DL info. 322 return ConstantFoldConstant(C, DL, &TLI); 323 } 324 325 auto *I = cast<Instruction>(V); 326 Info Entry = InstInfoMap.lookup(I); 327 assert(Entry.NewValue); 328 return Entry.NewValue; 329 } 330 331 void TruncInstCombine::ReduceExpressionDag(Type *SclTy) { 332 NumInstrsReduced += InstInfoMap.size(); 333 for (auto &Itr : InstInfoMap) { // Forward 334 Instruction *I = Itr.first; 335 TruncInstCombine::Info &NodeInfo = Itr.second; 336 337 assert(!NodeInfo.NewValue && "Instruction has been evaluated"); 338 339 IRBuilder<> Builder(I); 340 Value *Res = nullptr; 341 unsigned Opc = I->getOpcode(); 342 switch (Opc) { 343 case Instruction::Trunc: 344 case Instruction::ZExt: 345 case Instruction::SExt: { 346 Type *Ty = getReducedType(I, SclTy); 347 // If the source type of the cast is the type we're trying for then we can 348 // just return the source. There's no need to insert it because it is not 349 // new. 350 if (I->getOperand(0)->getType() == Ty) { 351 assert(!isa<TruncInst>(I) && "Cannot reach here with TruncInst"); 352 NodeInfo.NewValue = I->getOperand(0); 353 continue; 354 } 355 // Otherwise, must be the same type of cast, so just reinsert a new one. 356 // This also handles the case of zext(trunc(x)) -> zext(x). 357 Res = Builder.CreateIntCast(I->getOperand(0), Ty, 358 Opc == Instruction::SExt); 359 360 // Update Worklist entries with new value if needed. 361 // There are three possible changes to the Worklist: 362 // 1. Update Old-TruncInst -> New-TruncInst. 363 // 2. Remove Old-TruncInst (if New node is not TruncInst). 364 // 3. Add New-TruncInst (if Old node was not TruncInst). 365 auto *Entry = find(Worklist, I); 366 if (Entry != Worklist.end()) { 367 if (auto *NewCI = dyn_cast<TruncInst>(Res)) 368 *Entry = NewCI; 369 else 370 Worklist.erase(Entry); 371 } else if (auto *NewCI = dyn_cast<TruncInst>(Res)) 372 Worklist.push_back(NewCI); 373 break; 374 } 375 case Instruction::Add: 376 case Instruction::Sub: 377 case Instruction::Mul: 378 case Instruction::And: 379 case Instruction::Or: 380 case Instruction::Xor: 381 case Instruction::Shl: { 382 Value *LHS = getReducedOperand(I->getOperand(0), SclTy); 383 Value *RHS = getReducedOperand(I->getOperand(1), SclTy); 384 Res = Builder.CreateBinOp((Instruction::BinaryOps)Opc, LHS, RHS); 385 break; 386 } 387 case Instruction::Select: { 388 Value *Op0 = I->getOperand(0); 389 Value *LHS = getReducedOperand(I->getOperand(1), SclTy); 390 Value *RHS = getReducedOperand(I->getOperand(2), SclTy); 391 Res = Builder.CreateSelect(Op0, LHS, RHS); 392 break; 393 } 394 default: 395 llvm_unreachable("Unhandled instruction"); 396 } 397 398 NodeInfo.NewValue = Res; 399 if (auto *ResI = dyn_cast<Instruction>(Res)) 400 ResI->takeName(I); 401 } 402 403 Value *Res = getReducedOperand(CurrentTruncInst->getOperand(0), SclTy); 404 Type *DstTy = CurrentTruncInst->getType(); 405 if (Res->getType() != DstTy) { 406 IRBuilder<> Builder(CurrentTruncInst); 407 Res = Builder.CreateIntCast(Res, DstTy, false); 408 if (auto *ResI = dyn_cast<Instruction>(Res)) 409 ResI->takeName(CurrentTruncInst); 410 } 411 CurrentTruncInst->replaceAllUsesWith(Res); 412 413 // Erase old expression dag, which was replaced by the reduced expression dag. 414 // We iterate backward, which means we visit the instruction before we visit 415 // any of its operands, this way, when we get to the operand, we already 416 // removed the instructions (from the expression dag) that uses it. 417 CurrentTruncInst->eraseFromParent(); 418 for (auto I = InstInfoMap.rbegin(), E = InstInfoMap.rend(); I != E; ++I) { 419 // We still need to check that the instruction has no users before we erase 420 // it, because {SExt, ZExt}Inst Instruction might have other users that was 421 // not reduced, in such case, we need to keep that instruction. 422 if (I->first->use_empty()) 423 I->first->eraseFromParent(); 424 } 425 } 426 427 bool TruncInstCombine::run(Function &F) { 428 bool MadeIRChange = false; 429 430 // Collect all TruncInst in the function into the Worklist for evaluating. 431 for (auto &BB : F) { 432 // Ignore unreachable basic block. 433 if (!DT.isReachableFromEntry(&BB)) 434 continue; 435 for (auto &I : BB) 436 if (auto *CI = dyn_cast<TruncInst>(&I)) 437 Worklist.push_back(CI); 438 } 439 440 // Process all TruncInst in the Worklist, for each instruction: 441 // 1. Check if it dominates an eligible expression dag to be reduced. 442 // 2. Create a reduced expression dag and replace the old one with it. 443 while (!Worklist.empty()) { 444 CurrentTruncInst = Worklist.pop_back_val(); 445 446 if (Type *NewDstSclTy = getBestTruncatedType()) { 447 LLVM_DEBUG( 448 dbgs() << "ICE: TruncInstCombine reducing type of expression dag " 449 "dominated by: " 450 << CurrentTruncInst << '\n'); 451 ReduceExpressionDag(NewDstSclTy); 452 ++NumDAGsReduced; 453 MadeIRChange = true; 454 } 455 } 456 457 return MadeIRChange; 458 } 459