1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines various operation fold utilities. These utilities are
10 // intended to be used by passes to unify and simply their logic.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Transforms/FoldUtils.h"
15
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/Matchers.h"
18 #include "mlir/IR/Operation.h"
19
20 using namespace mlir;
21
22 /// Given an operation, find the parent region that folded constants should be
23 /// inserted into.
24 static Region *
getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> & interfaces,Block * insertionBlock)25 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces,
26 Block *insertionBlock) {
27 while (Region *region = insertionBlock->getParent()) {
28 // Insert in this region for any of the following scenarios:
29 // * The parent is unregistered, or is known to be isolated from above.
30 // * The parent is a top-level operation.
31 auto *parentOp = region->getParentOp();
32 if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() ||
33 !parentOp->getBlock())
34 return region;
35
36 // Otherwise, check if this region is a desired insertion region.
37 auto *interface = interfaces.getInterfaceFor(parentOp);
38 if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region)))
39 return region;
40
41 // Traverse up the parent looking for an insertion region.
42 insertionBlock = parentOp->getBlock();
43 }
44 llvm_unreachable("expected valid insertion region");
45 }
46
47 /// A utility function used to materialize a constant for a given attribute and
48 /// type. On success, a valid constant value is returned. Otherwise, null is
49 /// returned
materializeConstant(Dialect * dialect,OpBuilder & builder,Attribute value,Type type,Location loc)50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
51 Attribute value, Type type,
52 Location loc) {
53 auto insertPt = builder.getInsertionPoint();
54 (void)insertPt;
55
56 // Ask the dialect to materialize a constant operation for this value.
57 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) {
58 assert(insertPt == builder.getInsertionPoint());
59 assert(matchPattern(constOp, m_Constant()));
60 return constOp;
61 }
62
63 return nullptr;
64 }
65
66 //===----------------------------------------------------------------------===//
67 // OperationFolder
68 //===----------------------------------------------------------------------===//
69
tryToFold(Operation * op,function_ref<void (Operation *)> processGeneratedConstants,function_ref<void (Operation *)> preReplaceAction,bool * inPlaceUpdate)70 LogicalResult OperationFolder::tryToFold(
71 Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
72 function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
73 if (inPlaceUpdate)
74 *inPlaceUpdate = false;
75
76 // If this is a unique'd constant, return failure as we know that it has
77 // already been folded.
78 if (isFolderOwnedConstant(op)) {
79 // Check to see if we should rehoist, i.e. if a non-constant operation was
80 // inserted before this one.
81 Block *opBlock = op->getBlock();
82 if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
83 op->moveBefore(&opBlock->front());
84 return failure();
85 }
86
87 // Try to fold the operation.
88 SmallVector<Value, 8> results;
89 OpBuilder builder(op);
90 if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
91 return failure();
92
93 // Check to see if the operation was just updated in place.
94 if (results.empty()) {
95 if (inPlaceUpdate)
96 *inPlaceUpdate = true;
97 return success();
98 }
99
100 // Constant folding succeeded. We will start replacing this op's uses and
101 // erase this op. Invoke the callback provided by the caller to perform any
102 // pre-replacement action.
103 if (preReplaceAction)
104 preReplaceAction(op);
105
106 // Replace all of the result values and erase the operation.
107 for (unsigned i = 0, e = results.size(); i != e; ++i)
108 op->getResult(i).replaceAllUsesWith(results[i]);
109 op->erase();
110 return success();
111 }
112
insertKnownConstant(Operation * op,Attribute constValue)113 bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
114 Block *opBlock = op->getBlock();
115
116 // If this is a constant we unique'd, we don't need to insert, but we can
117 // check to see if we should rehoist it.
118 if (isFolderOwnedConstant(op)) {
119 if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode()))
120 op->moveBefore(&opBlock->front());
121 return true;
122 }
123
124 // Get the constant value of the op if necessary.
125 if (!constValue) {
126 matchPattern(op, m_Constant(&constValue));
127 assert(constValue && "expected `op` to be a constant");
128 } else {
129 // Ensure that the provided constant was actually correct.
130 #ifndef NDEBUG
131 Attribute expectedValue;
132 matchPattern(op, m_Constant(&expectedValue));
133 assert(
134 expectedValue == constValue &&
135 "provided constant value was not the expected value of the constant");
136 #endif
137 }
138
139 // Check for an existing constant operation for the attribute value.
140 Region *insertRegion = getInsertionRegion(interfaces, opBlock);
141 auto &uniquedConstants = foldScopes[insertRegion];
142 Operation *&folderConstOp = uniquedConstants[std::make_tuple(
143 op->getDialect(), constValue, *op->result_type_begin())];
144
145 // If there is an existing constant, replace `op`.
146 if (folderConstOp) {
147 op->replaceAllUsesWith(folderConstOp);
148 op->erase();
149 return false;
150 }
151
152 // Otherwise, we insert `op`. If `op` is in the insertion block and is either
153 // already at the front of the block, or the previous operation is already a
154 // constant we unique'd (i.e. one we inserted), then we don't need to do
155 // anything. Otherwise, we move the constant to the insertion block.
156 Block *insertBlock = &insertRegion->front();
157 if (opBlock != insertBlock || (&insertBlock->front() != op &&
158 !isFolderOwnedConstant(op->getPrevNode())))
159 op->moveBefore(&insertBlock->front());
160
161 folderConstOp = op;
162 referencedDialects[op].push_back(op->getDialect());
163 return true;
164 }
165
166 /// Notifies that the given constant `op` should be remove from this
167 /// OperationFolder's internal bookkeeping.
notifyRemoval(Operation * op)168 void OperationFolder::notifyRemoval(Operation *op) {
169 // Check to see if this operation is uniqued within the folder.
170 auto it = referencedDialects.find(op);
171 if (it == referencedDialects.end())
172 return;
173
174 // Get the constant value for this operation, this is the value that was used
175 // to unique the operation internally.
176 Attribute constValue;
177 matchPattern(op, m_Constant(&constValue));
178 assert(constValue);
179
180 // Get the constant map that this operation was uniqued in.
181 auto &uniquedConstants =
182 foldScopes[getInsertionRegion(interfaces, op->getBlock())];
183
184 // Erase all of the references to this operation.
185 auto type = op->getResult(0).getType();
186 for (auto *dialect : it->second)
187 uniquedConstants.erase(std::make_tuple(dialect, constValue, type));
188 referencedDialects.erase(it);
189 }
190
191 /// Clear out any constants cached inside of the folder.
clear()192 void OperationFolder::clear() {
193 foldScopes.clear();
194 referencedDialects.clear();
195 }
196
197 /// Get or create a constant using the given builder. On success this returns
198 /// the constant operation, nullptr otherwise.
getOrCreateConstant(OpBuilder & builder,Dialect * dialect,Attribute value,Type type,Location loc)199 Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect,
200 Attribute value, Type type,
201 Location loc) {
202 OpBuilder::InsertionGuard foldGuard(builder);
203
204 // Use the builder insertion block to find an insertion point for the
205 // constant.
206 auto *insertRegion =
207 getInsertionRegion(interfaces, builder.getInsertionBlock());
208 auto &entry = insertRegion->front();
209 builder.setInsertionPoint(&entry, entry.begin());
210
211 // Get the constant map for the insertion region of this operation.
212 auto &uniquedConstants = foldScopes[insertRegion];
213 Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect,
214 builder, value, type, loc);
215 return constOp ? constOp->getResult(0) : Value();
216 }
217
isFolderOwnedConstant(Operation * op) const218 bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
219 return referencedDialects.count(op);
220 }
221
222 /// Tries to perform folding on the given `op`. If successful, populates
223 /// `results` with the results of the folding.
tryToFold(OpBuilder & builder,Operation * op,SmallVectorImpl<Value> & results,function_ref<void (Operation *)> processGeneratedConstants)224 LogicalResult OperationFolder::tryToFold(
225 OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
226 function_ref<void(Operation *)> processGeneratedConstants) {
227 SmallVector<Attribute, 8> operandConstants;
228
229 // If this is a commutative operation, move constants to be trailing operands.
230 bool updatedOpOperands = false;
231 if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) {
232 auto isNonConstant = [&](OpOperand &o) {
233 return !matchPattern(o.get(), m_Constant());
234 };
235 auto *firstConstantIt =
236 llvm::find_if_not(op->getOpOperands(), isNonConstant);
237 auto *newConstantIt = std::stable_partition(
238 firstConstantIt, op->getOpOperands().end(), isNonConstant);
239
240 // Remember if we actually moved anything.
241 updatedOpOperands = firstConstantIt != newConstantIt;
242 }
243
244 // Check to see if any operands to the operation is constant and whether
245 // the operation knows how to constant fold itself.
246 operandConstants.assign(op->getNumOperands(), Attribute());
247 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i)
248 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i]));
249
250 // Attempt to constant fold the operation. If we failed, check to see if we at
251 // least updated the operands of the operation. We treat this as an in-place
252 // fold.
253 SmallVector<OpFoldResult, 8> foldResults;
254 if (failed(op->fold(operandConstants, foldResults)) ||
255 failed(processFoldResults(builder, op, results, foldResults,
256 processGeneratedConstants)))
257 return success(updatedOpOperands);
258 return success();
259 }
260
processFoldResults(OpBuilder & builder,Operation * op,SmallVectorImpl<Value> & results,ArrayRef<OpFoldResult> foldResults,function_ref<void (Operation *)> processGeneratedConstants)261 LogicalResult OperationFolder::processFoldResults(
262 OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
263 ArrayRef<OpFoldResult> foldResults,
264 function_ref<void(Operation *)> processGeneratedConstants) {
265 // Check to see if the operation was just updated in place.
266 if (foldResults.empty())
267 return success();
268 assert(foldResults.size() == op->getNumResults());
269
270 // Create a builder to insert new operations into the entry block of the
271 // insertion region.
272 auto *insertRegion =
273 getInsertionRegion(interfaces, builder.getInsertionBlock());
274 auto &entry = insertRegion->front();
275 OpBuilder::InsertionGuard foldGuard(builder);
276 builder.setInsertionPoint(&entry, entry.begin());
277
278 // Get the constant map for the insertion region of this operation.
279 auto &uniquedConstants = foldScopes[insertRegion];
280
281 // Create the result constants and replace the results.
282 auto *dialect = op->getDialect();
283 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) {
284 assert(!foldResults[i].isNull() && "expected valid OpFoldResult");
285
286 // Check if the result was an SSA value.
287 if (auto repl = foldResults[i].dyn_cast<Value>()) {
288 if (repl.getType() != op->getResult(i).getType()) {
289 results.clear();
290 return failure();
291 }
292 results.emplace_back(repl);
293 continue;
294 }
295
296 // Check to see if there is a canonicalized version of this constant.
297 auto res = op->getResult(i);
298 Attribute attrRepl = foldResults[i].get<Attribute>();
299 if (auto *constOp =
300 tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl,
301 res.getType(), op->getLoc())) {
302 // Ensure that this constant dominates the operation we are replacing it
303 // with. This may not automatically happen if the operation being folded
304 // was inserted before the constant within the insertion block.
305 Block *opBlock = op->getBlock();
306 if (opBlock == constOp->getBlock() && &opBlock->front() != constOp)
307 constOp->moveBefore(&opBlock->front());
308
309 results.push_back(constOp->getResult(0));
310 continue;
311 }
312 // If materialization fails, cleanup any operations generated for the
313 // previous results and return failure.
314 for (Operation &op : llvm::make_early_inc_range(
315 llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
316 notifyRemoval(&op);
317 op.erase();
318 }
319 results.clear();
320 return failure();
321 }
322
323 // Process any newly generated operations.
324 if (processGeneratedConstants) {
325 for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
326 processGeneratedConstants(&*i);
327 }
328
329 return success();
330 }
331
332 /// Try to get or create a new constant entry. On success this returns the
333 /// constant operation value, nullptr otherwise.
tryGetOrCreateConstant(ConstantMap & uniquedConstants,Dialect * dialect,OpBuilder & builder,Attribute value,Type type,Location loc)334 Operation *OperationFolder::tryGetOrCreateConstant(
335 ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder,
336 Attribute value, Type type, Location loc) {
337 // Check if an existing mapping already exists.
338 auto constKey = std::make_tuple(dialect, value, type);
339 Operation *&constOp = uniquedConstants[constKey];
340 if (constOp)
341 return constOp;
342
343 // If one doesn't exist, try to materialize one.
344 if (!(constOp = materializeConstant(dialect, builder, value, type, loc)))
345 return nullptr;
346
347 // Check to see if the generated constant is in the expected dialect.
348 auto *newDialect = constOp->getDialect();
349 if (newDialect == dialect) {
350 referencedDialects[constOp].push_back(dialect);
351 return constOp;
352 }
353
354 // If it isn't, then we also need to make sure that the mapping for the new
355 // dialect is valid.
356 auto newKey = std::make_tuple(newDialect, value, type);
357
358 // If an existing operation in the new dialect already exists, delete the
359 // materialized operation in favor of the existing one.
360 if (auto *existingOp = uniquedConstants.lookup(newKey)) {
361 constOp->erase();
362 referencedDialects[existingOp].push_back(dialect);
363 return constOp = existingOp;
364 }
365
366 // Otherwise, update the new dialect to the materialized operation.
367 referencedDialects[constOp].assign({dialect, newDialect});
368 auto newIt = uniquedConstants.insert({newKey, constOp});
369 return newIt.first->second;
370 }
371