1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/FoldUtils.h"
15 
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/Operation.h"
20 
21 using namespace mlir;
22 
23 /// Given an operation, find the parent region that folded constants should be
24 /// inserted into.
25 static Region *
26 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces,
27                    Block *insertionBlock) {
28   while (Region *region = insertionBlock->getParent()) {
29     // Insert in this region for any of the following scenarios:
30     //  * The parent is unregistered, or is known to be isolated from above.
31     //  * The parent is a top-level operation.
32     auto *parentOp = region->getParentOp();
33     if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
34         !parentOp->getBlock())
35       return region;
36 
37     // Otherwise, check if this region is a desired insertion region.
38     auto *interface = interfaces.getInterfaceFor(parentOp);
39     if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
40       return region;
41 
42     // Traverse up the parent looking for an insertion region.
43     insertionBlock = parentOp->getBlock();
44   }
45   llvm_unreachable("expected valid insertion region");
46 }
47 
48 /// A utility function used to materialize a constant for a given attribute and
49 /// type. On success, a valid constant value is returned. Otherwise, null is
50 /// returned
51 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
52                                       Attribute value, Type type,
53                                       Location loc) {
54   auto insertPt = builder.getInsertionPoint();
55   (void)insertPt;
56 
57   // Ask the dialect to materialize a constant operation for this value.
58   if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
59     assert(insertPt == builder.getInsertionPoint());
60     assert(matchPattern(constOp, m_Constant()));
61     return constOp;
62   }
63 
64   // TODO: To facilitate splitting the std dialect (PR48490), have a special
65   // case for falling back to std.constant. Eventually, we will have separate
66   // ops tensor.constant, int.constant, float.constant, etc. that live in their
67   // respective dialects, which will allow each dialect to implement the
68   // materializeConstant hook above.
69   //
70   // The special case is needed because in the interim state while we are
71   // splitting out those dialects from std, the std dialect depends on the
72   // tensor dialect, which makes it impossible for the tensor dialect to use
73   // std.constant (it would be a cyclic dependency) as part of its
74   // materializeConstant hook.
75   //
76   // If the dialect is unable to materialize a constant, check to see if the
77   // standard constant can be used.
78   if (ConstantOp::isBuildableWith(value, type))
79     return builder.create<ConstantOp>(loc, type, value);
80   return nullptr;
81 }
82 
83 //===----------------------------------------------------------------------===//
84 // OperationFolder
85 //===----------------------------------------------------------------------===//
86 
87 /// Scan the specified region for constants that can be used in folding,
88 /// moving them to the entry block (or any custom insertion location specified
89 /// by shouldMaterializeInto), and add them to our known-constants table.
90 void OperationFolder::processExistingConstants(Region &region) {
91   if (region.empty())
92     return;
93 
94   // March the constant insertion point forward, moving all constants to the
95   // top of the block, but keeping them in their order of discovery.
96   Region *insertRegion = getInsertionRegion(interfaces, &region.front());
97   auto &uniquedConstants = foldScopes[insertRegion];
98 
99   Block &insertBlock = insertRegion->front();
100   Block::iterator constantIterator = insertBlock.begin();
101 
102   // Process each constant that we discover in this region.
103   auto processConstant = [&](Operation *op, Attribute value) {
104     assert(op->getNumResults() == 1 && "constants have one result");
105     // Check to see if we already have an instance of this constant.
106     Operation *&constOp = uniquedConstants[std::make_tuple(
107         op->getDialect(), value, op->getResult(0).getType())];
108 
109     // If we already have an instance of this constant, CSE/delete this one as
110     // we go.
111     if (constOp) {
112       if (constantIterator == Block::iterator(op))
113         ++constantIterator; // Don't invalidate our iterator when scanning.
114       op->getResult(0).replaceAllUsesWith(constOp->getResult(0));
115       op->erase();
116       return;
117     }
118 
119     // Otherwise, remember that we have this constant.
120     constOp = op;
121     referencedDialects[op].push_back(op->getDialect());
122 
123     // If the constant isn't already at the insertion point then move it up.
124     if (constantIterator != Block::iterator(op))
125       op->moveBefore(&insertBlock, constantIterator);
126     else
127       ++constantIterator; // It was pointing at the constant.
128   };
129 
130   // Collect all the constants for this region of isolation or insertion (as
131   // specified by the shouldMaterializeInto hook).  Collect any subregions of
132   // isolation/constant insertion for subsequent processing.
133   SmallVector<Operation *> insertionSubregionOps;
134   region.walk<WalkOrder::PreOrder>([&](Operation *op) {
135     // If this is a constant, process it.
136     Attribute value;
137     if (matchPattern(op, m_Constant(&value))) {
138       processConstant(op, value);
139       // We may have deleted the operation, don't check it for regions.
140       return WalkResult::skip();
141     }
142 
143     // If the operation has regions and is isolated, don't recurse into it.
144     if (op->getNumRegions() != 0) {
145       auto hasDifferentInsertRegion = [&](Region &region) {
146         return !region.empty() &&
147                getInsertionRegion(interfaces, &region.front()) != insertRegion;
148       };
149       if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) {
150         insertionSubregionOps.push_back(op);
151         return WalkResult::skip();
152       }
153     }
154 
155     // Otherwise keep going.
156     return WalkResult::advance();
157   });
158 
159   // Process regions in any isolated ops separately.
160   for (Operation *subregionOps : insertionSubregionOps) {
161     for (Region &region : subregionOps->getRegions())
162       processExistingConstants(region);
163   }
164 }
165 
166 LogicalResult OperationFolder::tryToFold(
167     Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
168     function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
169   if (inPlaceUpdate)
170     *inPlaceUpdate = false;
171 
172   // If this is a unique'd constant, return failure as we know that it has
173   // already been folded.
174   if (referencedDialects.count(op))
175     return failure();
176 
177   // Try to fold the operation.
178   SmallVector<Value, 8> results;
179   OpBuilder builder(op);
180   if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
181     return failure();
182 
183   // Check to see if the operation was just updated in place.
184   if (results.empty()) {
185     if (inPlaceUpdate)
186       *inPlaceUpdate = true;
187     return success();
188   }
189 
190   // Constant folding succeeded. We will start replacing this op's uses and
191   // erase this op. Invoke the callback provided by the caller to perform any
192   // pre-replacement action.
193   if (preReplaceAction)
194     preReplaceAction(op);
195 
196   // Replace all of the result values and erase the operation.
197   for (unsigned i = 0, e = results.size(); i != e; ++i)
198     op->getResult(i).replaceAllUsesWith(results[i]);
199   op->erase();
200   return success();
201 }
202 
203 /// Notifies that the given constant `op` should be remove from this
204 /// OperationFolder's internal bookkeeping.
205 void OperationFolder::notifyRemoval(Operation *op) {
206   // Check to see if this operation is uniqued within the folder.
207   auto it = referencedDialects.find(op);
208   if (it == referencedDialects.end())
209     return;
210 
211   // Get the constant value for this operation, this is the value that was used
212   // to unique the operation internally.
213   Attribute constValue;
214   matchPattern(op, m_Constant(&constValue));
215   assert(constValue);
216 
217   // Get the constant map that this operation was uniqued in.
218   auto &uniquedConstants =
219       foldScopes[getInsertionRegion(interfaces, op->getBlock())];
220 
221   // Erase all of the references to this operation.
222   auto type = op->getResult(0).getType();
223   for (auto *dialect : it->second)
224     uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
225   referencedDialects.erase(it);
226 }
227 
228 /// Clear out any constants cached inside of the folder.
229 void OperationFolder::clear() {
230   foldScopes.clear();
231   referencedDialects.clear();
232 }
233 
234 /// Get or create a constant using the given builder. On success this returns
235 /// the constant operation, nullptr otherwise.
236 Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
237                                            Attribute value, Type type,
238                                            Location loc) {
239   OpBuilder::InsertionGuard foldGuard(builder);
240 
241   // Use the builder insertion block to find an insertion point for the
242   // constant.
243   auto *insertRegion =
244       getInsertionRegion(interfaces, builder.getInsertionBlock());
245   auto &entry = insertRegion->front();
246   builder.setInsertionPoint(&entry, entry.begin());
247 
248   // Get the constant map for the insertion region of this operation.
249   auto &uniquedConstants = foldScopes[insertRegion];
250   Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
251                                               builder, value, type, loc);
252   return constOp ? constOp->getResult(0) : Value();
253 }
254 
255 /// Tries to perform folding on the given `op`. If successful, populates
256 /// `results` with the results of the folding.
257 LogicalResult OperationFolder::tryToFold(
258     OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
259     function_ref<void(Operation *)> processGeneratedConstants) {
260   SmallVector<Attribute, 8> operandConstants;
261   SmallVector<OpFoldResult, 8> foldResults;
262 
263   // If this is a commutative operation, move constants to be trailing operands.
264   if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
265     std::stable_partition(
266         op->getOpOperands().begin(), op->getOpOperands().end(),
267         [&](OpOperand &O) { return !matchPattern(O.get(), m_Constant()); });
268   }
269 
270   // Check to see if any operands to the operation is constant and whether
271   // the operation knows how to constant fold itself.
272   operandConstants.assign(op->getNumOperands(), Attribute());
273   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
274     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
275 
276   // Attempt to constant fold the operation.
277   if (failed(op->fold(operandConstants, foldResults)))
278     return failure();
279 
280   // Check to see if the operation was just updated in place.
281   if (foldResults.empty())
282     return success();
283   assert(foldResults.size() == op->getNumResults());
284 
285   // Create a builder to insert new operations into the entry block of the
286   // insertion region.
287   auto *insertRegion =
288       getInsertionRegion(interfaces, builder.getInsertionBlock());
289   auto &entry = insertRegion->front();
290   OpBuilder::InsertionGuard foldGuard(builder);
291   builder.setInsertionPoint(&entry, entry.begin());
292 
293   // Get the constant map for the insertion region of this operation.
294   auto &uniquedConstants = foldScopes[insertRegion];
295 
296   // Create the result constants and replace the results.
297   auto *dialect = op->getDialect();
298   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
299     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
300 
301     // Check if the result was an SSA value.
302     if (auto repl = foldResults[i].dyn_cast<Value>()) {
303       if (repl.getType() != op->getResult(i).getType())
304         return failure();
305       results.emplace_back(repl);
306       continue;
307     }
308 
309     // Check to see if there is a canonicalized version of this constant.
310     auto res = op->getResult(i);
311     Attribute attrRepl = foldResults[i].get<Attribute>();
312     if (auto *constOp =
313             tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
314                                    res.getType(), op->getLoc())) {
315       results.push_back(constOp->getResult(0));
316       continue;
317     }
318     // If materialization fails, cleanup any operations generated for the
319     // previous results and return failure.
320     for (Operation &op : llvm::make_early_inc_range(
321              llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
322       notifyRemoval(&op);
323       op.erase();
324     }
325     return failure();
326   }
327 
328   // Process any newly generated operations.
329   if (processGeneratedConstants) {
330     for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
331       processGeneratedConstants(&*i);
332   }
333 
334   return success();
335 }
336 
337 /// Try to get or create a new constant entry. On success this returns the
338 /// constant operation value, nullptr otherwise.
339 Operation *OperationFolder::tryGetOrCreateConstant(
340     ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
341     Attribute value, Type type, Location loc) {
342   // Check if an existing mapping already exists.
343   auto constKey = std::make_tuple(dialect, value, type);
344   Operation *&constOp = uniquedConstants[constKey];
345   if (constOp)
346     return constOp;
347 
348   // If one doesn't exist, try to materialize one.
349   if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
350     return nullptr;
351 
352   // Check to see if the generated constant is in the expected dialect.
353   auto *newDialect = constOp->getDialect();
354   if (newDialect == dialect) {
355     referencedDialects[constOp].push_back(dialect);
356     return constOp;
357   }
358 
359   // If it isn't, then we also need to make sure that the mapping for the new
360   // dialect is valid.
361   auto newKey = std::make_tuple(newDialect, value, type);
362 
363   // If an existing operation in the new dialect already exists, delete the
364   // materialized operation in favor of the existing one.
365   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
366     constOp->erase();
367     referencedDialects[existingOp].push_back(dialect);
368     return constOp = existingOp;
369   }
370 
371   // Otherwise, update the new dialect to the materialized operation.
372   referencedDialects[constOp].assign({dialect, newDialect});
373   auto newIt = uniquedConstants.insert({newKey, constOp});
374   return newIt.first->second;
375 }
376