1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "ReduceOperands.h"
10 #include "llvm/IR/Constants.h"
11 #include "llvm/IR/InstIterator.h"
12 #include "llvm/IR/InstrTypes.h"
13 #include "llvm/IR/Operator.h"
14 #include "llvm/IR/PatternMatch.h"
15 #include "llvm/IR/Type.h"
16 
17 using namespace llvm;
18 using namespace PatternMatch;
19 
20 static void
extractOperandsFromModule(Oracle & O,Module & Program,function_ref<Value * (Use &)> ReduceValue)21 extractOperandsFromModule(Oracle &O, Module &Program,
22                           function_ref<Value *(Use &)> ReduceValue) {
23   for (auto &F : Program.functions()) {
24     for (auto &I : instructions(&F)) {
25       for (auto &Op : I.operands()) {
26         if (!O.shouldKeep()) {
27           if (Value *Reduced = ReduceValue(Op))
28             Op.set(Reduced);
29         }
30       }
31     }
32   }
33 }
34 
isOne(Use & Op)35 static bool isOne(Use &Op) {
36   auto *C = dyn_cast<Constant>(Op);
37   return C && C->isOneValue();
38 }
39 
isZero(Use & Op)40 static bool isZero(Use &Op) {
41   auto *C = dyn_cast<Constant>(Op);
42   return C && C->isNullValue();
43 }
44 
isZeroOrOneFP(Value * Op)45 static bool isZeroOrOneFP(Value *Op) {
46   const APFloat *C;
47   return match(Op, m_APFloat(C)) &&
48          ((C->isZero() && !C->isNegative()) || C->isExactlyValue(1.0));
49 }
50 
shouldReduceOperand(Use & Op)51 static bool shouldReduceOperand(Use &Op) {
52   Type *Ty = Op->getType();
53   if (Ty->isLabelTy() || Ty->isMetadataTy())
54     return false;
55   // TODO: be more precise about which GEP operands we can reduce (e.g. array
56   // indexes)
57   if (isa<GEPOperator>(Op.getUser()))
58     return false;
59   if (auto *CB = dyn_cast<CallBase>(Op.getUser())) {
60     if (&CB->getCalledOperandUse() == &Op)
61       return false;
62   }
63   return true;
64 }
65 
reduceOperandsOneDeltaPass(TestRunner & Test)66 void llvm::reduceOperandsOneDeltaPass(TestRunner &Test) {
67   errs() << "*** Reducing Operands to one...\n";
68   auto ReduceValue = [](Use &Op) -> Value * {
69     if (!shouldReduceOperand(Op))
70       return nullptr;
71 
72     Type *Ty = Op->getType();
73     if (auto *IntTy = dyn_cast<IntegerType>(Ty)) {
74       // Don't replace existing ones and zeroes.
75       return (isOne(Op) || isZero(Op)) ? nullptr : ConstantInt::get(IntTy, 1);
76     }
77 
78     if (Ty->isFloatingPointTy())
79       return isZeroOrOneFP(Op) ? nullptr : ConstantFP::get(Ty, 1.0);
80 
81     if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
82       if (isOne(Op) || isZero(Op) || isZeroOrOneFP(Op))
83         return nullptr;
84 
85       if (auto *IntTy = dyn_cast<IntegerType>(VT->getElementType()))
86         return ConstantVector::getSplat(VT->getElementCount(),
87                                         ConstantInt::get(IntTy, 1));
88 
89       return ConstantVector::getSplat(
90           VT->getElementCount(), ConstantFP::get(VT->getElementType(), 1.0));
91     }
92 
93     return nullptr;
94   };
95   runDeltaPass(Test, [ReduceValue](Oracle &O, Module &Program) {
96     extractOperandsFromModule(O, Program, ReduceValue);
97   });
98 }
99 
reduceOperandsZeroDeltaPass(TestRunner & Test)100 void llvm::reduceOperandsZeroDeltaPass(TestRunner &Test) {
101   errs() << "*** Reducing Operands to zero...\n";
102   auto ReduceValue = [](Use &Op) -> Value * {
103     if (!shouldReduceOperand(Op))
104       return nullptr;
105     // Don't replace existing zeroes.
106     return isZero(Op) ? nullptr : Constant::getNullValue(Op->getType());
107   };
108   runDeltaPass(Test, [ReduceValue](Oracle &O, Module &Program) {
109     extractOperandsFromModule(O, Program, ReduceValue);
110   });
111 }
112 
reduceOperandsNaNDeltaPass(TestRunner & Test)113 void llvm::reduceOperandsNaNDeltaPass(TestRunner &Test) {
114   errs() << "*** Reducing Operands to NaN...\n";
115   auto ReduceValue = [](Use &Op) -> Value * {
116     Type *Ty = Op->getType();
117     if (!Ty->isFPOrFPVectorTy())
118       return nullptr;
119 
120     // Prefer 0.0 or 1.0 over NaN.
121     //
122     // TODO: Preferring NaN may make more sense because FP operations are more
123     // universally foldable.
124     if (match(Op.get(), m_NaN()) || isZeroOrOneFP(Op.get()))
125       return nullptr;
126 
127     if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
128       return ConstantVector::getSplat(VT->getElementCount(),
129                                       ConstantFP::getQNaN(VT->getElementType()));
130     }
131 
132     return ConstantFP::getQNaN(Ty);
133   };
134   runDeltaPass(Test, [ReduceValue](Oracle &O, Module &Program) {
135     extractOperandsFromModule(O, Program, ReduceValue);
136   });
137 }
138