1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/FoldUtils.h"
15 
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
19 
20 using namespace mlir;
21 
22 /// Given an operation, find the parent region that folded constants should be
23 /// inserted into.
24 static Region *
25 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces,
26                    Block *insertionBlock) {
27   while (Region *region = insertionBlock->getParent()) {
28     // Insert in this region for any of the following scenarios:
29     //  * The parent is unregistered, or is known to be isolated from above.
30     //  * The parent is a top-level operation.
31     auto *parentOp = region->getParentOp();
32     if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33         !parentOp->getBlock())
34       return region;
35 
36     // Otherwise, check if this region is a desired insertion region.
37     auto *interface = interfaces.getInterfaceFor(parentOp);
38     if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
39       return region;
40 
41     // Traverse up the parent looking for an insertion region.
42     insertionBlock = parentOp->getBlock();
43   }
44   llvm_unreachable("expected valid insertion region");
45 }
46 
47 /// A utility function used to materialize a constant for a given attribute and
48 /// type. On success, a valid constant value is returned. Otherwise, null is
49 /// returned
50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
51                                       Attribute value, Type type,
52                                       Location loc) {
53   auto insertPt = builder.getInsertionPoint();
54   (void)insertPt;
55 
56   // Ask the dialect to materialize a constant operation for this value.
57   if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
58     assert(insertPt == builder.getInsertionPoint());
59     assert(matchPattern(constOp, m_Constant()));
60     return constOp;
61   }
62 
63   return nullptr;
64 }
65 
66 //===----------------------------------------------------------------------===//
67 // OperationFolder
68 //===----------------------------------------------------------------------===//
69 
70 LogicalResult OperationFolder::tryToFold(
71     Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
72     function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
73   if (inPlaceUpdate)
74     *inPlaceUpdate = false;
75 
76   // If this is a unique'd constant, return failure as we know that it has
77   // already been folded.
78   if (isFolderOwnedConstant(op)) {
79     // Check to see if we should rehoist, i.e. if a non-constant operation was
80     // inserted before this one.
81     Block *opBlock = op->getBlock();
82     if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
83       op->moveBefore(&opBlock->front());
84     return failure();
85   }
86 
87   // Try to fold the operation.
88   SmallVector<Value, 8> results;
89   OpBuilder builder(op);
90   if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
91     return failure();
92 
93   // Check to see if the operation was just updated in place.
94   if (results.empty()) {
95     if (inPlaceUpdate)
96       *inPlaceUpdate = true;
97     return success();
98   }
99 
100   // Constant folding succeeded. We will start replacing this op's uses and
101   // erase this op. Invoke the callback provided by the caller to perform any
102   // pre-replacement action.
103   if (preReplaceAction)
104     preReplaceAction(op);
105 
106   // Replace all of the result values and erase the operation.
107   for (unsigned i = 0, e = results.size(); i != e; ++i)
108     op->getResult(i).replaceAllUsesWith(results[i]);
109   op->erase();
110   return success();
111 }
112 
113 bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
114   Block *opBlock = op->getBlock();
115 
116   // If this is a constant we unique'd, we don't need to insert, but we can
117   // check to see if we should rehoist it.
118   if (isFolderOwnedConstant(op)) {
119     if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
120       op->moveBefore(&opBlock->front());
121     return true;
122   }
123 
124   // Get the constant value of the op if necessary.
125   if (!constValue) {
126     matchPattern(op, m_Constant(&constValue));
127     assert(constValue && "expected `op` to be a constant");
128   } else {
129     // Ensure that the provided constant was actually correct.
130 #ifndef NDEBUG
131     Attribute expectedValue;
132     matchPattern(op, m_Constant(&expectedValue));
133     assert(
134         expectedValue == constValue &&
135         "provided constant value was not the expected value of the constant");
136 #endif
137   }
138 
139   // Check for an existing constant operation for the attribute value.
140   Region *insertRegion = getInsertionRegion(interfaces, opBlock);
141   auto &uniquedConstants = foldScopes[insertRegion];
142   Operation *&folderConstOp = uniquedConstants[std::make_tuple(
143       op->getDialect(), constValue, *op->result_type_begin())];
144 
145   // If there is an existing constant, replace `op`.
146   if (folderConstOp) {
147     op->replaceAllUsesWith(folderConstOp);
148     op->erase();
149     return false;
150   }
151 
152   // Otherwise, we insert `op`. If `op` is in the insertion block and is either
153   // already at the front of the block, or the previous operation is already a
154   // constant we unique'd (i.e. one we inserted), then we don't need to do
155   // anything. Otherwise, we move the constant to the insertion block.
156   Block *insertBlock = &insertRegion->front();
157   if (opBlock != insertBlock || (&insertBlock->front() != op &&
158                                  !isFolderOwnedConstant(op->getPrevNode())))
159     op->moveBefore(&insertBlock->front());
160 
161   folderConstOp = op;
162   referencedDialects[op].push_back(op->getDialect());
163   return true;
164 }
165 
166 /// Notifies that the given constant `op` should be remove from this
167 /// OperationFolder's internal bookkeeping.
168 void OperationFolder::notifyRemoval(Operation *op) {
169   // Check to see if this operation is uniqued within the folder.
170   auto it = referencedDialects.find(op);
171   if (it == referencedDialects.end())
172     return;
173 
174   // Get the constant value for this operation, this is the value that was used
175   // to unique the operation internally.
176   Attribute constValue;
177   matchPattern(op, m_Constant(&constValue));
178   assert(constValue);
179 
180   // Get the constant map that this operation was uniqued in.
181   auto &uniquedConstants =
182       foldScopes[getInsertionRegion(interfaces, op->getBlock())];
183 
184   // Erase all of the references to this operation.
185   auto type = op->getResult(0).getType();
186   for (auto *dialect : it->second)
187     uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
188   referencedDialects.erase(it);
189 }
190 
191 /// Clear out any constants cached inside of the folder.
192 void OperationFolder::clear() {
193   foldScopes.clear();
194   referencedDialects.clear();
195 }
196 
197 /// Get or create a constant using the given builder. On success this returns
198 /// the constant operation, nullptr otherwise.
199 Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
200                                            Attribute value, Type type,
201                                            Location loc) {
202   OpBuilder::InsertionGuard foldGuard(builder);
203 
204   // Use the builder insertion block to find an insertion point for the
205   // constant.
206   auto *insertRegion =
207       getInsertionRegion(interfaces, builder.getInsertionBlock());
208   auto &entry = insertRegion->front();
209   builder.setInsertionPoint(&entry, entry.begin());
210 
211   // Get the constant map for the insertion region of this operation.
212   auto &uniquedConstants = foldScopes[insertRegion];
213   Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
214                                               builder, value, type, loc);
215   return constOp ? constOp->getResult(0) : Value();
216 }
217 
218 bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
219   return referencedDialects.count(op);
220 }
221 
222 /// Tries to perform folding on the given `op`. If successful, populates
223 /// `results` with the results of the folding.
224 LogicalResult OperationFolder::tryToFold(
225     OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
226     function_ref<void(Operation *)> processGeneratedConstants) {
227   SmallVector<Attribute, 8> operandConstants;
228 
229   // If this is a commutative operation, move constants to be trailing operands.
230   bool updatedOpOperands = false;
231   if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
232     auto isNonConstant = [&](OpOperand &o) {
233       return !matchPattern(o.get(), m_Constant());
234     };
235     auto *firstConstantIt =
236         llvm::find_if_not(op->getOpOperands(), isNonConstant);
237     auto *newConstantIt = std::stable_partition(
238         firstConstantIt, op->getOpOperands().end(), isNonConstant);
239 
240     // Remember if we actually moved anything.
241     updatedOpOperands = firstConstantIt != newConstantIt;
242   }
243 
244   // Check to see if any operands to the operation is constant and whether
245   // the operation knows how to constant fold itself.
246   operandConstants.assign(op->getNumOperands(), Attribute());
247   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
248     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
249 
250   // Attempt to constant fold the operation. If we failed, check to see if we at
251   // least updated the operands of the operation. We treat this as an in-place
252   // fold.
253   SmallVector<OpFoldResult, 8> foldResults;
254   if (failed(op->fold(operandConstants, foldResults)) ||
255       failed(processFoldResults(builder, op, results, foldResults,
256                                 processGeneratedConstants)))
257     return success(updatedOpOperands);
258   return success();
259 }
260 
261 LogicalResult OperationFolder::processFoldResults(
262     OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
263     ArrayRef<OpFoldResult> foldResults,
264     function_ref<void(Operation *)> processGeneratedConstants) {
265   // Check to see if the operation was just updated in place.
266   if (foldResults.empty())
267     return success();
268   assert(foldResults.size() == op->getNumResults());
269 
270   // Create a builder to insert new operations into the entry block of the
271   // insertion region.
272   auto *insertRegion =
273       getInsertionRegion(interfaces, builder.getInsertionBlock());
274   auto &entry = insertRegion->front();
275   OpBuilder::InsertionGuard foldGuard(builder);
276   builder.setInsertionPoint(&entry, entry.begin());
277 
278   // Get the constant map for the insertion region of this operation.
279   auto &uniquedConstants = foldScopes[insertRegion];
280 
281   // Create the result constants and replace the results.
282   auto *dialect = op->getDialect();
283   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
284     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
285 
286     // Check if the result was an SSA value.
287     if (auto repl = foldResults[i].dyn_cast<Value>()) {
288       if (repl.getType() != op->getResult(i).getType()) {
289         results.clear();
290         return failure();
291       }
292       results.emplace_back(repl);
293       continue;
294     }
295 
296     // Check to see if there is a canonicalized version of this constant.
297     auto res = op->getResult(i);
298     Attribute attrRepl = foldResults[i].get<Attribute>();
299     if (auto *constOp =
300             tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
301                                    res.getType(), op->getLoc())) {
302       // Ensure that this constant dominates the operation we are replacing it
303       // with. This may not automatically happen if the operation being folded
304       // was inserted before the constant within the insertion block.
305       Block *opBlock = op->getBlock();
306       if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
307         constOp->moveBefore(&opBlock->front());
308 
309       results.push_back(constOp->getResult(0));
310       continue;
311     }
312     // If materialization fails, cleanup any operations generated for the
313     // previous results and return failure.
314     for (Operation &op : llvm::make_early_inc_range(
315              llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
316       notifyRemoval(&op);
317       op.erase();
318     }
319     results.clear();
320     return failure();
321   }
322 
323   // Process any newly generated operations.
324   if (processGeneratedConstants) {
325     for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
326       processGeneratedConstants(&*i);
327   }
328 
329   return success();
330 }
331 
332 /// Try to get or create a new constant entry. On success this returns the
333 /// constant operation value, nullptr otherwise.
334 Operation *OperationFolder::tryGetOrCreateConstant(
335     ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
336     Attribute value, Type type, Location loc) {
337   // Check if an existing mapping already exists.
338   auto constKey = std::make_tuple(dialect, value, type);
339   Operation *&constOp = uniquedConstants[constKey];
340   if (constOp)
341     return constOp;
342 
343   // If one doesn't exist, try to materialize one.
344   if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
345     return nullptr;
346 
347   // Check to see if the generated constant is in the expected dialect.
348   auto *newDialect = constOp->getDialect();
349   if (newDialect == dialect) {
350     referencedDialects[constOp].push_back(dialect);
351     return constOp;
352   }
353 
354   // If it isn't, then we also need to make sure that the mapping for the new
355   // dialect is valid.
356   auto newKey = std::make_tuple(newDialect, value, type);
357 
358   // If an existing operation in the new dialect already exists, delete the
359   // materialized operation in favor of the existing one.
360   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
361     constOp->erase();
362     referencedDialects[existingOp].push_back(dialect);
363     return constOp = existingOp;
364   }
365 
366   // Otherwise, update the new dialect to the materialized operation.
367   referencedDialects[constOp].assign({dialect, newDialect});
368   auto newIt = uniquedConstants.insert({newKey, constOp});
369   return newIt.first->second;
370 }
371