1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Copyright 2019 The MLIR Authors.
4 //
5 // Licensed under the Apache License, Version 2.0 (the "License");
6 // you may not use this file except in compliance with the License.
7 // You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing, software
12 // distributed under the License is distributed on an "AS IS" BASIS,
13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 // See the License for the specific language governing permissions and
15 // limitations under the License.
16 // =============================================================================
17 //
18 // This file defines various operation fold utilities. These utilities are
19 // intended to be used by passes to unify and simply their logic.
20 //
21 //===----------------------------------------------------------------------===//
22 
23 #include "mlir/Transforms/FoldUtils.h"
24 
25 #include "mlir/Dialect/StandardOps/Ops.h"
26 #include "mlir/IR/Builders.h"
27 #include "mlir/IR/Matchers.h"
28 #include "mlir/IR/Operation.h"
29 
30 using namespace mlir;
31 
32 /// Given an operation, find the parent region that folded constants should be
33 /// inserted into.
34 static Region *getInsertionRegion(Operation *op) {
35   while (Region *region = op->getParentRegion()) {
36     // Insert in this region for any of the following scenarios:
37     //  * The parent is unregistered, or is known to be isolated from above.
38     //  * The parent is a top-level operation.
39     auto *parentOp = region->getParentOp();
40     if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() ||
41         !parentOp->getBlock())
42       return region;
43     // Traverse up the parent looking for an insertion region.
44     op = parentOp;
45   }
46   llvm_unreachable("expected valid insertion region");
47 }
48 
49 /// A utility function used to materialize a constant for a given attribute and
50 /// type. On success, a valid constant value is returned. Otherwise, null is
51 /// returned
52 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
53                                       Attribute value, Type type,
54                                       Location loc) {
55   auto insertPt = builder.getInsertionPoint();
56   (void)insertPt;
57 
58   // Ask the dialect to materialize a constant operation for this value.
59   if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
60     assert(insertPt == builder.getInsertionPoint());
61     assert(matchPattern(constOp, m_Constant(&value)));
62     return constOp;
63   }
64 
65   // If the dialect is unable to materialize a constant, check to see if the
66   // standard constant can be used.
67   if (ConstantOp::isBuildableWith(value, type))
68     return builder.create<ConstantOp>(loc, type, value);
69   return nullptr;
70 }
71 
72 //===----------------------------------------------------------------------===//
73 // OperationFolder
74 //===----------------------------------------------------------------------===//
75 
76 LogicalResult OperationFolder::tryToFold(
77     Operation *op,
78     llvm::function_ref<void(Operation *)> processGeneratedConstants,
79     llvm::function_ref<void(Operation *)> preReplaceAction) {
80   // If this is a unique'd constant, return failure as we know that it has
81   // already been folded.
82   if (referencedDialects.count(op))
83     return failure();
84 
85   // Try to fold the operation.
86   SmallVector<Value *, 8> results;
87   if (failed(tryToFold(op, results, processGeneratedConstants)))
88     return failure();
89 
90   // Constant folding succeeded. We will start replacing this op's uses and
91   // eventually erase this op. Invoke the callback provided by the caller to
92   // perform any pre-replacement action.
93   if (preReplaceAction)
94     preReplaceAction(op);
95 
96   // Check to see if the operation was just updated in place.
97   if (results.empty())
98     return success();
99 
100   // Otherwise, replace all of the result values and erase the operation.
101   for (unsigned i = 0, e = results.size(); i != e; ++i)
102     op->getResult(i)->replaceAllUsesWith(results[i]);
103   op->erase();
104   return success();
105 }
106 
107 /// Notifies that the given constant `op` should be remove from this
108 /// OperationFolder's internal bookkeeping.
109 void OperationFolder::notifyRemoval(Operation *op) {
110   // Check to see if this operation is uniqued within the folder.
111   auto it = referencedDialects.find(op);
112   if (it == referencedDialects.end())
113     return;
114 
115   // Get the constant value for this operation, this is the value that was used
116   // to unique the operation internally.
117   Attribute constValue;
118   matchPattern(op, m_Constant(&constValue));
119   assert(constValue);
120 
121   // Get the constant map that this operation was uniqued in.
122   auto &uniquedConstants = foldScopes[getInsertionRegion(op)];
123 
124   // Erase all of the references to this operation.
125   auto type = op->getResult(0)->getType();
126   for (auto *dialect : it->second)
127     uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
128   referencedDialects.erase(it);
129 }
130 
131 /// Tries to perform folding on the given `op`. If successful, populates
132 /// `results` with the results of the folding.
133 LogicalResult OperationFolder::tryToFold(
134     Operation *op, SmallVectorImpl<Value *> &results,
135     llvm::function_ref<void(Operation *)> processGeneratedConstants) {
136   SmallVector<Attribute, 8> operandConstants;
137   SmallVector<OpFoldResult, 8> foldResults;
138 
139   // Check to see if any operands to the operation is constant and whether
140   // the operation knows how to constant fold itself.
141   operandConstants.assign(op->getNumOperands(), Attribute());
142   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
143     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
144 
145   // If this is a commutative binary operation with a constant on the left
146   // side move it to the right side.
147   if (operandConstants.size() == 2 && operandConstants[0] &&
148       !operandConstants[1] && op->isCommutative()) {
149     std::swap(op->getOpOperand(0), op->getOpOperand(1));
150     std::swap(operandConstants[0], operandConstants[1]);
151   }
152 
153   // Attempt to constant fold the operation.
154   if (failed(op->fold(operandConstants, foldResults)))
155     return failure();
156 
157   // Check to see if the operation was just updated in place.
158   if (foldResults.empty())
159     return success();
160   assert(foldResults.size() == op->getNumResults());
161 
162   // Create a builder to insert new operations into the entry block of the
163   // insertion region.
164   auto *insertionRegion = getInsertionRegion(op);
165   auto &entry = insertionRegion->front();
166   OpBuilder builder(&entry, entry.begin());
167 
168   // Get the constant map for the insertion region of this operation.
169   auto &uniquedConstants = foldScopes[insertionRegion];
170 
171   // Create the result constants and replace the results.
172   auto *dialect = op->getDialect();
173   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
174     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
175 
176     // Check if the result was an SSA value.
177     if (auto *repl = foldResults[i].dyn_cast<Value *>()) {
178       results.emplace_back(repl);
179       continue;
180     }
181 
182     // Check to see if there is a canonicalized version of this constant.
183     auto *res = op->getResult(i);
184     Attribute attrRepl = foldResults[i].get<Attribute>();
185     if (auto *constOp =
186             tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
187                                    res->getType(), op->getLoc())) {
188       results.push_back(constOp->getResult(0));
189       continue;
190     }
191     // If materialization fails, cleanup any operations generated for the
192     // previous results and return failure.
193     for (Operation &op : llvm::make_early_inc_range(
194              llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
195       notifyRemoval(&op);
196       op.erase();
197     }
198     return failure();
199   }
200 
201   // Process any newly generated operations.
202   if (processGeneratedConstants) {
203     for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
204       processGeneratedConstants(&*i);
205   }
206 
207   return success();
208 }
209 
210 /// Try to get or create a new constant entry. On success this returns the
211 /// constant operation value, nullptr otherwise.
212 Operation *OperationFolder::tryGetOrCreateConstant(
213     ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
214     Attribute value, Type type, Location loc) {
215   // Check if an existing mapping already exists.
216   auto constKey = std::make_tuple(dialect, value, type);
217   auto *&constInst = uniquedConstants[constKey];
218   if (constInst)
219     return constInst;
220 
221   // If one doesn't exist, try to materialize one.
222   if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
223     return nullptr;
224 
225   // Check to see if the generated constant is in the expected dialect.
226   auto *newDialect = constInst->getDialect();
227   if (newDialect == dialect) {
228     referencedDialects[constInst].push_back(dialect);
229     return constInst;
230   }
231 
232   // If it isn't, then we also need to make sure that the mapping for the new
233   // dialect is valid.
234   auto newKey = std::make_tuple(newDialect, value, type);
235 
236   // If an existing operation in the new dialect already exists, delete the
237   // materialized operation in favor of the existing one.
238   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
239     constInst->erase();
240     referencedDialects[existingOp].push_back(dialect);
241     return constInst = existingOp;
242   }
243 
244   // Otherwise, update the new dialect to the materialized operation.
245   referencedDialects[constInst].assign({dialect, newDialect});
246   auto newIt = uniquedConstants.insert({newKey, constInst});
247   return newIt.first->second;
248 }
249