1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous inlining utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/InliningUtils.h" 14 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/BuiltinOps.h" 18 #include "mlir/IR/Operation.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 #define DEBUG_TYPE "inlining" 24 25 using namespace mlir; 26 27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the 28 /// provided caller location. 29 static void 30 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, 31 Location callerLoc) { 32 DenseMap<Location, Location> mappedLocations; 33 auto remapOpLoc = [&](Operation *op) { 34 auto it = mappedLocations.find(op->getLoc()); 35 if (it == mappedLocations.end()) { 36 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); 37 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; 38 } 39 op->setLoc(it->second); 40 }; 41 for (auto &block : inlinedBlocks) 42 block.walk(remapOpLoc); 43 } 44 45 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, 46 BlockAndValueMapping &mapper) { 47 auto remapOperands = [&](Operation *op) { 48 for (auto &operand : op->getOpOperands()) 49 if (auto mappedOp = mapper.lookupOrNull(operand.get())) 50 operand.set(mappedOp); 51 }; 52 for (auto &block : inlinedBlocks) 53 block.walk(remapOperands); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // InlinerInterface 58 //===----------------------------------------------------------------------===// 59 60 bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable, 61 bool wouldBeCloned) const { 62 if (auto *handler = getInterfaceFor(call)) 63 return handler->isLegalToInline(call, callable, wouldBeCloned); 64 return false; 65 } 66 67 bool InlinerInterface::isLegalToInline( 68 Region *dest, Region *src, bool wouldBeCloned, 69 BlockAndValueMapping &valueMapping) const { 70 // Regions can always be inlined into functions. 71 if (isa<FuncOp>(dest->getParentOp())) 72 return true; 73 74 if (auto *handler = getInterfaceFor(dest->getParentOp())) 75 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping); 76 return false; 77 } 78 79 bool InlinerInterface::isLegalToInline( 80 Operation *op, Region *dest, bool wouldBeCloned, 81 BlockAndValueMapping &valueMapping) const { 82 if (auto *handler = getInterfaceFor(op)) 83 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping); 84 return false; 85 } 86 87 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { 88 auto *handler = getInterfaceFor(op); 89 return handler ? handler->shouldAnalyzeRecursively(op) : true; 90 } 91 92 /// Handle the given inlined terminator by replacing it with a new operation 93 /// as necessary. 94 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { 95 auto *handler = getInterfaceFor(op); 96 assert(handler && "expected valid dialect handler"); 97 handler->handleTerminator(op, newDest); 98 } 99 100 /// Handle the given inlined terminator by replacing it with a new operation 101 /// as necessary. 102 void InlinerInterface::handleTerminator(Operation *op, 103 ArrayRef<Value> valuesToRepl) const { 104 auto *handler = getInterfaceFor(op); 105 assert(handler && "expected valid dialect handler"); 106 handler->handleTerminator(op, valuesToRepl); 107 } 108 109 void InlinerInterface::processInlinedCallBlocks( 110 Operation *call, iterator_range<Region::iterator> inlinedBlocks) const { 111 auto *handler = getInterfaceFor(call); 112 assert(handler && "expected valid dialect handler"); 113 handler->processInlinedCallBlocks(call, inlinedBlocks); 114 } 115 116 /// Utility to check that all of the operations within 'src' can be inlined. 117 static bool isLegalToInline(InlinerInterface &interface, Region *src, 118 Region *insertRegion, bool shouldCloneInlinedRegion, 119 BlockAndValueMapping &valueMapping) { 120 for (auto &block : *src) { 121 for (auto &op : block) { 122 // Check this operation. 123 if (!interface.isLegalToInline(&op, insertRegion, 124 shouldCloneInlinedRegion, valueMapping)) { 125 LLVM_DEBUG({ 126 llvm::dbgs() << "* Illegal to inline because of op: "; 127 op.dump(); 128 }); 129 return false; 130 } 131 // Check any nested regions. 132 if (interface.shouldAnalyzeRecursively(&op) && 133 llvm::any_of(op.getRegions(), [&](Region ®ion) { 134 return !isLegalToInline(interface, ®ion, insertRegion, 135 shouldCloneInlinedRegion, valueMapping); 136 })) 137 return false; 138 } 139 } 140 return true; 141 } 142 143 //===----------------------------------------------------------------------===// 144 // Inline Methods 145 //===----------------------------------------------------------------------===// 146 147 static LogicalResult 148 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, 149 Block::iterator inlinePoint, BlockAndValueMapping &mapper, 150 ValueRange resultsToReplace, TypeRange regionResultTypes, 151 Optional<Location> inlineLoc, bool shouldCloneInlinedRegion, 152 Operation *call = nullptr) { 153 assert(resultsToReplace.size() == regionResultTypes.size()); 154 // We expect the region to have at least one block. 155 if (src->empty()) 156 return failure(); 157 158 // Check that all of the region arguments have been mapped. 159 auto *srcEntryBlock = &src->front(); 160 if (llvm::any_of(srcEntryBlock->getArguments(), 161 [&](BlockArgument arg) { return !mapper.contains(arg); })) 162 return failure(); 163 164 // Check that the operations within the source region are valid to inline. 165 Region *insertRegion = inlineBlock->getParent(); 166 if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion, 167 mapper) || 168 !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion, 169 mapper)) 170 return failure(); 171 172 // Check to see if the region is being cloned, or moved inline. In either 173 // case, move the new blocks after the 'insertBlock' to improve IR 174 // readability. 175 Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint); 176 if (shouldCloneInlinedRegion) 177 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); 178 else 179 insertRegion->getBlocks().splice(postInsertBlock->getIterator(), 180 src->getBlocks(), src->begin(), 181 src->end()); 182 183 // Get the range of newly inserted blocks. 184 auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()), 185 postInsertBlock->getIterator()); 186 Block *firstNewBlock = &*newBlocks.begin(); 187 188 // Remap the locations of the inlined operations if a valid source location 189 // was provided. 190 if (inlineLoc && !inlineLoc->isa<UnknownLoc>()) 191 remapInlinedLocations(newBlocks, *inlineLoc); 192 193 // If the blocks were moved in-place, make sure to remap any necessary 194 // operands. 195 if (!shouldCloneInlinedRegion) 196 remapInlinedOperands(newBlocks, mapper); 197 198 // Process the newly inlined blocks. 199 if (call) 200 interface.processInlinedCallBlocks(call, newBlocks); 201 interface.processInlinedBlocks(newBlocks); 202 203 // Handle the case where only a single block was inlined. 204 if (std::next(newBlocks.begin()) == newBlocks.end()) { 205 // Have the interface handle the terminator of this block. 206 auto *firstBlockTerminator = firstNewBlock->getTerminator(); 207 interface.handleTerminator(firstBlockTerminator, 208 llvm::to_vector<6>(resultsToReplace)); 209 firstBlockTerminator->erase(); 210 211 // Merge the post insert block into the cloned entry block. 212 firstNewBlock->getOperations().splice(firstNewBlock->end(), 213 postInsertBlock->getOperations()); 214 postInsertBlock->erase(); 215 } else { 216 // Otherwise, there were multiple blocks inlined. Add arguments to the post 217 // insertion block to represent the results to replace. 218 for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) { 219 resultToRepl.value().replaceAllUsesWith( 220 postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()], 221 resultToRepl.value().getLoc())); 222 } 223 224 /// Handle the terminators for each of the new blocks. 225 for (auto &newBlock : newBlocks) 226 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); 227 } 228 229 // Splice the instructions of the inlined entry block into the insert block. 230 inlineBlock->getOperations().splice(inlineBlock->end(), 231 firstNewBlock->getOperations()); 232 firstNewBlock->erase(); 233 return success(); 234 } 235 236 static LogicalResult 237 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, 238 Block::iterator inlinePoint, ValueRange inlinedOperands, 239 ValueRange resultsToReplace, Optional<Location> inlineLoc, 240 bool shouldCloneInlinedRegion, Operation *call = nullptr) { 241 // We expect the region to have at least one block. 242 if (src->empty()) 243 return failure(); 244 245 auto *entryBlock = &src->front(); 246 if (inlinedOperands.size() != entryBlock->getNumArguments()) 247 return failure(); 248 249 // Map the provided call operands to the arguments of the region. 250 BlockAndValueMapping mapper; 251 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { 252 // Verify that the types of the provided values match the function argument 253 // types. 254 BlockArgument regionArg = entryBlock->getArgument(i); 255 if (inlinedOperands[i].getType() != regionArg.getType()) 256 return failure(); 257 mapper.map(regionArg, inlinedOperands[i]); 258 } 259 260 // Call into the main region inliner function. 261 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, 262 resultsToReplace, resultsToReplace.getTypes(), 263 inlineLoc, shouldCloneInlinedRegion, call); 264 } 265 266 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 267 Operation *inlinePoint, 268 BlockAndValueMapping &mapper, 269 ValueRange resultsToReplace, 270 TypeRange regionResultTypes, 271 Optional<Location> inlineLoc, 272 bool shouldCloneInlinedRegion) { 273 return inlineRegion(interface, src, inlinePoint->getBlock(), 274 ++inlinePoint->getIterator(), mapper, resultsToReplace, 275 regionResultTypes, inlineLoc, shouldCloneInlinedRegion); 276 } 277 LogicalResult 278 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock, 279 Block::iterator inlinePoint, BlockAndValueMapping &mapper, 280 ValueRange resultsToReplace, TypeRange regionResultTypes, 281 Optional<Location> inlineLoc, 282 bool shouldCloneInlinedRegion) { 283 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, 284 resultsToReplace, regionResultTypes, inlineLoc, 285 shouldCloneInlinedRegion); 286 } 287 288 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 289 Operation *inlinePoint, 290 ValueRange inlinedOperands, 291 ValueRange resultsToReplace, 292 Optional<Location> inlineLoc, 293 bool shouldCloneInlinedRegion) { 294 return inlineRegion(interface, src, inlinePoint->getBlock(), 295 ++inlinePoint->getIterator(), inlinedOperands, 296 resultsToReplace, inlineLoc, shouldCloneInlinedRegion); 297 } 298 LogicalResult 299 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock, 300 Block::iterator inlinePoint, ValueRange inlinedOperands, 301 ValueRange resultsToReplace, Optional<Location> inlineLoc, 302 bool shouldCloneInlinedRegion) { 303 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, 304 inlinedOperands, resultsToReplace, inlineLoc, 305 shouldCloneInlinedRegion); 306 } 307 308 /// Utility function used to generate a cast operation from the given interface, 309 /// or return nullptr if a cast could not be generated. 310 static Value materializeConversion(const DialectInlinerInterface *interface, 311 SmallVectorImpl<Operation *> &castOps, 312 OpBuilder &castBuilder, Value arg, Type type, 313 Location conversionLoc) { 314 if (!interface) 315 return nullptr; 316 317 // Check to see if the interface for the call can materialize a conversion. 318 Operation *castOp = interface->materializeCallConversion(castBuilder, arg, 319 type, conversionLoc); 320 if (!castOp) 321 return nullptr; 322 castOps.push_back(castOp); 323 324 // Ensure that the generated cast is correct. 325 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && 326 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); 327 return castOp->getResult(0); 328 } 329 330 /// This function inlines a given region, 'src', of a callable operation, 331 /// 'callable', into the location defined by the given call operation. This 332 /// function returns failure if inlining is not possible, success otherwise. On 333 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' 334 /// corresponds to whether the source region should be cloned into the 'call' or 335 /// spliced directly. 336 LogicalResult mlir::inlineCall(InlinerInterface &interface, 337 CallOpInterface call, 338 CallableOpInterface callable, Region *src, 339 bool shouldCloneInlinedRegion) { 340 // We expect the region to have at least one block. 341 if (src->empty()) 342 return failure(); 343 auto *entryBlock = &src->front(); 344 ArrayRef<Type> callableResultTypes = callable.getCallableResults(); 345 346 // Make sure that the number of arguments and results matchup between the call 347 // and the region. 348 SmallVector<Value, 8> callOperands(call.getArgOperands()); 349 SmallVector<Value, 8> callResults(call->getResults()); 350 if (callOperands.size() != entryBlock->getNumArguments() || 351 callResults.size() != callableResultTypes.size()) 352 return failure(); 353 354 // A set of cast operations generated to matchup the signature of the region 355 // with the signature of the call. 356 SmallVector<Operation *, 4> castOps; 357 castOps.reserve(callOperands.size() + callResults.size()); 358 359 // Functor used to cleanup generated state on failure. 360 auto cleanupState = [&] { 361 for (auto *op : castOps) { 362 op->getResult(0).replaceAllUsesWith(op->getOperand(0)); 363 op->erase(); 364 } 365 return failure(); 366 }; 367 368 // Builder used for any conversion operations that need to be materialized. 369 OpBuilder castBuilder(call); 370 Location castLoc = call.getLoc(); 371 const auto *callInterface = interface.getInterfaceFor(call->getDialect()); 372 373 // Map the provided call operands to the arguments of the region. 374 BlockAndValueMapping mapper; 375 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { 376 BlockArgument regionArg = entryBlock->getArgument(i); 377 Value operand = callOperands[i]; 378 379 // If the call operand doesn't match the expected region argument, try to 380 // generate a cast. 381 Type regionArgType = regionArg.getType(); 382 if (operand.getType() != regionArgType) { 383 if (!(operand = materializeConversion(callInterface, castOps, castBuilder, 384 operand, regionArgType, castLoc))) 385 return cleanupState(); 386 } 387 mapper.map(regionArg, operand); 388 } 389 390 // Ensure that the resultant values of the call match the callable. 391 castBuilder.setInsertionPointAfter(call); 392 for (unsigned i = 0, e = callResults.size(); i != e; ++i) { 393 Value callResult = callResults[i]; 394 if (callResult.getType() == callableResultTypes[i]) 395 continue; 396 397 // Generate a conversion that will produce the original type, so that the IR 398 // is still valid after the original call gets replaced. 399 Value castResult = 400 materializeConversion(callInterface, castOps, castBuilder, callResult, 401 callResult.getType(), castLoc); 402 if (!castResult) 403 return cleanupState(); 404 callResult.replaceAllUsesWith(castResult); 405 castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); 406 } 407 408 // Check that it is legal to inline the callable into the call. 409 if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion)) 410 return cleanupState(); 411 412 // Attempt to inline the call. 413 if (failed(inlineRegionImpl(interface, src, call->getBlock(), 414 ++call->getIterator(), mapper, callResults, 415 callableResultTypes, call.getLoc(), 416 shouldCloneInlinedRegion, call))) 417 return cleanupState(); 418 return success(); 419 } 420