1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "ReduceOperands.h"
10 #include "llvm/IR/Constants.h"
11 #include "llvm/IR/InstIterator.h"
12 #include "llvm/IR/InstrTypes.h"
13 #include "llvm/IR/Operator.h"
14 #include "llvm/IR/PatternMatch.h"
15 #include "llvm/IR/Type.h"
16
17 using namespace llvm;
18 using namespace PatternMatch;
19
20 static void
extractOperandsFromModule(Oracle & O,Module & Program,function_ref<Value * (Use &)> ReduceValue)21 extractOperandsFromModule(Oracle &O, Module &Program,
22 function_ref<Value *(Use &)> ReduceValue) {
23 for (auto &F : Program.functions()) {
24 for (auto &I : instructions(&F)) {
25 for (auto &Op : I.operands()) {
26 if (!O.shouldKeep()) {
27 if (Value *Reduced = ReduceValue(Op))
28 Op.set(Reduced);
29 }
30 }
31 }
32 }
33 }
34
isOne(Use & Op)35 static bool isOne(Use &Op) {
36 auto *C = dyn_cast<Constant>(Op);
37 return C && C->isOneValue();
38 }
39
isZero(Use & Op)40 static bool isZero(Use &Op) {
41 auto *C = dyn_cast<Constant>(Op);
42 return C && C->isNullValue();
43 }
44
isZeroOrOneFP(Value * Op)45 static bool isZeroOrOneFP(Value *Op) {
46 const APFloat *C;
47 return match(Op, m_APFloat(C)) &&
48 ((C->isZero() && !C->isNegative()) || C->isExactlyValue(1.0));
49 }
50
shouldReduceOperand(Use & Op)51 static bool shouldReduceOperand(Use &Op) {
52 Type *Ty = Op->getType();
53 if (Ty->isLabelTy() || Ty->isMetadataTy())
54 return false;
55 // TODO: be more precise about which GEP operands we can reduce (e.g. array
56 // indexes)
57 if (isa<GEPOperator>(Op.getUser()))
58 return false;
59 if (auto *CB = dyn_cast<CallBase>(Op.getUser())) {
60 if (&CB->getCalledOperandUse() == &Op)
61 return false;
62 }
63 return true;
64 }
65
reduceOperandsOneDeltaPass(TestRunner & Test)66 void llvm::reduceOperandsOneDeltaPass(TestRunner &Test) {
67 errs() << "*** Reducing Operands to one...\n";
68 auto ReduceValue = [](Use &Op) -> Value * {
69 if (!shouldReduceOperand(Op))
70 return nullptr;
71
72 Type *Ty = Op->getType();
73 if (auto *IntTy = dyn_cast<IntegerType>(Ty)) {
74 // Don't replace existing ones and zeroes.
75 return (isOne(Op) || isZero(Op)) ? nullptr : ConstantInt::get(IntTy, 1);
76 }
77
78 if (Ty->isFloatingPointTy())
79 return isZeroOrOneFP(Op) ? nullptr : ConstantFP::get(Ty, 1.0);
80
81 if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
82 if (isOne(Op) || isZero(Op) || isZeroOrOneFP(Op))
83 return nullptr;
84
85 if (auto *IntTy = dyn_cast<IntegerType>(VT->getElementType()))
86 return ConstantVector::getSplat(VT->getElementCount(),
87 ConstantInt::get(IntTy, 1));
88
89 return ConstantVector::getSplat(
90 VT->getElementCount(), ConstantFP::get(VT->getElementType(), 1.0));
91 }
92
93 return nullptr;
94 };
95 runDeltaPass(Test, [ReduceValue](Oracle &O, Module &Program) {
96 extractOperandsFromModule(O, Program, ReduceValue);
97 });
98 }
99
reduceOperandsZeroDeltaPass(TestRunner & Test)100 void llvm::reduceOperandsZeroDeltaPass(TestRunner &Test) {
101 errs() << "*** Reducing Operands to zero...\n";
102 auto ReduceValue = [](Use &Op) -> Value * {
103 if (!shouldReduceOperand(Op))
104 return nullptr;
105 // Don't replace existing zeroes.
106 return isZero(Op) ? nullptr : Constant::getNullValue(Op->getType());
107 };
108 runDeltaPass(Test, [ReduceValue](Oracle &O, Module &Program) {
109 extractOperandsFromModule(O, Program, ReduceValue);
110 });
111 }
112
reduceOperandsNaNDeltaPass(TestRunner & Test)113 void llvm::reduceOperandsNaNDeltaPass(TestRunner &Test) {
114 errs() << "*** Reducing Operands to NaN...\n";
115 auto ReduceValue = [](Use &Op) -> Value * {
116 Type *Ty = Op->getType();
117 if (!Ty->isFPOrFPVectorTy())
118 return nullptr;
119
120 // Prefer 0.0 or 1.0 over NaN.
121 //
122 // TODO: Preferring NaN may make more sense because FP operations are more
123 // universally foldable.
124 if (match(Op.get(), m_NaN()) || isZeroOrOneFP(Op.get()))
125 return nullptr;
126
127 if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
128 return ConstantVector::getSplat(VT->getElementCount(),
129 ConstantFP::getQNaN(VT->getElementType()));
130 }
131
132 return ConstantFP::getQNaN(Ty);
133 };
134 runDeltaPass(Test, [ReduceValue](Oracle &O, Module &Program) {
135 extractOperandsFromModule(O, Program, ReduceValue);
136 });
137 }
138