1 //===- InliningUtils.cpp ---- Misc utilities for inlining -----------------===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements miscellaneous inlining utilities. 10 // 11 //===----------------------------------------------------------------------===// 12 13 #include "mlir/Transforms/InliningUtils.h" 14 15 #include "mlir/IR/BlockAndValueMapping.h" 16 #include "mlir/IR/Builders.h" 17 #include "mlir/IR/Operation.h" 18 #include "mlir/Interfaces/CallInterfaces.h" 19 #include "llvm/ADT/MapVector.h" 20 #include "llvm/Support/Debug.h" 21 #include "llvm/Support/raw_ostream.h" 22 23 #define DEBUG_TYPE "inlining" 24 25 using namespace mlir; 26 27 /// Remap locations from the inlined blocks with CallSiteLoc locations with the 28 /// provided caller location. 29 static void 30 remapInlinedLocations(iterator_range<Region::iterator> inlinedBlocks, 31 Location callerLoc) { 32 DenseMap<Location, Location> mappedLocations; 33 auto remapOpLoc = [&](Operation *op) { 34 auto it = mappedLocations.find(op->getLoc()); 35 if (it == mappedLocations.end()) { 36 auto newLoc = CallSiteLoc::get(op->getLoc(), callerLoc); 37 it = mappedLocations.try_emplace(op->getLoc(), newLoc).first; 38 } 39 op->setLoc(it->second); 40 }; 41 for (auto &block : inlinedBlocks) 42 block.walk(remapOpLoc); 43 } 44 45 static void remapInlinedOperands(iterator_range<Region::iterator> inlinedBlocks, 46 BlockAndValueMapping &mapper) { 47 auto remapOperands = [&](Operation *op) { 48 for (auto &operand : op->getOpOperands()) 49 if (auto mappedOp = mapper.lookupOrNull(operand.get())) 50 operand.set(mappedOp); 51 }; 52 for (auto &block : inlinedBlocks) 53 block.walk(remapOperands); 54 } 55 56 //===----------------------------------------------------------------------===// 57 // InlinerInterface 58 //===----------------------------------------------------------------------===// 59 60 bool InlinerInterface::isLegalToInline(Operation *call, Operation *callable, 61 bool wouldBeCloned) const { 62 if (auto *handler = getInterfaceFor(call)) 63 return handler->isLegalToInline(call, callable, wouldBeCloned); 64 return false; 65 } 66 67 bool InlinerInterface::isLegalToInline( 68 Region *dest, Region *src, bool wouldBeCloned, 69 BlockAndValueMapping &valueMapping) const { 70 if (auto *handler = getInterfaceFor(dest->getParentOp())) 71 return handler->isLegalToInline(dest, src, wouldBeCloned, valueMapping); 72 return false; 73 } 74 75 bool InlinerInterface::isLegalToInline( 76 Operation *op, Region *dest, bool wouldBeCloned, 77 BlockAndValueMapping &valueMapping) const { 78 if (auto *handler = getInterfaceFor(op)) 79 return handler->isLegalToInline(op, dest, wouldBeCloned, valueMapping); 80 return false; 81 } 82 83 bool InlinerInterface::shouldAnalyzeRecursively(Operation *op) const { 84 auto *handler = getInterfaceFor(op); 85 return handler ? handler->shouldAnalyzeRecursively(op) : true; 86 } 87 88 /// Handle the given inlined terminator by replacing it with a new operation 89 /// as necessary. 90 void InlinerInterface::handleTerminator(Operation *op, Block *newDest) const { 91 auto *handler = getInterfaceFor(op); 92 assert(handler && "expected valid dialect handler"); 93 handler->handleTerminator(op, newDest); 94 } 95 96 /// Handle the given inlined terminator by replacing it with a new operation 97 /// as necessary. 98 void InlinerInterface::handleTerminator(Operation *op, 99 ArrayRef<Value> valuesToRepl) const { 100 auto *handler = getInterfaceFor(op); 101 assert(handler && "expected valid dialect handler"); 102 handler->handleTerminator(op, valuesToRepl); 103 } 104 105 void InlinerInterface::processInlinedCallBlocks( 106 Operation *call, iterator_range<Region::iterator> inlinedBlocks) const { 107 auto *handler = getInterfaceFor(call); 108 assert(handler && "expected valid dialect handler"); 109 handler->processInlinedCallBlocks(call, inlinedBlocks); 110 } 111 112 /// Utility to check that all of the operations within 'src' can be inlined. 113 static bool isLegalToInline(InlinerInterface &interface, Region *src, 114 Region *insertRegion, bool shouldCloneInlinedRegion, 115 BlockAndValueMapping &valueMapping) { 116 for (auto &block : *src) { 117 for (auto &op : block) { 118 // Check this operation. 119 if (!interface.isLegalToInline(&op, insertRegion, 120 shouldCloneInlinedRegion, valueMapping)) { 121 LLVM_DEBUG({ 122 llvm::dbgs() << "* Illegal to inline because of op: "; 123 op.dump(); 124 }); 125 return false; 126 } 127 // Check any nested regions. 128 if (interface.shouldAnalyzeRecursively(&op) && 129 llvm::any_of(op.getRegions(), [&](Region ®ion) { 130 return !isLegalToInline(interface, ®ion, insertRegion, 131 shouldCloneInlinedRegion, valueMapping); 132 })) 133 return false; 134 } 135 } 136 return true; 137 } 138 139 //===----------------------------------------------------------------------===// 140 // Inline Methods 141 //===----------------------------------------------------------------------===// 142 143 static LogicalResult 144 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, 145 Block::iterator inlinePoint, BlockAndValueMapping &mapper, 146 ValueRange resultsToReplace, TypeRange regionResultTypes, 147 Optional<Location> inlineLoc, bool shouldCloneInlinedRegion, 148 Operation *call = nullptr) { 149 assert(resultsToReplace.size() == regionResultTypes.size()); 150 // We expect the region to have at least one block. 151 if (src->empty()) 152 return failure(); 153 154 // Check that all of the region arguments have been mapped. 155 auto *srcEntryBlock = &src->front(); 156 if (llvm::any_of(srcEntryBlock->getArguments(), 157 [&](BlockArgument arg) { return !mapper.contains(arg); })) 158 return failure(); 159 160 // Check that the operations within the source region are valid to inline. 161 Region *insertRegion = inlineBlock->getParent(); 162 if (!interface.isLegalToInline(insertRegion, src, shouldCloneInlinedRegion, 163 mapper) || 164 !isLegalToInline(interface, src, insertRegion, shouldCloneInlinedRegion, 165 mapper)) 166 return failure(); 167 168 // Check to see if the region is being cloned, or moved inline. In either 169 // case, move the new blocks after the 'insertBlock' to improve IR 170 // readability. 171 Block *postInsertBlock = inlineBlock->splitBlock(inlinePoint); 172 if (shouldCloneInlinedRegion) 173 src->cloneInto(insertRegion, postInsertBlock->getIterator(), mapper); 174 else 175 insertRegion->getBlocks().splice(postInsertBlock->getIterator(), 176 src->getBlocks(), src->begin(), 177 src->end()); 178 179 // Get the range of newly inserted blocks. 180 auto newBlocks = llvm::make_range(std::next(inlineBlock->getIterator()), 181 postInsertBlock->getIterator()); 182 Block *firstNewBlock = &*newBlocks.begin(); 183 184 // Remap the locations of the inlined operations if a valid source location 185 // was provided. 186 if (inlineLoc && !inlineLoc->isa<UnknownLoc>()) 187 remapInlinedLocations(newBlocks, *inlineLoc); 188 189 // If the blocks were moved in-place, make sure to remap any necessary 190 // operands. 191 if (!shouldCloneInlinedRegion) 192 remapInlinedOperands(newBlocks, mapper); 193 194 // Process the newly inlined blocks. 195 if (call) 196 interface.processInlinedCallBlocks(call, newBlocks); 197 interface.processInlinedBlocks(newBlocks); 198 199 // Handle the case where only a single block was inlined. 200 if (std::next(newBlocks.begin()) == newBlocks.end()) { 201 // Have the interface handle the terminator of this block. 202 auto *firstBlockTerminator = firstNewBlock->getTerminator(); 203 interface.handleTerminator(firstBlockTerminator, 204 llvm::to_vector<6>(resultsToReplace)); 205 firstBlockTerminator->erase(); 206 207 // Merge the post insert block into the cloned entry block. 208 firstNewBlock->getOperations().splice(firstNewBlock->end(), 209 postInsertBlock->getOperations()); 210 postInsertBlock->erase(); 211 } else { 212 // Otherwise, there were multiple blocks inlined. Add arguments to the post 213 // insertion block to represent the results to replace. 214 for (const auto &resultToRepl : llvm::enumerate(resultsToReplace)) { 215 resultToRepl.value().replaceAllUsesWith( 216 postInsertBlock->addArgument(regionResultTypes[resultToRepl.index()], 217 resultToRepl.value().getLoc())); 218 } 219 220 /// Handle the terminators for each of the new blocks. 221 for (auto &newBlock : newBlocks) 222 interface.handleTerminator(newBlock.getTerminator(), postInsertBlock); 223 } 224 225 // Splice the instructions of the inlined entry block into the insert block. 226 inlineBlock->getOperations().splice(inlineBlock->end(), 227 firstNewBlock->getOperations()); 228 firstNewBlock->erase(); 229 return success(); 230 } 231 232 static LogicalResult 233 inlineRegionImpl(InlinerInterface &interface, Region *src, Block *inlineBlock, 234 Block::iterator inlinePoint, ValueRange inlinedOperands, 235 ValueRange resultsToReplace, Optional<Location> inlineLoc, 236 bool shouldCloneInlinedRegion, Operation *call = nullptr) { 237 // We expect the region to have at least one block. 238 if (src->empty()) 239 return failure(); 240 241 auto *entryBlock = &src->front(); 242 if (inlinedOperands.size() != entryBlock->getNumArguments()) 243 return failure(); 244 245 // Map the provided call operands to the arguments of the region. 246 BlockAndValueMapping mapper; 247 for (unsigned i = 0, e = inlinedOperands.size(); i != e; ++i) { 248 // Verify that the types of the provided values match the function argument 249 // types. 250 BlockArgument regionArg = entryBlock->getArgument(i); 251 if (inlinedOperands[i].getType() != regionArg.getType()) 252 return failure(); 253 mapper.map(regionArg, inlinedOperands[i]); 254 } 255 256 // Call into the main region inliner function. 257 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, 258 resultsToReplace, resultsToReplace.getTypes(), 259 inlineLoc, shouldCloneInlinedRegion, call); 260 } 261 262 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 263 Operation *inlinePoint, 264 BlockAndValueMapping &mapper, 265 ValueRange resultsToReplace, 266 TypeRange regionResultTypes, 267 Optional<Location> inlineLoc, 268 bool shouldCloneInlinedRegion) { 269 return inlineRegion(interface, src, inlinePoint->getBlock(), 270 ++inlinePoint->getIterator(), mapper, resultsToReplace, 271 regionResultTypes, inlineLoc, shouldCloneInlinedRegion); 272 } 273 LogicalResult 274 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock, 275 Block::iterator inlinePoint, BlockAndValueMapping &mapper, 276 ValueRange resultsToReplace, TypeRange regionResultTypes, 277 Optional<Location> inlineLoc, 278 bool shouldCloneInlinedRegion) { 279 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, mapper, 280 resultsToReplace, regionResultTypes, inlineLoc, 281 shouldCloneInlinedRegion); 282 } 283 284 LogicalResult mlir::inlineRegion(InlinerInterface &interface, Region *src, 285 Operation *inlinePoint, 286 ValueRange inlinedOperands, 287 ValueRange resultsToReplace, 288 Optional<Location> inlineLoc, 289 bool shouldCloneInlinedRegion) { 290 return inlineRegion(interface, src, inlinePoint->getBlock(), 291 ++inlinePoint->getIterator(), inlinedOperands, 292 resultsToReplace, inlineLoc, shouldCloneInlinedRegion); 293 } 294 LogicalResult 295 mlir::inlineRegion(InlinerInterface &interface, Region *src, Block *inlineBlock, 296 Block::iterator inlinePoint, ValueRange inlinedOperands, 297 ValueRange resultsToReplace, Optional<Location> inlineLoc, 298 bool shouldCloneInlinedRegion) { 299 return inlineRegionImpl(interface, src, inlineBlock, inlinePoint, 300 inlinedOperands, resultsToReplace, inlineLoc, 301 shouldCloneInlinedRegion); 302 } 303 304 /// Utility function used to generate a cast operation from the given interface, 305 /// or return nullptr if a cast could not be generated. 306 static Value materializeConversion(const DialectInlinerInterface *interface, 307 SmallVectorImpl<Operation *> &castOps, 308 OpBuilder &castBuilder, Value arg, Type type, 309 Location conversionLoc) { 310 if (!interface) 311 return nullptr; 312 313 // Check to see if the interface for the call can materialize a conversion. 314 Operation *castOp = interface->materializeCallConversion(castBuilder, arg, 315 type, conversionLoc); 316 if (!castOp) 317 return nullptr; 318 castOps.push_back(castOp); 319 320 // Ensure that the generated cast is correct. 321 assert(castOp->getNumOperands() == 1 && castOp->getOperand(0) == arg && 322 castOp->getNumResults() == 1 && *castOp->result_type_begin() == type); 323 return castOp->getResult(0); 324 } 325 326 /// This function inlines a given region, 'src', of a callable operation, 327 /// 'callable', into the location defined by the given call operation. This 328 /// function returns failure if inlining is not possible, success otherwise. On 329 /// failure, no changes are made to the module. 'shouldCloneInlinedRegion' 330 /// corresponds to whether the source region should be cloned into the 'call' or 331 /// spliced directly. 332 LogicalResult mlir::inlineCall(InlinerInterface &interface, 333 CallOpInterface call, 334 CallableOpInterface callable, Region *src, 335 bool shouldCloneInlinedRegion) { 336 // We expect the region to have at least one block. 337 if (src->empty()) 338 return failure(); 339 auto *entryBlock = &src->front(); 340 ArrayRef<Type> callableResultTypes = callable.getCallableResults(); 341 342 // Make sure that the number of arguments and results matchup between the call 343 // and the region. 344 SmallVector<Value, 8> callOperands(call.getArgOperands()); 345 SmallVector<Value, 8> callResults(call->getResults()); 346 if (callOperands.size() != entryBlock->getNumArguments() || 347 callResults.size() != callableResultTypes.size()) 348 return failure(); 349 350 // A set of cast operations generated to matchup the signature of the region 351 // with the signature of the call. 352 SmallVector<Operation *, 4> castOps; 353 castOps.reserve(callOperands.size() + callResults.size()); 354 355 // Functor used to cleanup generated state on failure. 356 auto cleanupState = [&] { 357 for (auto *op : castOps) { 358 op->getResult(0).replaceAllUsesWith(op->getOperand(0)); 359 op->erase(); 360 } 361 return failure(); 362 }; 363 364 // Builder used for any conversion operations that need to be materialized. 365 OpBuilder castBuilder(call); 366 Location castLoc = call.getLoc(); 367 const auto *callInterface = interface.getInterfaceFor(call->getDialect()); 368 369 // Map the provided call operands to the arguments of the region. 370 BlockAndValueMapping mapper; 371 for (unsigned i = 0, e = callOperands.size(); i != e; ++i) { 372 BlockArgument regionArg = entryBlock->getArgument(i); 373 Value operand = callOperands[i]; 374 375 // If the call operand doesn't match the expected region argument, try to 376 // generate a cast. 377 Type regionArgType = regionArg.getType(); 378 if (operand.getType() != regionArgType) { 379 if (!(operand = materializeConversion(callInterface, castOps, castBuilder, 380 operand, regionArgType, castLoc))) 381 return cleanupState(); 382 } 383 mapper.map(regionArg, operand); 384 } 385 386 // Ensure that the resultant values of the call match the callable. 387 castBuilder.setInsertionPointAfter(call); 388 for (unsigned i = 0, e = callResults.size(); i != e; ++i) { 389 Value callResult = callResults[i]; 390 if (callResult.getType() == callableResultTypes[i]) 391 continue; 392 393 // Generate a conversion that will produce the original type, so that the IR 394 // is still valid after the original call gets replaced. 395 Value castResult = 396 materializeConversion(callInterface, castOps, castBuilder, callResult, 397 callResult.getType(), castLoc); 398 if (!castResult) 399 return cleanupState(); 400 callResult.replaceAllUsesWith(castResult); 401 castResult.getDefiningOp()->replaceUsesOfWith(castResult, callResult); 402 } 403 404 // Check that it is legal to inline the callable into the call. 405 if (!interface.isLegalToInline(call, callable, shouldCloneInlinedRegion)) 406 return cleanupState(); 407 408 // Attempt to inline the call. 409 if (failed(inlineRegionImpl(interface, src, call->getBlock(), 410 ++call->getIterator(), mapper, callResults, 411 callableResultTypes, call.getLoc(), 412 shouldCloneInlinedRegion, call))) 413 return cleanupState(); 414 return success(); 415 } 416