1 //===- FoldUtils.cpp ---- Fold Utilities ----------------------------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file defines various operation fold utilities. These utilities are 10 // intended to be used by passes to unify and simply their logic. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "mlir/Transforms/FoldUtils.h" 15 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/Matchers.h" 18 #include "mlir/IR/Operation.h" 19 20 using namespace mlir; 21 22 /// Given an operation, find the parent region that folded constants should be 23 /// inserted into. 24 static Region * 25 getInsertionRegion(DialectInterfaceCollection<DialectFoldInterface> &interfaces, 26 Block *insertionBlock) { 27 while (Region *region = insertionBlock->getParent()) { 28 // Insert in this region for any of the following scenarios: 29 // * The parent is unregistered, or is known to be isolated from above. 30 // * The parent is a top-level operation. 31 auto *parentOp = region->getParentOp(); 32 if (parentOp->mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || 33 !parentOp->getBlock()) 34 return region; 35 36 // Otherwise, check if this region is a desired insertion region. 37 auto *interface = interfaces.getInterfaceFor(parentOp); 38 if (LLVM_UNLIKELY(interface && interface->shouldMaterializeInto(region))) 39 return region; 40 41 // Traverse up the parent looking for an insertion region. 42 insertionBlock = parentOp->getBlock(); 43 } 44 llvm_unreachable("expected valid insertion region"); 45 } 46 47 /// A utility function used to materialize a constant for a given attribute and 48 /// type. On success, a valid constant value is returned. Otherwise, null is 49 /// returned 50 static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder, 51 Attribute value, Type type, 52 Location loc) { 53 auto insertPt = builder.getInsertionPoint(); 54 (void)insertPt; 55 56 // Ask the dialect to materialize a constant operation for this value. 57 if (auto *constOp = dialect->materializeConstant(builder, value, type, loc)) { 58 assert(insertPt == builder.getInsertionPoint()); 59 assert(matchPattern(constOp, m_Constant())); 60 return constOp; 61 } 62 63 return nullptr; 64 } 65 66 //===----------------------------------------------------------------------===// 67 // OperationFolder 68 //===----------------------------------------------------------------------===// 69 70 LogicalResult OperationFolder::tryToFold( 71 Operation *op, function_ref<void(Operation *)> processGeneratedConstants, 72 function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) { 73 if (inPlaceUpdate) 74 *inPlaceUpdate = false; 75 76 // If this is a unique'd constant, return failure as we know that it has 77 // already been folded. 78 if (isFolderOwnedConstant(op)) { 79 // Check to see if we should rehoist, i.e. if a non-constant operation was 80 // inserted before this one. 81 Block *opBlock = op->getBlock(); 82 if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) 83 op->moveBefore(&opBlock->front()); 84 return failure(); 85 } 86 87 // Try to fold the operation. 88 SmallVector<Value, 8> results; 89 OpBuilder builder(op); 90 if (failed(tryToFold(builder, op, results, processGeneratedConstants))) 91 return failure(); 92 93 // Check to see if the operation was just updated in place. 94 if (results.empty()) { 95 if (inPlaceUpdate) 96 *inPlaceUpdate = true; 97 return success(); 98 } 99 100 // Constant folding succeeded. We will start replacing this op's uses and 101 // erase this op. Invoke the callback provided by the caller to perform any 102 // pre-replacement action. 103 if (preReplaceAction) 104 preReplaceAction(op); 105 106 // Replace all of the result values and erase the operation. 107 for (unsigned i = 0, e = results.size(); i != e; ++i) 108 op->getResult(i).replaceAllUsesWith(results[i]); 109 op->erase(); 110 return success(); 111 } 112 113 bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) { 114 Block *opBlock = op->getBlock(); 115 116 // If this is a constant we unique'd, we don't need to insert, but we can 117 // check to see if we should rehoist it. 118 if (isFolderOwnedConstant(op)) { 119 if (&opBlock->front() != op && !isFolderOwnedConstant(op->getPrevNode())) 120 op->moveBefore(&opBlock->front()); 121 return true; 122 } 123 124 // Get the constant value of the op if necessary. 125 if (!constValue) { 126 matchPattern(op, m_Constant(&constValue)); 127 assert(constValue && "expected `op` to be a constant"); 128 } else { 129 // Ensure that the provided constant was actually correct. 130 #ifndef NDEBUG 131 Attribute expectedValue; 132 matchPattern(op, m_Constant(&expectedValue)); 133 assert( 134 expectedValue == constValue && 135 "provided constant value was not the expected value of the constant"); 136 #endif 137 } 138 139 // Check for an existing constant operation for the attribute value. 140 Region *insertRegion = getInsertionRegion(interfaces, opBlock); 141 auto &uniquedConstants = foldScopes[insertRegion]; 142 Operation *&folderConstOp = uniquedConstants[std::make_tuple( 143 op->getDialect(), constValue, *op->result_type_begin())]; 144 145 // If there is an existing constant, replace `op`. 146 if (folderConstOp) { 147 op->replaceAllUsesWith(folderConstOp); 148 op->erase(); 149 return false; 150 } 151 152 // Otherwise, we insert `op`. If `op` is in the insertion block and is either 153 // already at the front of the block, or the previous operation is already a 154 // constant we unique'd (i.e. one we inserted), then we don't need to do 155 // anything. Otherwise, we move the constant to the insertion block. 156 Block *insertBlock = &insertRegion->front(); 157 if (opBlock != insertBlock || (&insertBlock->front() != op && 158 !isFolderOwnedConstant(op->getPrevNode()))) 159 op->moveBefore(&insertBlock->front()); 160 161 folderConstOp = op; 162 referencedDialects[op].push_back(op->getDialect()); 163 return true; 164 } 165 166 /// Notifies that the given constant `op` should be remove from this 167 /// OperationFolder's internal bookkeeping. 168 void OperationFolder::notifyRemoval(Operation *op) { 169 // Check to see if this operation is uniqued within the folder. 170 auto it = referencedDialects.find(op); 171 if (it == referencedDialects.end()) 172 return; 173 174 // Get the constant value for this operation, this is the value that was used 175 // to unique the operation internally. 176 Attribute constValue; 177 matchPattern(op, m_Constant(&constValue)); 178 assert(constValue); 179 180 // Get the constant map that this operation was uniqued in. 181 auto &uniquedConstants = 182 foldScopes[getInsertionRegion(interfaces, op->getBlock())]; 183 184 // Erase all of the references to this operation. 185 auto type = op->getResult(0).getType(); 186 for (auto *dialect : it->second) 187 uniquedConstants.erase(std::make_tuple(dialect, constValue, type)); 188 referencedDialects.erase(it); 189 } 190 191 /// Clear out any constants cached inside of the folder. 192 void OperationFolder::clear() { 193 foldScopes.clear(); 194 referencedDialects.clear(); 195 } 196 197 /// Get or create a constant using the given builder. On success this returns 198 /// the constant operation, nullptr otherwise. 199 Value OperationFolder::getOrCreateConstant(OpBuilder &builder, Dialect *dialect, 200 Attribute value, Type type, 201 Location loc) { 202 OpBuilder::InsertionGuard foldGuard(builder); 203 204 // Use the builder insertion block to find an insertion point for the 205 // constant. 206 auto *insertRegion = 207 getInsertionRegion(interfaces, builder.getInsertionBlock()); 208 auto &entry = insertRegion->front(); 209 builder.setInsertionPoint(&entry, entry.begin()); 210 211 // Get the constant map for the insertion region of this operation. 212 auto &uniquedConstants = foldScopes[insertRegion]; 213 Operation *constOp = tryGetOrCreateConstant(uniquedConstants, dialect, 214 builder, value, type, loc); 215 return constOp ? constOp->getResult(0) : Value(); 216 } 217 218 bool OperationFolder::isFolderOwnedConstant(Operation *op) const { 219 return referencedDialects.count(op); 220 } 221 222 /// Tries to perform folding on the given `op`. If successful, populates 223 /// `results` with the results of the folding. 224 LogicalResult OperationFolder::tryToFold( 225 OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results, 226 function_ref<void(Operation *)> processGeneratedConstants) { 227 SmallVector<Attribute, 8> operandConstants; 228 229 // If this is a commutative operation, move constants to be trailing operands. 230 bool updatedOpOperands = false; 231 if (op->getNumOperands() >= 2 && op->hasTrait<OpTrait::IsCommutative>()) { 232 auto isNonConstant = [&](OpOperand &o) { 233 return !matchPattern(o.get(), m_Constant()); 234 }; 235 auto *firstConstantIt = 236 llvm::find_if_not(op->getOpOperands(), isNonConstant); 237 auto *newConstantIt = std::stable_partition( 238 firstConstantIt, op->getOpOperands().end(), isNonConstant); 239 240 // Remember if we actually moved anything. 241 updatedOpOperands = firstConstantIt != newConstantIt; 242 } 243 244 // Check to see if any operands to the operation is constant and whether 245 // the operation knows how to constant fold itself. 246 operandConstants.assign(op->getNumOperands(), Attribute()); 247 for (unsigned i = 0, e = op->getNumOperands(); i != e; ++i) 248 matchPattern(op->getOperand(i), m_Constant(&operandConstants[i])); 249 250 // Attempt to constant fold the operation. If we failed, check to see if we at 251 // least updated the operands of the operation. We treat this as an in-place 252 // fold. 253 SmallVector<OpFoldResult, 8> foldResults; 254 if (failed(op->fold(operandConstants, foldResults)) || 255 failed(processFoldResults(builder, op, results, foldResults, 256 processGeneratedConstants))) 257 return success(updatedOpOperands); 258 return success(); 259 } 260 261 LogicalResult OperationFolder::processFoldResults( 262 OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results, 263 ArrayRef<OpFoldResult> foldResults, 264 function_ref<void(Operation *)> processGeneratedConstants) { 265 // Check to see if the operation was just updated in place. 266 if (foldResults.empty()) 267 return success(); 268 assert(foldResults.size() == op->getNumResults()); 269 270 // Create a builder to insert new operations into the entry block of the 271 // insertion region. 272 auto *insertRegion = 273 getInsertionRegion(interfaces, builder.getInsertionBlock()); 274 auto &entry = insertRegion->front(); 275 OpBuilder::InsertionGuard foldGuard(builder); 276 builder.setInsertionPoint(&entry, entry.begin()); 277 278 // Get the constant map for the insertion region of this operation. 279 auto &uniquedConstants = foldScopes[insertRegion]; 280 281 // Create the result constants and replace the results. 282 auto *dialect = op->getDialect(); 283 for (unsigned i = 0, e = op->getNumResults(); i != e; ++i) { 284 assert(!foldResults[i].isNull() && "expected valid OpFoldResult"); 285 286 // Check if the result was an SSA value. 287 if (auto repl = foldResults[i].dyn_cast<Value>()) { 288 if (repl.getType() != op->getResult(i).getType()) { 289 results.clear(); 290 return failure(); 291 } 292 results.emplace_back(repl); 293 continue; 294 } 295 296 // Check to see if there is a canonicalized version of this constant. 297 auto res = op->getResult(i); 298 Attribute attrRepl = foldResults[i].get<Attribute>(); 299 if (auto *constOp = 300 tryGetOrCreateConstant(uniquedConstants, dialect, builder, attrRepl, 301 res.getType(), op->getLoc())) { 302 // Ensure that this constant dominates the operation we are replacing it 303 // with. This may not automatically happen if the operation being folded 304 // was inserted before the constant within the insertion block. 305 Block *opBlock = op->getBlock(); 306 if (opBlock == constOp->getBlock() && &opBlock->front() != constOp) 307 constOp->moveBefore(&opBlock->front()); 308 309 results.push_back(constOp->getResult(0)); 310 continue; 311 } 312 // If materialization fails, cleanup any operations generated for the 313 // previous results and return failure. 314 for (Operation &op : llvm::make_early_inc_range( 315 llvm::make_range(entry.begin(), builder.getInsertionPoint()))) { 316 notifyRemoval(&op); 317 op.erase(); 318 } 319 results.clear(); 320 return failure(); 321 } 322 323 // Process any newly generated operations. 324 if (processGeneratedConstants) { 325 for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i) 326 processGeneratedConstants(&*i); 327 } 328 329 return success(); 330 } 331 332 /// Try to get or create a new constant entry. On success this returns the 333 /// constant operation value, nullptr otherwise. 334 Operation *OperationFolder::tryGetOrCreateConstant( 335 ConstantMap &uniquedConstants, Dialect *dialect, OpBuilder &builder, 336 Attribute value, Type type, Location loc) { 337 // Check if an existing mapping already exists. 338 auto constKey = std::make_tuple(dialect, value, type); 339 Operation *&constOp = uniquedConstants[constKey]; 340 if (constOp) 341 return constOp; 342 343 // If one doesn't exist, try to materialize one. 344 if (!(constOp = materializeConstant(dialect, builder, value, type, loc))) 345 return nullptr; 346 347 // Check to see if the generated constant is in the expected dialect. 348 auto *newDialect = constOp->getDialect(); 349 if (newDialect == dialect) { 350 referencedDialects[constOp].push_back(dialect); 351 return constOp; 352 } 353 354 // If it isn't, then we also need to make sure that the mapping for the new 355 // dialect is valid. 356 auto newKey = std::make_tuple(newDialect, value, type); 357 358 // If an existing operation in the new dialect already exists, delete the 359 // materialized operation in favor of the existing one. 360 if (auto *existingOp = uniquedConstants.lookup(newKey)) { 361 constOp->erase(); 362 referencedDialects[existingOp].push_back(dialect); 363 return constOp = existingOp; 364 } 365 366 // Otherwise, update the new dialect to the materialized operation. 367 referencedDialects[constOp].assign({dialect, newDialect}); 368 auto newIt = uniquedConstants.insert({newKey, constOp}); 369 return newIt.first->second; 370 } 371