1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines various operation fold utilities. These utilities are 10 // intended to be used by passes to unify and simply their logic. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Transforms/FoldUtils.h" 15 16 #include "mlir/Dialect/StandardOps/IR/Ops.h" 17 #include "mlir/IR/Builders.h" 18 #include "mlir/IR/Matchers.h" 19 #include "mlir/IR/Operation.h" 20 21 using namespace mlir; 22 23 /// Given an operation, find the parent region that folded constants should be 24 /// inserted into. 25 static Region * 26 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces, 27 Block *insertionBlock) { 28 while (Region *region = insertionBlock->getParent()) { 29 // Insert in this region for any of the following scenarios: 30 // * The parent is unregistered, or is known to be isolated from above. 31 // * The parent is a top-level operation. 32 auto *parentOp = region->getParentOp(); 33 if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || 34 !parentOp->getBlock()) 35 return region; 36 37 // Otherwise, check if this region is a desired insertion region. 38 auto *interface = interfaces.getInterfaceFor(parentOp); 39 if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) 40 return region; 41 42 // Traverse up the parent looking for an insertion region. 43 insertionBlock = parentOp->getBlock(); 44 } 45 llvm_unreachable("expected valid insertion region"); 46 } 47 48 /// A utility function used to materialize a constant for a given attribute and 49 /// type. On success, a valid constant value is returned. Otherwise, null is 50 /// returned 51 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, 52 Attribute value, Type type, 53 Location loc) { 54 auto insertPt = builder.getInsertionPoint(); 55 (void)insertPt; 56 57 // Ask the dialect to materialize a constant operation for this value. 58 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { 59 assert(insertPt == builder.getInsertionPoint()); 60 assert(matchPattern(constOp, m_Constant())); 61 return constOp; 62 } 63 64 // TODO: To facilitate splitting the std dialect (PR48490), have a special 65 // case for falling back to std.constant. Eventually, we will have separate 66 // ops tensor.constant, int.constant, float.constant, etc. that live in their 67 // respective dialects, which will allow each dialect to implement the 68 // materializeConstant hook above. 69 // 70 // The special case is needed because in the interim state while we are 71 // splitting out those dialects from std, the std dialect depends on the 72 // tensor dialect, which makes it impossible for the tensor dialect to use 73 // std.constant (it would be a cyclic dependency) as part of its 74 // materializeConstant hook. 75 // 76 // If the dialect is unable to materialize a constant, check to see if the 77 // standard constant can be used. 78 if (ConstantOp::isBuildableWith(value, type)) 79 return builder.create<ConstantOp>(loc, type, value); 80 return nullptr; 81 } 82 83 //===----------------------------------------------------------------------===// 84 // OperationFolder 85 //===----------------------------------------------------------------------===// 86 87 /// Scan the specified region for constants that can be used in folding, 88 /// moving them to the entry block (or any custom insertion location specified 89 /// by shouldMaterializeInto), and add them to our known-constants table. 90 void OperationFolder::processExistingConstants(Region ®ion) { 91 if (region.empty()) 92 return; 93 94 // March the constant insertion point forward, moving all constants to the 95 // top of the block, but keeping them in their order of discovery. 96 Region *insertRegion = getInsertionRegion(interfaces, ®ion.front()); 97 auto &uniquedConstants = foldScopes[insertRegion]; 98 99 Block &insertBlock = insertRegion->front(); 100 Block::iterator constantIterator = insertBlock.begin(); 101 102 // Process each constant that we discover in this region. 103 auto processConstant = [&](Operation *op, Attribute value) { 104 assert(op->getNumResults() == 1 && "constants have one result"); 105 // Check to see if we already have an instance of this constant. 106 Operation *&constOp = uniquedConstants[std::make_tuple( 107 op->getDialect(), value, op->getResult(0).getType())]; 108 109 // If we already have an instance of this constant, CSE/delete this one as 110 // we go. 111 if (constOp) { 112 if (constantIterator == Block::iterator(op)) 113 ++constantIterator; // Don't invalidate our iterator when scanning. 114 op->getResult(0).replaceAllUsesWith(constOp->getResult(0)); 115 op->erase(); 116 return; 117 } 118 119 // Otherwise, remember that we have this constant. 120 constOp = op; 121 referencedDialects[op].push_back(op->getDialect()); 122 123 // If the constant isn't already at the insertion point then move it up. 124 if (constantIterator != Block::iterator(op)) 125 op->moveBefore(&insertBlock, constantIterator); 126 else 127 ++constantIterator; // It was pointing at the constant. 128 }; 129 130 // Collect all the constants for this region of isolation or insertion (as 131 // specified by the shouldMaterializeInto hook). Collect any subregions of 132 // isolation/constant insertion for subsequent processing. 133 SmallVector<Operation *> insertionSubregionOps; 134 region.walk<WalkOrder::PreOrder>([&](Operation *op) { 135 // If this is a constant, process it. 136 Attribute value; 137 if (matchPattern(op, m_Constant(&value))) { 138 processConstant(op, value); 139 // We may have deleted the operation, don't check it for regions. 140 return WalkResult::skip(); 141 } 142 143 // If the operation has regions and is isolated, don't recurse into it. 144 if (op->getNumRegions() != 0) { 145 auto hasDifferentInsertRegion = [&](Region ®ion) { 146 return !region.empty() && 147 getInsertionRegion(interfaces, ®ion.front()) != insertRegion; 148 }; 149 if (llvm::any_of(op->getRegions(), hasDifferentInsertRegion)) { 150 insertionSubregionOps.push_back(op); 151 return WalkResult::skip(); 152 } 153 } 154 155 // Otherwise keep going. 156 return WalkResult::advance(); 157 }); 158 159 // Process regions in any isolated ops separately. 160 for (Operation *subregionOps : insertionSubregionOps) { 161 for (Region ®ion : subregionOps->getRegions()) 162 processExistingConstants(region); 163 } 164 } 165 166 LogicalResult OperationFolder::tryToFold( 167 Operation *op, function_ref<void(Operation *)> processGeneratedConstants, 168 function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) { 169 if (inPlaceUpdate) 170 *inPlaceUpdate = false; 171 172 // If this is a unique'd constant, return failure as we know that it has 173 // already been folded. 174 if (referencedDialects.count(op)) 175 return failure(); 176 177 // Try to fold the operation. 178 SmallVector<Value, 8> results; 179 OpBuilder builder(op); 180 if (failed(tryToFold(builder, op, results, processGeneratedConstants))) 181 return failure(); 182 183 // Check to see if the operation was just updated in place. 184 if (results.empty()) { 185 if (inPlaceUpdate) 186 *inPlaceUpdate = true; 187 return success(); 188 } 189 190 // Constant folding succeeded. We will start replacing this op's uses and 191 // erase this op. Invoke the callback provided by the caller to perform any 192 // pre-replacement action. 193 if (preReplaceAction) 194 preReplaceAction(op); 195 196 // Replace all of the result values and erase the operation. 197 for (unsigned i = 0, e = results.size(); i != e; ++i) 198 op->getResult(i).replaceAllUsesWith(results[i]); 199 op->erase(); 200 return success(); 201 } 202 203 /// Notifies that the given constant `op` should be remove from this 204 /// OperationFolder's internal bookkeeping. 205 void OperationFolder::notifyRemoval(Operation *op) { 206 // Check to see if this operation is uniqued within the folder. 207 auto it = referencedDialects.find(op); 208 if (it == referencedDialects.end()) 209 return; 210 211 // Get the constant value for this operation, this is the value that was used 212 // to unique the operation internally. 213 Attribute constValue; 214 matchPattern(op, m_Constant(&constValue)); 215 assert(constValue); 216 217 // Get the constant map that this operation was uniqued in. 218 auto &uniquedConstants = 219 foldScopes[getInsertionRegion(interfaces, op->getBlock())]; 220 221 // Erase all of the references to this operation. 222 auto type = op->getResult(0).getType(); 223 for (auto *dialect : it->second) 224 uniquedConstants.erase(std::make_tuple(dialect, constValue, type)); 225 referencedDialects.erase(it); 226 } 227 228 /// Clear out any constants cached inside of the folder. 229 void OperationFolder::clear() { 230 foldScopes.clear(); 231 referencedDialects.clear(); 232 } 233 234 /// Get or create a constant using the given builder. On success this returns 235 /// the constant operation, nullptr otherwise. 236 Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect, 237 Attribute value, Type type, 238 Location loc) { 239 OpBuilder::InsertionGuard foldGuard(builder); 240 241 // Use the builder insertion block to find an insertion point for the 242 // constant. 243 auto *insertRegion = 244 getInsertionRegion(interfaces, builder.getInsertionBlock()); 245 auto &entry = insertRegion->front(); 246 builder.setInsertionPoint(&entry, entry.begin()); 247 248 // Get the constant map for the insertion region of this operation. 249 auto &uniquedConstants = foldScopes[insertRegion]; 250 Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, 251 builder, value, type, loc); 252 return constOp ? constOp->getResult(0) : Value(); 253 } 254 255 /// Tries to perform folding on the given `op`. If successful, populates 256 /// `results` with the results of the folding. 257 LogicalResult OperationFolder::tryToFold( 258 OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results, 259 function_ref<void(Operation *)> processGeneratedConstants) { 260 SmallVector<Attribute, 8> operandConstants; 261 SmallVector<OpFoldResult, 8> foldResults; 262 263 // If this is a commutative operation, move constants to be trailing operands. 264 if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) { 265 std::stable_partition( 266 op->getOpOperands().begin(), op->getOpOperands().end(), 267 [&](OpOperand &O) { return !matchPattern(O.get(), m_Constant()); }); 268 } 269 270 // Check to see if any operands to the operation is constant and whether 271 // the operation knows how to constant fold itself. 272 operandConstants.assign(op->getNumOperands(), Attribute()); 273 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 274 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); 275 276 // Attempt to constant fold the operation. 277 if (failed(op->fold(operandConstants, foldResults))) 278 return failure(); 279 280 // Check to see if the operation was just updated in place. 281 if (foldResults.empty()) 282 return success(); 283 assert(foldResults.size() == op->getNumResults()); 284 285 // Create a builder to insert new operations into the entry block of the 286 // insertion region. 287 auto *insertRegion = 288 getInsertionRegion(interfaces, builder.getInsertionBlock()); 289 auto &entry = insertRegion->front(); 290 OpBuilder::InsertionGuard foldGuard(builder); 291 builder.setInsertionPoint(&entry, entry.begin()); 292 293 // Get the constant map for the insertion region of this operation. 294 auto &uniquedConstants = foldScopes[insertRegion]; 295 296 // Create the result constants and replace the results. 297 auto *dialect = op->getDialect(); 298 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 299 assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); 300 301 // Check if the result was an SSA value. 302 if (auto repl = foldResults[i].dyn_cast<Value>()) { 303 if (repl.getType() != op->getResult(i).getType()) 304 return failure(); 305 results.emplace_back(repl); 306 continue; 307 } 308 309 // Check to see if there is a canonicalized version of this constant. 310 auto res = op->getResult(i); 311 Attribute attrRepl = foldResults[i].get<Attribute>(); 312 if (auto *constOp = 313 tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl, 314 res.getType(), op->getLoc())) { 315 results.push_back(constOp->getResult(0)); 316 continue; 317 } 318 // If materialization fails, cleanup any operations generated for the 319 // previous results and return failure. 320 for (Operation &op : llvm::make_early_inc_range( 321 llvm::make_range(entry.begin(), builder.getInsertionPoint()))) { 322 notifyRemoval(&op); 323 op.erase(); 324 } 325 return failure(); 326 } 327 328 // Process any newly generated operations. 329 if (processGeneratedConstants) { 330 for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i) 331 processGeneratedConstants(&*i); 332 } 333 334 return success(); 335 } 336 337 /// Try to get or create a new constant entry. On success this returns the 338 /// constant operation value, nullptr otherwise. 339 Operation *OperationFolder::tryGetOrCreateConstant( 340 ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder, 341 Attribute value, Type type, Location loc) { 342 // Check if an existing mapping already exists. 343 auto constKey = std::make_tuple(dialect, value, type); 344 Operation *&constOp = uniquedConstants[constKey]; 345 if (constOp) 346 return constOp; 347 348 // If one doesn't exist, try to materialize one. 349 if (!(constOp = materializeConstant(dialect, builder, value, type, loc))) 350 return nullptr; 351 352 // Check to see if the generated constant is in the expected dialect. 353 auto *newDialect = constOp->getDialect(); 354 if (newDialect == dialect) { 355 referencedDialects[constOp].push_back(dialect); 356 return constOp; 357 } 358 359 // If it isn't, then we also need to make sure that the mapping for the new 360 // dialect is valid. 361 auto newKey = std::make_tuple(newDialect, value, type); 362 363 // If an existing operation in the new dialect already exists, delete the 364 // materialized operation in favor of the existing one. 365 if (auto *existingOp = uniquedConstants.lookup(newKey)) { 366 constOp->erase(); 367 referencedDialects[existingOp].push_back(dialect); 368 return constOp = existingOp; 369 } 370 371 // Otherwise, update the new dialect to the materialized operation. 372 referencedDialects[constOp].assign({dialect, newDialect}); 373 auto newIt = uniquedConstants.insert({newKey, constOp}); 374 return newIt.first->second; 375 } 376