1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// 2 // 3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous inlining utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/InliningUtils.h" 14 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/Function.h" 18 #include "mlir/IR/Operation.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 #define DEBUG_TYPE "inlining" 24 25 using namespace mlir; 26 27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the 28 /// provided caller location. 29 static void 30 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, 31 Location callerLoc) { 32 DenseMap<Location, Location> mappedLocations; 33 auto remapOpLoc = [&](Operation *op) { 34 auto it = mappedLocations.find(op->getLoc()); 35 if (it == mappedLocations.end()) { 36 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); 37 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; 38 } 39 op->setLoc(it->second); 40 }; 41 for (auto &block : inlinedBlocks) 42 block.walk(remapOpLoc); 43 } 44 45 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, 46 BlockAndValueMapping &mapper) { 47 auto remapOperands = [&](Operation *op) { 48 for (auto &operand : op->getOpOperands()) 49 if (auto mappedOp = mapper.lookupOrNull(operand.get())) 50 operand.set(mappedOp); 51 }; 52 for (auto &block : inlinedBlocks) 53 block.walk(remapOperands); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // InlinerInterface 58 //===----------------------------------------------------------------------===// 59 60 bool InlinerInterface::isLegalToInline( 61 Region *dest, Region *src, BlockAndValueMapping &valueMapping) const { 62 // Regions can always be inlined into functions. 63 if (isa<FuncOp>(dest->getParentOp())) 64 return true; 65 66 auto *handler = getInterfaceFor(dest->getParentOp()); 67 return handler ? handler->isLegalToInline(dest, src, valueMapping) : false; 68 } 69 70 bool InlinerInterface::isLegalToInline( 71 Operation *op, Region *dest, BlockAndValueMapping &valueMapping) const { 72 auto *handler = getInterfaceFor(op); 73 return handler ? handler->isLegalToInline(op, dest, valueMapping) : false; 74 } 75 76 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { 77 auto *handler = getInterfaceFor(op); 78 return handler ? handler->shouldAnalyzeRecursively(op) : true; 79 } 80 81 /// Handle the given inlined terminator by replacing it with a new operation 82 /// as necessary. 83 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { 84 auto *handler = getInterfaceFor(op); 85 assert(handler && "expected valid dialect handler"); 86 handler->handleTerminator(op, newDest); 87 } 88 89 /// Handle the given inlined terminator by replacing it with a new operation 90 /// as necessary. 91 void InlinerInterface::handleTerminator(Operation *op, 92 ArrayRef<Value> valuesToRepl) const { 93 auto *handler = getInterfaceFor(op); 94 assert(handler && "expected valid dialect handler"); 95 handler->handleTerminator(op, valuesToRepl); 96 } 97 98 /// Utility to check that all of the operations within 'src' can be inlined. 99 static bool isLegalToInline(InlinerInterface &interface, Region *src, 100 Region *insertRegion, 101 BlockAndValueMapping &valueMapping) { 102 for (auto &block : *src) { 103 for (auto &op : block) { 104 // Check this operation. 105 if (!interface.isLegalToInline(&op, insertRegion, valueMapping)) { 106 LLVM_DEBUG({ 107 llvm::dbgs() << "* Illegal to inline because of op: "; 108 op.dump(); 109 }); 110 return false; 111 } 112 // Check any nested regions. 113 if (interface.shouldAnalyzeRecursively(&op) && 114 llvm::any_of(op.getRegions(), [&](Region ®ion) { 115 return !isLegalToInline(interface, ®ion, insertRegion, 116 valueMapping); 117 })) 118 return false; 119 } 120 } 121 return true; 122 } 123 124 //===----------------------------------------------------------------------===// 125 // Inline Methods 126 //===----------------------------------------------------------------------===// 127 128 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 129 Operation *inlinePoint, 130 BlockAndValueMapping &mapper, 131 ArrayRef<Value> resultsToReplace, 132 Optional<Location> inlineLoc, 133 bool shouldCloneInlinedRegion) { 134 // We expect the region to have at least one block. 135 if (src->empty()) 136 return failure(); 137 138 // Check that all of the region arguments have been mapped. 139 auto *srcEntryBlock = &src->front(); 140 if (llvm::any_of(srcEntryBlock->getArguments(), 141 [&](BlockArgument arg) { return !mapper.contains(arg); })) 142 return failure(); 143 144 // The insertion point must be within a block. 145 Block *insertBlock = inlinePoint->getBlock(); 146 if (!insertBlock) 147 return failure(); 148 Region *insertRegion = insertBlock->getParent(); 149 150 // Check that the operations within the source region are valid to inline. 151 if (!interface.isLegalToInline(insertRegion, src, mapper) || 152 !isLegalToInline(interface, src, insertRegion, mapper)) 153 return failure(); 154 155 // Split the insertion block. 156 Block *postInsertBlock = 157 insertBlock->splitBlock(++inlinePoint->getIterator()); 158 159 // Check to see if the region is being cloned, or moved inline. In either 160 // case, move the new blocks after the 'insertBlock' to improve IR 161 // readability. 162 if (shouldCloneInlinedRegion) 163 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); 164 else 165 insertRegion->getBlocks().splice(postInsertBlock->getIterator(), 166 src->getBlocks(), src->begin(), 167 src->end()); 168 169 // Get the range of newly inserted blocks. 170 auto newBlocks = llvm::make_range(std::next(insertBlock->getIterator()), 171 postInsertBlock->getIterator()); 172 Block *firstNewBlock = &*newBlocks.begin(); 173 174 // Remap the locations of the inlined operations if a valid source location 175 // was provided. 176 if (inlineLoc && !inlineLoc->isa<UnknownLoc>()) 177 remapInlinedLocations(newBlocks, *inlineLoc); 178 179 // If the blocks were moved in-place, make sure to remap any necessary 180 // operands. 181 if (!shouldCloneInlinedRegion) 182 remapInlinedOperands(newBlocks, mapper); 183 184 // Process the newly inlined blocks. 185 interface.processInlinedBlocks(newBlocks); 186 187 // Handle the case where only a single block was inlined. 188 if (std::next(newBlocks.begin()) == newBlocks.end()) { 189 // Have the interface handle the terminator of this block. 190 auto *firstBlockTerminator = firstNewBlock->getTerminator(); 191 interface.handleTerminator(firstBlockTerminator, resultsToReplace); 192 firstBlockTerminator->erase(); 193 194 // Merge the post insert block into the cloned entry block. 195 firstNewBlock->getOperations().splice(firstNewBlock->end(), 196 postInsertBlock->getOperations()); 197 postInsertBlock->erase(); 198 } else { 199 // Otherwise, there were multiple blocks inlined. Add arguments to the post 200 // insertion block to represent the results to replace. 201 for (Value resultToRepl : resultsToReplace) { 202 resultToRepl.replaceAllUsesWith( 203 postInsertBlock->addArgument(resultToRepl.getType())); 204 } 205 206 /// Handle the terminators for each of the new blocks. 207 for (auto &newBlock : newBlocks) 208 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); 209 } 210 211 // Splice the instructions of the inlined entry block into the insert block. 212 insertBlock->getOperations().splice(insertBlock->end(), 213 firstNewBlock->getOperations()); 214 firstNewBlock->erase(); 215 return success(); 216 } 217 218 /// This function is an overload of the above 'inlineRegion' that allows for 219 /// providing the set of operands ('inlinedOperands') that should be used 220 /// in-favor of the region arguments when inlining. 221 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 222 Operation *inlinePoint, 223 ArrayRef<Value> inlinedOperands, 224 ArrayRef<Value> resultsToReplace, 225 Optional<Location> inlineLoc, 226 bool shouldCloneInlinedRegion) { 227 // We expect the region to have at least one block. 228 if (src->empty()) 229 return failure(); 230 231 auto *entryBlock = &src->front(); 232 if (inlinedOperands.size() != entryBlock->getNumArguments()) 233 return failure(); 234 235 // Map the provided call operands to the arguments of the region. 236 BlockAndValueMapping mapper; 237 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { 238 // Verify that the types of the provided values match the function argument 239 // types. 240 BlockArgument regionArg = entryBlock->getArgument(i); 241 if (inlinedOperands[i].getType() != regionArg.getType()) 242 return failure(); 243 mapper.map(regionArg, inlinedOperands[i]); 244 } 245 246 // Call into the main region inliner function. 247 return inlineRegion(interface, src, inlinePoint, mapper, resultsToReplace, 248 inlineLoc, shouldCloneInlinedRegion); 249 } 250 251 /// Utility function used to generate a cast operation from the given interface, 252 /// or return nullptr if a cast could not be generated. 253 static Value materializeConversion(const DialectInlinerInterface *interface, 254 SmallVectorImpl<Operation *> &castOps, 255 OpBuilder &castBuilder, Value arg, Type type, 256 Location conversionLoc) { 257 if (!interface) 258 return nullptr; 259 260 // Check to see if the interface for the call can materialize a conversion. 261 Operation *castOp = interface->materializeCallConversion(castBuilder, arg, 262 type, conversionLoc); 263 if (!castOp) 264 return nullptr; 265 castOps.push_back(castOp); 266 267 // Ensure that the generated cast is correct. 268 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && 269 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); 270 return castOp->getResult(0); 271 } 272 273 /// This function inlines a given region, 'src', of a callable operation, 274 /// 'callable', into the location defined by the given call operation. This 275 /// function returns failure if inlining is not possible, success otherwise. On 276 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' 277 /// corresponds to whether the source region should be cloned into the 'call' or 278 /// spliced directly. 279 LogicalResult mlir::inlineCall(InlinerInterface &interface, 280 CallOpInterface call, 281 CallableOpInterface callable, Region *src, 282 bool shouldCloneInlinedRegion) { 283 // We expect the region to have at least one block. 284 if (src->empty()) 285 return failure(); 286 auto *entryBlock = &src->front(); 287 ArrayRef<Type> callableResultTypes = callable.getCallableResults(src); 288 289 // Make sure that the number of arguments and results matchup between the call 290 // and the region. 291 SmallVector<Value, 8> callOperands(call.getArgOperands()); 292 SmallVector<Value, 8> callResults(call.getOperation()->getResults()); 293 if (callOperands.size() != entryBlock->getNumArguments() || 294 callResults.size() != callableResultTypes.size()) 295 return failure(); 296 297 // A set of cast operations generated to matchup the signature of the region 298 // with the signature of the call. 299 SmallVector<Operation *, 4> castOps; 300 castOps.reserve(callOperands.size() + callResults.size()); 301 302 // Functor used to cleanup generated state on failure. 303 auto cleanupState = [&] { 304 for (auto *op : castOps) { 305 op->getResult(0).replaceAllUsesWith(op->getOperand(0)); 306 op->erase(); 307 } 308 return failure(); 309 }; 310 311 // Builder used for any conversion operations that need to be materialized. 312 OpBuilder castBuilder(call); 313 Location castLoc = call.getLoc(); 314 auto *callInterface = interface.getInterfaceFor(call.getDialect()); 315 316 // Map the provided call operands to the arguments of the region. 317 BlockAndValueMapping mapper; 318 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { 319 BlockArgument regionArg = entryBlock->getArgument(i); 320 Value operand = callOperands[i]; 321 322 // If the call operand doesn't match the expected region argument, try to 323 // generate a cast. 324 Type regionArgType = regionArg.getType(); 325 if (operand.getType() != regionArgType) { 326 if (!(operand = materializeConversion(callInterface, castOps, castBuilder, 327 operand, regionArgType, castLoc))) 328 return cleanupState(); 329 } 330 mapper.map(regionArg, operand); 331 } 332 333 // Ensure that the resultant values of the call, match the callable. 334 castBuilder.setInsertionPointAfter(call); 335 for (unsigned i = 0, e = callResults.size(); i != e; ++i) { 336 Value callResult = callResults[i]; 337 if (callResult.getType() == callableResultTypes[i]) 338 continue; 339 340 // Generate a conversion that will produce the original type, so that the IR 341 // is still valid after the original call gets replaced. 342 Value castResult = 343 materializeConversion(callInterface, castOps, castBuilder, callResult, 344 callResult.getType(), castLoc); 345 if (!castResult) 346 return cleanupState(); 347 callResult.replaceAllUsesWith(castResult); 348 castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); 349 } 350 351 // Attempt to inline the call. 352 if (failed(inlineRegion(interface, src, call, mapper, callResults, 353 call.getLoc(), shouldCloneInlinedRegion))) 354 return cleanupState(); 355 return success(); 356 } 357