1 //===- CostModel.cpp ------ Cost Model Analysis ---------------------------===// 2 // 3 // The LLVM Compiler Infrastructure 4 // 5 // This file is distributed under the University of Illinois Open Source 6 // License. See LICENSE.TXT for details. 7 // 8 //===----------------------------------------------------------------------===// 9 // 10 // This file defines the cost model analysis. It provides a very basic cost 11 // estimation for LLVM-IR. This analysis uses the services of the codegen 12 // to approximate the cost of any IR instruction when lowered to machine 13 // instructions. The cost results are unit-less and the cost number represents 14 // the throughput of the machine assuming that all loads hit the cache, all 15 // branches are predicted, etc. The cost numbers can be added in order to 16 // compare two or more transformation alternatives. 17 // 18 //===----------------------------------------------------------------------===// 19 20 #include "llvm/ADT/STLExtras.h" 21 #include "llvm/Analysis/Passes.h" 22 #include "llvm/Analysis/TargetTransformInfo.h" 23 #include "llvm/Analysis/VectorUtils.h" 24 #include "llvm/IR/Function.h" 25 #include "llvm/IR/Instructions.h" 26 #include "llvm/IR/IntrinsicInst.h" 27 #include "llvm/IR/Value.h" 28 #include "llvm/Pass.h" 29 #include "llvm/Support/CommandLine.h" 30 #include "llvm/Support/Debug.h" 31 #include "llvm/Support/raw_ostream.h" 32 using namespace llvm; 33 34 #define CM_NAME "cost-model" 35 #define DEBUG_TYPE CM_NAME 36 37 static cl::opt<bool> EnableReduxCost("costmodel-reduxcost", cl::init(false), 38 cl::Hidden, 39 cl::desc("Recognize reduction patterns.")); 40 41 namespace { 42 class CostModelAnalysis : public FunctionPass { 43 44 public: 45 static char ID; // Class identification, replacement for typeinfo 46 CostModelAnalysis() : FunctionPass(ID), F(nullptr), TTI(nullptr) { 47 initializeCostModelAnalysisPass( 48 *PassRegistry::getPassRegistry()); 49 } 50 51 /// Returns the expected cost of the instruction. 52 /// Returns -1 if the cost is unknown. 53 /// Note, this method does not cache the cost calculation and it 54 /// can be expensive in some cases. 55 unsigned getInstructionCost(const Instruction *I) const; 56 57 private: 58 void getAnalysisUsage(AnalysisUsage &AU) const override; 59 bool runOnFunction(Function &F) override; 60 void print(raw_ostream &OS, const Module*) const override; 61 62 /// The function that we analyze. 63 Function *F; 64 /// Target information. 65 const TargetTransformInfo *TTI; 66 }; 67 } // End of anonymous namespace 68 69 // Register this pass. 70 char CostModelAnalysis::ID = 0; 71 static const char cm_name[] = "Cost Model Analysis"; 72 INITIALIZE_PASS_BEGIN(CostModelAnalysis, CM_NAME, cm_name, false, true) 73 INITIALIZE_PASS_END (CostModelAnalysis, CM_NAME, cm_name, false, true) 74 75 FunctionPass *llvm::createCostModelAnalysisPass() { 76 return new CostModelAnalysis(); 77 } 78 79 void 80 CostModelAnalysis::getAnalysisUsage(AnalysisUsage &AU) const { 81 AU.setPreservesAll(); 82 } 83 84 bool 85 CostModelAnalysis::runOnFunction(Function &F) { 86 this->F = &F; 87 auto *TTIWP = getAnalysisIfAvailable<TargetTransformInfoWrapperPass>(); 88 TTI = TTIWP ? &TTIWP->getTTI(F) : nullptr; 89 90 return false; 91 } 92 93 static bool isReverseVectorMask(SmallVectorImpl<int> &Mask) { 94 for (unsigned i = 0, MaskSize = Mask.size(); i < MaskSize; ++i) 95 if (Mask[i] > 0 && Mask[i] != (int)(MaskSize - 1 - i)) 96 return false; 97 return true; 98 } 99 100 static bool isAlternateVectorMask(SmallVectorImpl<int> &Mask) { 101 bool isAlternate = true; 102 unsigned MaskSize = Mask.size(); 103 104 // Example: shufflevector A, B, <0,5,2,7> 105 for (unsigned i = 0; i < MaskSize && isAlternate; ++i) { 106 if (Mask[i] < 0) 107 continue; 108 isAlternate = Mask[i] == (int)((i & 1) ? MaskSize + i : i); 109 } 110 111 if (isAlternate) 112 return true; 113 114 isAlternate = true; 115 // Example: shufflevector A, B, <4,1,6,3> 116 for (unsigned i = 0; i < MaskSize && isAlternate; ++i) { 117 if (Mask[i] < 0) 118 continue; 119 isAlternate = Mask[i] == (int)((i & 1) ? i : MaskSize + i); 120 } 121 122 return isAlternate; 123 } 124 125 static TargetTransformInfo::OperandValueKind getOperandInfo(Value *V) { 126 TargetTransformInfo::OperandValueKind OpInfo = 127 TargetTransformInfo::OK_AnyValue; 128 129 // Check for a splat of a constant or for a non uniform vector of constants. 130 if (isa<ConstantVector>(V) || isa<ConstantDataVector>(V)) { 131 OpInfo = TargetTransformInfo::OK_NonUniformConstantValue; 132 if (cast<Constant>(V)->getSplatValue() != nullptr) 133 OpInfo = TargetTransformInfo::OK_UniformConstantValue; 134 } 135 136 // Check for a splat of a uniform value. This is not loop aware, so return 137 // true only for the obviously uniform cases (argument, globalvalue) 138 const Value *Splat = getSplatValue(V); 139 if (Splat && (isa<Argument>(Splat) || isa<GlobalValue>(Splat))) 140 OpInfo = TargetTransformInfo::OK_UniformValue; 141 142 return OpInfo; 143 } 144 145 static bool matchPairwiseShuffleMask(ShuffleVectorInst *SI, bool IsLeft, 146 unsigned Level) { 147 // We don't need a shuffle if we just want to have element 0 in position 0 of 148 // the vector. 149 if (!SI && Level == 0 && IsLeft) 150 return true; 151 else if (!SI) 152 return false; 153 154 SmallVector<int, 32> Mask(SI->getType()->getVectorNumElements(), -1); 155 156 // Build a mask of 0, 2, ... (left) or 1, 3, ... (right) depending on whether 157 // we look at the left or right side. 158 for (unsigned i = 0, e = (1 << Level), val = !IsLeft; i != e; ++i, val += 2) 159 Mask[i] = val; 160 161 SmallVector<int, 16> ActualMask = SI->getShuffleMask(); 162 return Mask == ActualMask; 163 } 164 165 static bool matchPairwiseReductionAtLevel(const BinaryOperator *BinOp, 166 unsigned Level, unsigned NumLevels) { 167 // Match one level of pairwise operations. 168 // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, 169 // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> 170 // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, 171 // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> 172 // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 173 if (BinOp == nullptr) 174 return false; 175 176 assert(BinOp->getType()->isVectorTy() && "Expecting a vector type"); 177 178 unsigned Opcode = BinOp->getOpcode(); 179 Value *L = BinOp->getOperand(0); 180 Value *R = BinOp->getOperand(1); 181 182 ShuffleVectorInst *LS = dyn_cast<ShuffleVectorInst>(L); 183 if (!LS && Level) 184 return false; 185 ShuffleVectorInst *RS = dyn_cast<ShuffleVectorInst>(R); 186 if (!RS && Level) 187 return false; 188 189 // On level 0 we can omit one shufflevector instruction. 190 if (!Level && !RS && !LS) 191 return false; 192 193 // Shuffle inputs must match. 194 Value *NextLevelOpL = LS ? LS->getOperand(0) : nullptr; 195 Value *NextLevelOpR = RS ? RS->getOperand(0) : nullptr; 196 Value *NextLevelOp = nullptr; 197 if (NextLevelOpR && NextLevelOpL) { 198 // If we have two shuffles their operands must match. 199 if (NextLevelOpL != NextLevelOpR) 200 return false; 201 202 NextLevelOp = NextLevelOpL; 203 } else if (Level == 0 && (NextLevelOpR || NextLevelOpL)) { 204 // On the first level we can omit the shufflevector <0, undef,...>. So the 205 // input to the other shufflevector <1, undef> must match with one of the 206 // inputs to the current binary operation. 207 // Example: 208 // %NextLevelOpL = shufflevector %R, <1, undef ...> 209 // %BinOp = fadd %NextLevelOpL, %R 210 if (NextLevelOpL && NextLevelOpL != R) 211 return false; 212 else if (NextLevelOpR && NextLevelOpR != L) 213 return false; 214 215 NextLevelOp = NextLevelOpL ? R : L; 216 } else 217 return false; 218 219 // Check that the next levels binary operation exists and matches with the 220 // current one. 221 BinaryOperator *NextLevelBinOp = nullptr; 222 if (Level + 1 != NumLevels) { 223 if (!(NextLevelBinOp = dyn_cast<BinaryOperator>(NextLevelOp))) 224 return false; 225 else if (NextLevelBinOp->getOpcode() != Opcode) 226 return false; 227 } 228 229 // Shuffle mask for pairwise operation must match. 230 if (matchPairwiseShuffleMask(LS, true, Level)) { 231 if (!matchPairwiseShuffleMask(RS, false, Level)) 232 return false; 233 } else if (matchPairwiseShuffleMask(RS, true, Level)) { 234 if (!matchPairwiseShuffleMask(LS, false, Level)) 235 return false; 236 } else 237 return false; 238 239 if (++Level == NumLevels) 240 return true; 241 242 // Match next level. 243 return matchPairwiseReductionAtLevel(NextLevelBinOp, Level, NumLevels); 244 } 245 246 static bool matchPairwiseReduction(const ExtractElementInst *ReduxRoot, 247 unsigned &Opcode, Type *&Ty) { 248 if (!EnableReduxCost) 249 return false; 250 251 // Need to extract the first element. 252 ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); 253 unsigned Idx = ~0u; 254 if (CI) 255 Idx = CI->getZExtValue(); 256 if (Idx != 0) 257 return false; 258 259 BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0)); 260 if (!RdxStart) 261 return false; 262 263 Type *VecTy = ReduxRoot->getOperand(0)->getType(); 264 unsigned NumVecElems = VecTy->getVectorNumElements(); 265 if (!isPowerOf2_32(NumVecElems)) 266 return false; 267 268 // We look for a sequence of shuffle,shuffle,add triples like the following 269 // that builds a pairwise reduction tree. 270 // 271 // (X0, X1, X2, X3) 272 // (X0 + X1, X2 + X3, undef, undef) 273 // ((X0 + X1) + (X2 + X3), undef, undef, undef) 274 // 275 // %rdx.shuf.0.0 = shufflevector <4 x float> %rdx, <4 x float> undef, 276 // <4 x i32> <i32 0, i32 2 , i32 undef, i32 undef> 277 // %rdx.shuf.0.1 = shufflevector <4 x float> %rdx, <4 x float> undef, 278 // <4 x i32> <i32 1, i32 3, i32 undef, i32 undef> 279 // %bin.rdx.0 = fadd <4 x float> %rdx.shuf.0.0, %rdx.shuf.0.1 280 // %rdx.shuf.1.0 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef, 281 // <4 x i32> <i32 0, i32 undef, i32 undef, i32 undef> 282 // %rdx.shuf.1.1 = shufflevector <4 x float> %bin.rdx.0, <4 x float> undef, 283 // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> 284 // %bin.rdx8 = fadd <4 x float> %rdx.shuf.1.0, %rdx.shuf.1.1 285 // %r = extractelement <4 x float> %bin.rdx8, i32 0 286 if (!matchPairwiseReductionAtLevel(RdxStart, 0, Log2_32(NumVecElems))) 287 return false; 288 289 Opcode = RdxStart->getOpcode(); 290 Ty = VecTy; 291 292 return true; 293 } 294 295 static std::pair<Value *, ShuffleVectorInst *> 296 getShuffleAndOtherOprd(BinaryOperator *B) { 297 298 Value *L = B->getOperand(0); 299 Value *R = B->getOperand(1); 300 ShuffleVectorInst *S = nullptr; 301 302 if ((S = dyn_cast<ShuffleVectorInst>(L))) 303 return std::make_pair(R, S); 304 305 S = dyn_cast<ShuffleVectorInst>(R); 306 return std::make_pair(L, S); 307 } 308 309 static bool matchVectorSplittingReduction(const ExtractElementInst *ReduxRoot, 310 unsigned &Opcode, Type *&Ty) { 311 if (!EnableReduxCost) 312 return false; 313 314 // Need to extract the first element. 315 ConstantInt *CI = dyn_cast<ConstantInt>(ReduxRoot->getOperand(1)); 316 unsigned Idx = ~0u; 317 if (CI) 318 Idx = CI->getZExtValue(); 319 if (Idx != 0) 320 return false; 321 322 BinaryOperator *RdxStart = dyn_cast<BinaryOperator>(ReduxRoot->getOperand(0)); 323 if (!RdxStart) 324 return false; 325 unsigned RdxOpcode = RdxStart->getOpcode(); 326 327 Type *VecTy = ReduxRoot->getOperand(0)->getType(); 328 unsigned NumVecElems = VecTy->getVectorNumElements(); 329 if (!isPowerOf2_32(NumVecElems)) 330 return false; 331 332 // We look for a sequence of shuffles and adds like the following matching one 333 // fadd, shuffle vector pair at a time. 334 // 335 // %rdx.shuf = shufflevector <4 x float> %rdx, <4 x float> undef, 336 // <4 x i32> <i32 2, i32 3, i32 undef, i32 undef> 337 // %bin.rdx = fadd <4 x float> %rdx, %rdx.shuf 338 // %rdx.shuf7 = shufflevector <4 x float> %bin.rdx, <4 x float> undef, 339 // <4 x i32> <i32 1, i32 undef, i32 undef, i32 undef> 340 // %bin.rdx8 = fadd <4 x float> %bin.rdx, %rdx.shuf7 341 // %r = extractelement <4 x float> %bin.rdx8, i32 0 342 343 unsigned MaskStart = 1; 344 Value *RdxOp = RdxStart; 345 SmallVector<int, 32> ShuffleMask(NumVecElems, 0); 346 unsigned NumVecElemsRemain = NumVecElems; 347 while (NumVecElemsRemain - 1) { 348 // Check for the right reduction operation. 349 BinaryOperator *BinOp; 350 if (!(BinOp = dyn_cast<BinaryOperator>(RdxOp))) 351 return false; 352 if (BinOp->getOpcode() != RdxOpcode) 353 return false; 354 355 Value *NextRdxOp; 356 ShuffleVectorInst *Shuffle; 357 std::tie(NextRdxOp, Shuffle) = getShuffleAndOtherOprd(BinOp); 358 359 // Check the current reduction operation and the shuffle use the same value. 360 if (Shuffle == nullptr) 361 return false; 362 if (Shuffle->getOperand(0) != NextRdxOp) 363 return false; 364 365 // Check that shuffle masks matches. 366 for (unsigned j = 0; j != MaskStart; ++j) 367 ShuffleMask[j] = MaskStart + j; 368 // Fill the rest of the mask with -1 for undef. 369 std::fill(&ShuffleMask[MaskStart], ShuffleMask.end(), -1); 370 371 SmallVector<int, 16> Mask = Shuffle->getShuffleMask(); 372 if (ShuffleMask != Mask) 373 return false; 374 375 RdxOp = NextRdxOp; 376 NumVecElemsRemain /= 2; 377 MaskStart *= 2; 378 } 379 380 Opcode = RdxOpcode; 381 Ty = VecTy; 382 return true; 383 } 384 385 unsigned CostModelAnalysis::getInstructionCost(const Instruction *I) const { 386 if (!TTI) 387 return -1; 388 389 switch (I->getOpcode()) { 390 case Instruction::GetElementPtr: 391 return TTI->getUserCost(I); 392 393 case Instruction::Ret: 394 case Instruction::PHI: 395 case Instruction::Br: { 396 return TTI->getCFInstrCost(I->getOpcode()); 397 } 398 case Instruction::Add: 399 case Instruction::FAdd: 400 case Instruction::Sub: 401 case Instruction::FSub: 402 case Instruction::Mul: 403 case Instruction::FMul: 404 case Instruction::UDiv: 405 case Instruction::SDiv: 406 case Instruction::FDiv: 407 case Instruction::URem: 408 case Instruction::SRem: 409 case Instruction::FRem: 410 case Instruction::Shl: 411 case Instruction::LShr: 412 case Instruction::AShr: 413 case Instruction::And: 414 case Instruction::Or: 415 case Instruction::Xor: { 416 TargetTransformInfo::OperandValueKind Op1VK = 417 getOperandInfo(I->getOperand(0)); 418 TargetTransformInfo::OperandValueKind Op2VK = 419 getOperandInfo(I->getOperand(1)); 420 return TTI->getArithmeticInstrCost(I->getOpcode(), I->getType(), Op1VK, 421 Op2VK); 422 } 423 case Instruction::Select: { 424 const SelectInst *SI = cast<SelectInst>(I); 425 Type *CondTy = SI->getCondition()->getType(); 426 return TTI->getCmpSelInstrCost(I->getOpcode(), I->getType(), CondTy); 427 } 428 case Instruction::ICmp: 429 case Instruction::FCmp: { 430 Type *ValTy = I->getOperand(0)->getType(); 431 return TTI->getCmpSelInstrCost(I->getOpcode(), ValTy); 432 } 433 case Instruction::Store: { 434 const StoreInst *SI = cast<StoreInst>(I); 435 Type *ValTy = SI->getValueOperand()->getType(); 436 return TTI->getMemoryOpCost(I->getOpcode(), ValTy, 437 SI->getAlignment(), 438 SI->getPointerAddressSpace()); 439 } 440 case Instruction::Load: { 441 const LoadInst *LI = cast<LoadInst>(I); 442 return TTI->getMemoryOpCost(I->getOpcode(), I->getType(), 443 LI->getAlignment(), 444 LI->getPointerAddressSpace()); 445 } 446 case Instruction::ZExt: 447 case Instruction::SExt: 448 case Instruction::FPToUI: 449 case Instruction::FPToSI: 450 case Instruction::FPExt: 451 case Instruction::PtrToInt: 452 case Instruction::IntToPtr: 453 case Instruction::SIToFP: 454 case Instruction::UIToFP: 455 case Instruction::Trunc: 456 case Instruction::FPTrunc: 457 case Instruction::BitCast: 458 case Instruction::AddrSpaceCast: { 459 Type *SrcTy = I->getOperand(0)->getType(); 460 return TTI->getCastInstrCost(I->getOpcode(), I->getType(), SrcTy); 461 } 462 case Instruction::ExtractElement: { 463 const ExtractElementInst * EEI = cast<ExtractElementInst>(I); 464 ConstantInt *CI = dyn_cast<ConstantInt>(I->getOperand(1)); 465 unsigned Idx = -1; 466 if (CI) 467 Idx = CI->getZExtValue(); 468 469 // Try to match a reduction sequence (series of shufflevector and vector 470 // adds followed by a extractelement). 471 unsigned ReduxOpCode; 472 Type *ReduxType; 473 474 if (matchVectorSplittingReduction(EEI, ReduxOpCode, ReduxType)) 475 return TTI->getReductionCost(ReduxOpCode, ReduxType, false); 476 else if (matchPairwiseReduction(EEI, ReduxOpCode, ReduxType)) 477 return TTI->getReductionCost(ReduxOpCode, ReduxType, true); 478 479 return TTI->getVectorInstrCost(I->getOpcode(), 480 EEI->getOperand(0)->getType(), Idx); 481 } 482 case Instruction::InsertElement: { 483 const InsertElementInst * IE = cast<InsertElementInst>(I); 484 ConstantInt *CI = dyn_cast<ConstantInt>(IE->getOperand(2)); 485 unsigned Idx = -1; 486 if (CI) 487 Idx = CI->getZExtValue(); 488 return TTI->getVectorInstrCost(I->getOpcode(), 489 IE->getType(), Idx); 490 } 491 case Instruction::ShuffleVector: { 492 const ShuffleVectorInst *Shuffle = cast<ShuffleVectorInst>(I); 493 Type *VecTypOp0 = Shuffle->getOperand(0)->getType(); 494 unsigned NumVecElems = VecTypOp0->getVectorNumElements(); 495 SmallVector<int, 16> Mask = Shuffle->getShuffleMask(); 496 497 if (NumVecElems == Mask.size()) { 498 if (isReverseVectorMask(Mask)) 499 return TTI->getShuffleCost(TargetTransformInfo::SK_Reverse, VecTypOp0, 500 0, nullptr); 501 if (isAlternateVectorMask(Mask)) 502 return TTI->getShuffleCost(TargetTransformInfo::SK_Alternate, 503 VecTypOp0, 0, nullptr); 504 } 505 506 return -1; 507 } 508 case Instruction::Call: 509 if (const IntrinsicInst *II = dyn_cast<IntrinsicInst>(I)) { 510 SmallVector<Value *, 4> Args; 511 for (unsigned J = 0, JE = II->getNumArgOperands(); J != JE; ++J) 512 Args.push_back(II->getArgOperand(J)); 513 514 FastMathFlags FMF; 515 if (auto *FPMO = dyn_cast<FPMathOperator>(II)) 516 FMF = FPMO->getFastMathFlags(); 517 518 return TTI->getIntrinsicInstrCost(II->getIntrinsicID(), II->getType(), 519 Args, FMF); 520 } 521 return -1; 522 default: 523 // We don't have any information on this instruction. 524 return -1; 525 } 526 } 527 528 void CostModelAnalysis::print(raw_ostream &OS, const Module*) const { 529 if (!F) 530 return; 531 532 for (BasicBlock &B : *F) { 533 for (Instruction &Inst : B) { 534 unsigned Cost = getInstructionCost(&Inst); 535 if (Cost != (unsigned)-1) 536 OS << "Cost Model: Found an estimated cost of " << Cost; 537 else 538 OS << "Cost Model: Unknown cost"; 539 540 OS << " for instruction: " << Inst << "\n"; 541 } 542 } 543 } 544