1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Transforms/FoldUtils.h"
15 
16 #include "mlir/Dialect/StandardOps/IR/Ops.h"
17 #include "mlir/IR/Builders.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/Operation.h"
20 
21 using namespace mlir;
22 
23 /// Given an operation, find the parent region that folded constants should be
24 /// inserted into.
25 static Region *getInsertionRegion(
26     DialectInterfaceCollection<OpFolderDialectInterface> &interfaces,
27     Block *insertionBlock) {
28   while (Region *region = insertionBlock->getParent()) {
29     // Insert in this region for any of the following scenarios:
30     //  * The parent is unregistered, or is known to be isolated from above.
31     //  * The parent is a top-level operation.
32     auto *parentOp = region->getParentOp();
33     if (!parentOp->isRegistered() || parentOp->isKnownIsolatedFromAbove() ||
34         !parentOp->getBlock())
35       return region;
36 
37     // Otherwise, check if this region is a desired insertion region.
38     auto *interface = interfaces.getInterfaceFor(parentOp);
39     if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
40       return region;
41 
42     // Traverse up the parent looking for an insertion region.
43     insertionBlock = parentOp->getBlock();
44   }
45   llvm_unreachable("expected valid insertion region");
46 }
47 
48 /// A utility function used to materialize a constant for a given attribute and
49 /// type. On success, a valid constant value is returned. Otherwise, null is
50 /// returned
51 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
52                                       Attribute value, Type type,
53                                       Location loc) {
54   auto insertPt = builder.getInsertionPoint();
55   (void)insertPt;
56 
57   // Ask the dialect to materialize a constant operation for this value.
58   if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
59     assert(insertPt == builder.getInsertionPoint());
60     assert(matchPattern(constOp, m_Constant()));
61     return constOp;
62   }
63 
64   // If the dialect is unable to materialize a constant, check to see if the
65   // standard constant can be used.
66   if (ConstantOp::isBuildableWith(value, type))
67     return builder.create<ConstantOp>(loc, type, value);
68   return nullptr;
69 }
70 
71 //===----------------------------------------------------------------------===//
72 // OperationFolder
73 //===----------------------------------------------------------------------===//
74 
75 LogicalResult OperationFolder::tryToFold(
76     Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
77     function_ref<void(Operation *)> preReplaceAction) {
78   // If this is a unique'd constant, return failure as we know that it has
79   // already been folded.
80   if (referencedDialects.count(op))
81     return failure();
82 
83   // Try to fold the operation.
84   SmallVector<Value, 8> results;
85   OpBuilder builder(op);
86   if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
87     return failure();
88 
89   // Check to see if the operation was just updated in place.
90   if (results.empty())
91     return success();
92 
93   // Constant folding succeeded. We will start replacing this op's uses and
94   // erase this op. Invoke the callback provided by the caller to perform any
95   // pre-replacement action.
96   if (preReplaceAction)
97     preReplaceAction(op);
98 
99   // Replace all of the result values and erase the operation.
100   for (unsigned i = 0, e = results.size(); i != e; ++i)
101     op->getResult(i).replaceAllUsesWith(results[i]);
102   op->erase();
103   return success();
104 }
105 
106 /// Notifies that the given constant `op` should be remove from this
107 /// OperationFolder's internal bookkeeping.
108 void OperationFolder::notifyRemoval(Operation *op) {
109   // Check to see if this operation is uniqued within the folder.
110   auto it = referencedDialects.find(op);
111   if (it == referencedDialects.end())
112     return;
113 
114   // Get the constant value for this operation, this is the value that was used
115   // to unique the operation internally.
116   Attribute constValue;
117   matchPattern(op, m_Constant(&constValue));
118   assert(constValue);
119 
120   // Get the constant map that this operation was uniqued in.
121   auto &uniquedConstants =
122       foldScopes[getInsertionRegion(interfaces, op->getBlock())];
123 
124   // Erase all of the references to this operation.
125   auto type = op->getResult(0).getType();
126   for (auto *dialect : it->second)
127     uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
128   referencedDialects.erase(it);
129 }
130 
131 /// Clear out any constants cached inside of the folder.
132 void OperationFolder::clear() {
133   foldScopes.clear();
134   referencedDialects.clear();
135 }
136 
137 /// Tries to perform folding on the given `op`. If successful, populates
138 /// `results` with the results of the folding.
139 LogicalResult OperationFolder::tryToFold(
140     OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
141     function_ref<void(Operation *)> processGeneratedConstants) {
142   SmallVector<Attribute, 8> operandConstants;
143   SmallVector<OpFoldResult, 8> foldResults;
144 
145   // If this is a commutative operation, move constants to be trailing operands.
146   if (op->getNumOperands() >= 2 && op->isCommutative()) {
147     std::stable_partition(
148         op->getOpOperands().begin(), op->getOpOperands().end(),
149         [&](OpOperand &O) { return !matchPattern(O.get(), m_Constant()); });
150   }
151 
152   // Check to see if any operands to the operation is constant and whether
153   // the operation knows how to constant fold itself.
154   operandConstants.assign(op->getNumOperands(), Attribute());
155   for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
156     matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
157 
158   // Attempt to constant fold the operation.
159   if (failed(op->fold(operandConstants, foldResults)))
160     return failure();
161 
162   // Check to see if the operation was just updated in place.
163   if (foldResults.empty())
164     return success();
165   assert(foldResults.size() == op->getNumResults());
166 
167   // Create a builder to insert new operations into the entry block of the
168   // insertion region.
169   auto *insertRegion =
170       getInsertionRegion(interfaces, builder.getInsertionBlock());
171   auto &entry = insertRegion->front();
172   OpBuilder::InsertionGuard foldGuard(builder);
173   builder.setInsertionPoint(&entry, entry.begin());
174 
175   // Get the constant map for the insertion region of this operation.
176   auto &uniquedConstants = foldScopes[insertRegion];
177 
178   // Create the result constants and replace the results.
179   auto *dialect = op->getDialect();
180   for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
181     assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
182 
183     // Check if the result was an SSA value.
184     if (auto repl = foldResults[i].dyn_cast<Value>()) {
185       results.emplace_back(repl);
186       continue;
187     }
188 
189     // Check to see if there is a canonicalized version of this constant.
190     auto res = op->getResult(i);
191     Attribute attrRepl = foldResults[i].get<Attribute>();
192     if (auto *constOp =
193             tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
194                                    res.getType(), op->getLoc())) {
195       results.push_back(constOp->getResult(0));
196       continue;
197     }
198     // If materialization fails, cleanup any operations generated for the
199     // previous results and return failure.
200     for (Operation &op : llvm::make_early_inc_range(
201              llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
202       notifyRemoval(&op);
203       op.erase();
204     }
205     return failure();
206   }
207 
208   // Process any newly generated operations.
209   if (processGeneratedConstants) {
210     for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
211       processGeneratedConstants(&*i);
212   }
213 
214   return success();
215 }
216 
217 /// Try to get or create a new constant entry. On success this returns the
218 /// constant operation value, nullptr otherwise.
219 Operation *OperationFolder::tryGetOrCreateConstant(
220     ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
221     Attribute value, Type type, Location loc) {
222   // Check if an existing mapping already exists.
223   auto constKey = std::make_tuple(dialect, value, type);
224   auto *&constInst = uniquedConstants[constKey];
225   if (constInst)
226     return constInst;
227 
228   // If one doesn't exist, try to materialize one.
229   if (!(constInst = materializeConstant(dialect, builder, value, type, loc)))
230     return nullptr;
231 
232   // Check to see if the generated constant is in the expected dialect.
233   auto *newDialect = constInst->getDialect();
234   if (newDialect == dialect) {
235     referencedDialects[constInst].push_back(dialect);
236     return constInst;
237   }
238 
239   // If it isn't, then we also need to make sure that the mapping for the new
240   // dialect is valid.
241   auto newKey = std::make_tuple(newDialect, value, type);
242 
243   // If an existing operation in the new dialect already exists, delete the
244   // materialized operation in favor of the existing one.
245   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
246     constInst->erase();
247     referencedDialects[existingOp].push_back(dialect);
248     return constInst = existingOp;
249   }
250 
251   // Otherwise, update the new dialect to the materialized operation.
252   referencedDialects[constInst].assign({dialect, newDialect});
253   auto newIt = uniquedConstants.insert({newKey, constInst});
254   return newIt.first->second;
255 }
256