1 //===----------------------------------------------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #include "ReduceOperands.h" 10 #include "llvm/IR/Constants.h" 11 #include "llvm/IR/InstIterator.h" 12 #include "llvm/IR/InstrTypes.h" 13 #include "llvm/IR/Operator.h" 14 #include "llvm/IR/PatternMatch.h" 15 #include "llvm/IR/Type.h" 16 17 using namespace llvm; 18 using namespace PatternMatch; 19 20 static void 21 extractOperandsFromModule(Oracle &O, Module &Program, 22 function_ref<Value *(Use &)> ReduceValue) { 23 for (auto &F : Program.functions()) { 24 for (auto &I : instructions(&F)) { 25 for (auto &Op : I.operands()) { 26 if (!O.shouldKeep()) { 27 if (Value *Reduced = ReduceValue(Op)) 28 Op.set(Reduced); 29 } 30 } 31 } 32 } 33 } 34 35 static bool isOne(Use &Op) { 36 auto *C = dyn_cast<Constant>(Op); 37 return C && C->isOneValue(); 38 } 39 40 static bool isZero(Use &Op) { 41 auto *C = dyn_cast<Constant>(Op); 42 return C && C->isNullValue(); 43 } 44 45 static bool isZeroOrOneFP(Value *Op) { 46 const APFloat *C; 47 return match(Op, m_APFloat(C)) && 48 ((C->isZero() && !C->isNegative()) || C->isExactlyValue(1.0)); 49 } 50 51 static bool shouldReduceOperand(Use &Op) { 52 Type *Ty = Op->getType(); 53 if (Ty->isLabelTy() || Ty->isMetadataTy()) 54 return false; 55 // TODO: be more precise about which GEP operands we can reduce (e.g. array 56 // indexes) 57 if (isa<GEPOperator>(Op.getUser())) 58 return false; 59 if (auto *CB = dyn_cast<CallBase>(Op.getUser())) { 60 if (&CB->getCalledOperandUse() == &Op) 61 return false; 62 } 63 return true; 64 } 65 66 void llvm::reduceOperandsOneDeltaPass(TestRunner &Test) { 67 errs() << "*** Reducing Operands to one...\n"; 68 auto ReduceValue = [](Use &Op) -> Value * { 69 if (!shouldReduceOperand(Op)) 70 return nullptr; 71 72 Type *Ty = Op->getType(); 73 if (auto *IntTy = dyn_cast<IntegerType>(Ty)) { 74 // Don't replace existing ones and zeroes. 75 return (isOne(Op) || isZero(Op)) ? nullptr : ConstantInt::get(IntTy, 1); 76 } 77 78 if (Ty->isFloatingPointTy()) 79 return isZeroOrOneFP(Op) ? nullptr : ConstantFP::get(Ty, 1.0); 80 81 if (VectorType *VT = dyn_cast<VectorType>(Ty)) { 82 if (isOne(Op) || isZero(Op) || isZeroOrOneFP(Op)) 83 return nullptr; 84 85 if (auto *IntTy = dyn_cast<IntegerType>(VT->getElementType())) 86 return ConstantVector::getSplat(VT->getElementCount(), 87 ConstantInt::get(IntTy, 1)); 88 89 return ConstantVector::getSplat( 90 VT->getElementCount(), ConstantFP::get(VT->getElementType(), 1.0)); 91 } 92 93 return nullptr; 94 }; 95 runDeltaPass(Test, [ReduceValue](Oracle &O, Module &Program) { 96 extractOperandsFromModule(O, Program, ReduceValue); 97 }); 98 } 99 100 void llvm::reduceOperandsZeroDeltaPass(TestRunner &Test) { 101 errs() << "*** Reducing Operands to zero...\n"; 102 auto ReduceValue = [](Use &Op) -> Value * { 103 if (!shouldReduceOperand(Op)) 104 return nullptr; 105 // Don't replace existing zeroes. 106 return isZero(Op) ? nullptr : Constant::getNullValue(Op->getType()); 107 }; 108 runDeltaPass(Test, [ReduceValue](Oracle &O, Module &Program) { 109 extractOperandsFromModule(O, Program, ReduceValue); 110 }); 111 } 112 113 void llvm::reduceOperandsNaNDeltaPass(TestRunner &Test) { 114 errs() << "*** Reducing Operands to NaN...\n"; 115 auto ReduceValue = [](Use &Op) -> Value * { 116 Type *Ty = Op->getType(); 117 if (!Ty->isFPOrFPVectorTy()) 118 return nullptr; 119 120 // Prefer 0.0 or 1.0 over NaN. 121 // 122 // TODO: Preferring NaN may make more sense because FP operations are more 123 // universally foldable. 124 if (match(Op.get(), m_NaN()) || isZeroOrOneFP(Op.get())) 125 return nullptr; 126 127 if (VectorType *VT = dyn_cast<VectorType>(Ty)) { 128 return ConstantVector::getSplat(VT->getElementCount(), 129 ConstantFP::getQNaN(VT->getElementType())); 130 } 131 132 return ConstantFP::getQNaN(Ty); 133 }; 134 runDeltaPass(Test, [ReduceValue](Oracle &O, Module &Program) { 135 extractOperandsFromModule(O, Program, ReduceValue); 136 }); 137 } 138